mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-04-20 22:24:15 +00:00
adding all files on benchmarks, data, new model tot scripts, new prompts - pt1
This commit is contained in:
parent
ff04bbcd41
commit
c69681e3a0
39
=0.26.0
Normal file
39
=0.26.0
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
Collecting accelerate
|
||||||
|
Downloading accelerate-1.2.0-py3-none-any.whl.metadata (19 kB)
|
||||||
|
Requirement already satisfied: numpy<3.0.0,>=1.17 in /opt/conda/lib/python3.10/site-packages (from accelerate) (1.25.2)
|
||||||
|
Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from accelerate) (24.1)
|
||||||
|
Requirement already satisfied: psutil in /opt/conda/lib/python3.10/site-packages (from accelerate) (5.9.3)
|
||||||
|
Requirement already satisfied: pyyaml in /opt/conda/lib/python3.10/site-packages (from accelerate) (6.0.2)
|
||||||
|
Requirement already satisfied: torch>=1.10.0 in /opt/conda/lib/python3.10/site-packages (from accelerate) (2.5.1)
|
||||||
|
Requirement already satisfied: huggingface-hub>=0.21.0 in /opt/conda/lib/python3.10/site-packages (from accelerate) (0.26.3)
|
||||||
|
Requirement already satisfied: safetensors>=0.4.3 in /opt/conda/lib/python3.10/site-packages (from accelerate) (0.4.5)
|
||||||
|
Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (3.16.1)
|
||||||
|
Requirement already satisfied: fsspec>=2023.5.0 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (2024.10.0)
|
||||||
|
Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (2.32.3)
|
||||||
|
Requirement already satisfied: tqdm>=4.42.1 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (4.67.0)
|
||||||
|
Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (4.12.2)
|
||||||
|
Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.4.2)
|
||||||
|
Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.1.4)
|
||||||
|
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
|
||||||
|
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
|
||||||
|
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
|
||||||
|
Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (9.1.0.70)
|
||||||
|
Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.5.8)
|
||||||
|
Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (11.2.1.3)
|
||||||
|
Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (10.3.5.147)
|
||||||
|
Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (11.6.1.9)
|
||||||
|
Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.3.1.170)
|
||||||
|
Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (2.21.5)
|
||||||
|
Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
|
||||||
|
Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
|
||||||
|
Requirement already satisfied: triton==3.1.0 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.1.0)
|
||||||
|
Requirement already satisfied: sympy==1.13.1 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (1.13.1)
|
||||||
|
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/lib/python3.10/site-packages (from sympy==1.13.1->torch>=1.10.0->accelerate) (1.3.0)
|
||||||
|
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch>=1.10.0->accelerate) (3.0.2)
|
||||||
|
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.4.0)
|
||||||
|
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.10)
|
||||||
|
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (1.26.20)
|
||||||
|
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2024.8.30)
|
||||||
|
Downloading accelerate-1.2.0-py3-none-any.whl (336 kB)
|
||||||
|
Installing collected packages: accelerate
|
||||||
|
Successfully installed accelerate-1.2.0
|
135
run.py
135
run.py
@ -4,7 +4,8 @@ import argparse
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from src.tot.tasks import get_task
|
from src.tot.tasks import get_task
|
||||||
from src.tot.methods.bfs import solve, naive_solve
|
from src.tot.methods.bfs import solve, naive_solve, solve_bench
|
||||||
|
# from src.tot.methods.a_star import solve, naive_solve, solve_bench
|
||||||
from src.tot.models import gpt_usage
|
from src.tot.models import gpt_usage
|
||||||
|
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
@ -16,19 +17,21 @@ def run(args):
|
|||||||
'''
|
'''
|
||||||
main run function
|
main run function
|
||||||
'''
|
'''
|
||||||
#bc of the way the original repo is structured, will need to load in llama models in run.py to avoid repeated loading
|
#load in specific llama model, if applicable
|
||||||
|
#bc of the way the original repo is structured, will need to load in llama models in run.py to avoid repeated loading in models.py
|
||||||
if args.backend == 'llama':
|
if args.backend == 'llama':
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
||||||
|
|
||||||
if args.quantize and args.quantize=='ptq_int4':
|
if args.quantize and args.quantize=='ptq_int4':
|
||||||
model = AutoModelForCausalLM.from_pretrained("src/tot/quant/hf_quant_int4", device_map="cuda")
|
model = AutoModelForCausalLM.from_pretrained("src/tot/quant/hf_quant_int4", device_map="cuda", weights_only=False)
|
||||||
model = torch.compile(model, mode="max-autotune")
|
model = torch.compile(model, mode="max-autotune")
|
||||||
if args.quantize and args.quantize=='ptq_int8':
|
elif args.quantize and args.quantize=='ptq_int8':
|
||||||
model = AutoModelForCausalLM.from_pretrained("src/tot/ptq_int8", device_map="cuda")
|
model = AutoModelForCausalLM.from_pretrained("src/tot/quant/ptq_int8", device_map="cuda")
|
||||||
model = torch.compile(model, mode="max-autotune")
|
model = torch.compile(model, mode="max-autotune")
|
||||||
elif args.backend == 'qat':
|
elif args.quantize and args.quantize == 'qat':
|
||||||
model = AutoModelForCausalLM.from_pretrained("src/tot/qat_int8", device_map="cuda")
|
model = AutoModelForCausalLM.from_pretrained("src/tot/quant/qat_int8", device_map="cuda", weights_only=False)
|
||||||
model = torch.compile(model, mode="max-autotune")
|
model = torch.compile(model, mode="max-autotune")
|
||||||
else:
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
||||||
|
|
||||||
else: #gpt4 will be used later in this case
|
else: #gpt4 will be used later in this case
|
||||||
@ -36,44 +39,82 @@ def run(args):
|
|||||||
tokenizer = None
|
tokenizer = None
|
||||||
|
|
||||||
#set up
|
#set up
|
||||||
task = get_task(args.task)
|
if args.task == 'comp_bench':
|
||||||
logs, cnt_avg, cnt_any = [], 0, 0
|
test_data = torch.load('src/tot/data/agg_dl_test.pt')
|
||||||
if args.naive_run:
|
|
||||||
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
|
|
||||||
else:
|
|
||||||
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
|
|
||||||
os.makedirs(os.path.dirname(file), exist_ok=True)
|
|
||||||
|
|
||||||
#run the specified range of tasks
|
for sample in test_data:
|
||||||
for i in range(args.task_start_index, args.task_end_index):
|
|
||||||
|
|
||||||
# solve
|
start_timer = perf.counter()
|
||||||
start_timer = time.perf_counter()
|
|
||||||
|
WHILE WE GO THROUGH THE TREE
|
||||||
|
#solve
|
||||||
|
ys, info = solve_bench(args, model, tokenizer)
|
||||||
|
|
||||||
|
runtime = time.perf_counter()-start_timer
|
||||||
|
print(f"""For iteration {i} --
|
||||||
|
Total Time to Solve: {runtime} seconds""")
|
||||||
|
|
||||||
|
#evaluate/rate the proposals
|
||||||
|
|
||||||
|
|
||||||
|
#compare the proposed final answer vs the ground truth
|
||||||
|
gt = tokenizer.decode(sample['label'][0], return_tensors='pt').replace("<|eot_id|>", "").replace("<|begin_of_text|>", "").strip()
|
||||||
|
# log
|
||||||
|
# infos = [task.test_output(i, y) for y in ys]
|
||||||
|
# info.update({'idx': i, 'ys': ys, 'infos': infos, 'usage_so_far (gpt only)': gpt_usage(args.backend), 'total_runtime': runtime})
|
||||||
|
# logs.append(info)
|
||||||
|
# with open(file, 'w') as f:
|
||||||
|
# json.dump(logs, f, indent=4)
|
||||||
|
|
||||||
|
# log main metric
|
||||||
|
# accs = [info['r'] for info in infos]
|
||||||
|
# cnt_avg += sum(accs) / len(accs)
|
||||||
|
# cnt_any += any(accs)
|
||||||
|
# print(i, 'sum(accs)', sum(accs), 'cnt_avg', cnt_avg, 'cnt_any', cnt_any, '\n')
|
||||||
|
|
||||||
|
# n = args.task_end_index - args.task_start_index
|
||||||
|
# print(cnt_avg / n, cnt_any / n)
|
||||||
|
# print('usage_so_far', gpt_usage(args.backend))
|
||||||
|
|
||||||
|
else: #the original tot tasks from the paper
|
||||||
|
task = get_task(args.task)
|
||||||
|
logs, cnt_avg, cnt_any = [], 0, 0
|
||||||
if args.naive_run:
|
if args.naive_run:
|
||||||
ys, info = naive_solve(args, task, i, model, tokenizer)
|
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
|
||||||
else:
|
else:
|
||||||
ys, info = solve(args, task, i, model, tokenizer)
|
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
|
||||||
|
os.makedirs(os.path.dirname(file), exist_ok=True)
|
||||||
|
|
||||||
runtime = time.perf_counter()-start_timer
|
#run the specified range of tasks
|
||||||
print(f"""For iteration {i} --
|
for i in range(args.task_start_index, args.task_end_index):
|
||||||
Total Time to Solve: {runtime} seconds""")
|
|
||||||
|
|
||||||
# log
|
# solve
|
||||||
infos = [task.test_output(i, y) for y in ys]
|
start_timer = time.perf_counter()
|
||||||
info.update({'idx': i, 'ys': ys, 'infos': infos, 'usage_so_far (gpt only)': gpt_usage(args.backend), 'total_runtime': runtime})
|
if args.naive_run:
|
||||||
logs.append(info)
|
ys, info = naive_solve(args, task, i, model, tokenizer)
|
||||||
with open(file, 'w') as f:
|
else:
|
||||||
json.dump(logs, f, indent=4)
|
ys, info = solve(args, task, i, model, tokenizer)
|
||||||
|
|
||||||
# log main metric
|
runtime = time.perf_counter()-start_timer
|
||||||
accs = [info['r'] for info in infos]
|
print(f"""For iteration {i} --
|
||||||
cnt_avg += sum(accs) / len(accs)
|
Total Time to Solve: {runtime} seconds""")
|
||||||
cnt_any += any(accs)
|
|
||||||
print(i, 'sum(accs)', sum(accs), 'cnt_avg', cnt_avg, 'cnt_any', cnt_any, '\n')
|
|
||||||
|
|
||||||
n = args.task_end_index - args.task_start_index
|
# log
|
||||||
print(cnt_avg / n, cnt_any / n)
|
infos = [task.test_output(i, y) for y in ys]
|
||||||
print('usage_so_far', gpt_usage(args.backend))
|
info.update({'idx': i, 'ys': ys, 'infos': infos, 'usage_so_far (gpt only)': gpt_usage(args.backend), 'total_runtime': runtime})
|
||||||
|
logs.append(info)
|
||||||
|
with open(file, 'w') as f:
|
||||||
|
json.dump(logs, f, indent=4)
|
||||||
|
|
||||||
|
# log main metric
|
||||||
|
accs = [info['r'] for info in infos]
|
||||||
|
cnt_avg += sum(accs) / len(accs)
|
||||||
|
cnt_any += any(accs)
|
||||||
|
print(i, 'sum(accs)', sum(accs), 'cnt_avg', cnt_avg, 'cnt_any', cnt_any, '\n')
|
||||||
|
|
||||||
|
n = args.task_end_index - args.task_start_index
|
||||||
|
print(cnt_avg / n, cnt_any / n)
|
||||||
|
print('usage_so_far', gpt_usage(args.backend))
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -82,25 +123,23 @@ def parse_args():
|
|||||||
'''
|
'''
|
||||||
args = argparse.ArgumentParser()
|
args = argparse.ArgumentParser()
|
||||||
|
|
||||||
#what model to use
|
#the arguments to use for our purposes
|
||||||
args.add_argument('--backend', type=str, choices=['gpt-4o', 'llama'], default='gpt-4o')
|
args.add_argument('--backend', type=str, choices=['gpt-4o', 'llama'], default='gpt-4o')
|
||||||
args.add_argument('--quantize', type=str, choices=['qat', 'ptq', 'spinquant'])
|
args.add_argument('--quantize', type=str, choices=['qat', 'ptq_int4', 'ptq_int8'])
|
||||||
|
|
||||||
#what temperature to use
|
|
||||||
args.add_argument('--temperature', type=float, default=0.0)
|
args.add_argument('--temperature', type=float, default=0.0)
|
||||||
|
args.add_argument('--method_generate', type=str, choices=['sample', 'propose'])
|
||||||
|
args.add_argument('--method_evaluate', type=str, choices=['value', 'vote'])
|
||||||
|
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'], default='greedy')
|
||||||
|
|
||||||
#the problem task
|
#other args from the original repo
|
||||||
args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords'])
|
args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords', 'comp_bench'])
|
||||||
|
|
||||||
#which tasks from the data file to solve
|
|
||||||
args.add_argument('--task_start_index', type=int, default=900)
|
args.add_argument('--task_start_index', type=int, default=900)
|
||||||
args.add_argument('--task_end_index', type=int, default=1000)
|
args.add_argument('--task_end_index', type=int, default=1000)
|
||||||
|
|
||||||
args.add_argument('--naive_run', action='store_true')
|
args.add_argument('--naive_run', action='store_true')
|
||||||
args.add_argument('--prompt_sample', type=str, choices=['standard', 'cot']) # only used when method_generate = sample, or naive_run
|
args.add_argument('--prompt_sample', type=str, choices=['standard', 'cot']) # only used when method_generate = sample, or naive_run
|
||||||
args.add_argument('--method_generate', type=str, choices=['sample', 'propose'])
|
|
||||||
args.add_argument('--method_evaluate', type=str, choices=['value', 'vote'])
|
|
||||||
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'], default='greedy')
|
|
||||||
args.add_argument('--n_generate_sample', type=int, default=1) # only thing needed if naive_run
|
args.add_argument('--n_generate_sample', type=int, default=1) # only thing needed if naive_run
|
||||||
args.add_argument('--n_evaluate_sample', type=int, default=1)
|
args.add_argument('--n_evaluate_sample', type=int, default=1)
|
||||||
args.add_argument('--n_select_sample', type=int, default=1)
|
args.add_argument('--n_select_sample', type=int, default=1)
|
||||||
|
43
run_bench.py
43
run_bench.py
@ -27,9 +27,6 @@ def load_llama(quant=None):
|
|||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
def propose(problem, current_state, tokenizer, model, device):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def value_proposals(problem, current_state, proposals, tokenizer, model, device):
|
def value_proposals(problem, current_state, proposals, tokenizer, model, device):
|
||||||
'''
|
'''
|
||||||
Takes in string values of problem, current state, and proposals.
|
Takes in string values of problem, current state, and proposals.
|
||||||
@ -38,7 +35,7 @@ def value_proposals(problem, current_state, proposals, tokenizer, model, device)
|
|||||||
valuations = []
|
valuations = []
|
||||||
prompts = []
|
prompts = []
|
||||||
for p in proposals:
|
for p in proposals:
|
||||||
prompts.append(value_prompt.format(problem=problem, current_state=current_state, proposal=proposal))
|
prompts.append(value_prompt.format(problem=problem, current_state=current_state, proposal=p))
|
||||||
|
|
||||||
values = tokenizer(prompts, return_tensors='pt')
|
values = tokenizer(prompts, return_tensors='pt')
|
||||||
value_inputs = values['input_ids'].to(device)
|
value_inputs = values['input_ids'].to(device)
|
||||||
@ -58,8 +55,19 @@ def value_proposals(problem, current_state, proposals, tokenizer, model, device)
|
|||||||
|
|
||||||
return valuations
|
return valuations
|
||||||
|
|
||||||
def final_eval():
|
def final_eval(gt, final_prop):
|
||||||
pass
|
'''
|
||||||
|
compare the ground truth and final proposed solution by the model
|
||||||
|
'''
|
||||||
|
print("THIS IS THE FINAL PROP")
|
||||||
|
print(final_prop)
|
||||||
|
print("THIS IS THE GT")
|
||||||
|
print(gt)
|
||||||
|
|
||||||
|
if gt in final_prop:
|
||||||
|
return 1.0
|
||||||
|
else:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
def run(args):
|
def run(args):
|
||||||
@ -83,6 +91,9 @@ def run(args):
|
|||||||
#set up
|
#set up
|
||||||
test_data = torch.load('src/tot/data/agg_dl_test.pt')
|
test_data = torch.load('src/tot/data/agg_dl_test.pt')
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
right = 0
|
||||||
|
|
||||||
for samples in test_data:
|
for samples in test_data:
|
||||||
|
|
||||||
for sample in samples: #going to do this one problem at a time for now.
|
for sample in samples: #going to do this one problem at a time for now.
|
||||||
@ -97,6 +108,7 @@ def run(args):
|
|||||||
#start solving via tot
|
#start solving via tot
|
||||||
start_timer = perf.counter()
|
start_timer = perf.counter()
|
||||||
|
|
||||||
|
selected = ""
|
||||||
for i in range(args.depth): #args.depth number of attempts to reach the solution
|
for i in range(args.depth): #args.depth number of attempts to reach the solution
|
||||||
|
|
||||||
#propose next step/solutions per node/prompt
|
#propose next step/solutions per node/prompt
|
||||||
@ -105,7 +117,7 @@ def run(args):
|
|||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=mask,
|
attention_mask=mask,
|
||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
max_new_tokens=max_tokens,
|
max_new_tokens=args.max_tokens,
|
||||||
num_return_sequences=args.breadth)
|
num_return_sequences=args.breadth)
|
||||||
|
|
||||||
|
|
||||||
@ -123,21 +135,28 @@ def run(args):
|
|||||||
if 100.0 in valuations:
|
if 100.0 in valuations:
|
||||||
break
|
break
|
||||||
|
|
||||||
#select the proposals that act as input to the next iteration
|
#select the best proposal
|
||||||
val_props = list(zip(proposals, valuations))
|
val_props = list(zip(proposals, valuations))
|
||||||
valuations.sort(key = lambda ele: ele[1], descending=True)
|
val_props.sort(key = lambda ele: ele[1], descending=True)
|
||||||
selected = valuations[:args.greedy_n]
|
selected = val_props[:args.greedy_n]
|
||||||
|
|
||||||
inputs = tokenizer(valuations, return_tensors='pt')
|
#format the chosen proposal for the next iteration
|
||||||
|
next_prompt = propose_prompt.format(problem=problem, current_state=selected)
|
||||||
|
inputs = tokenizer(next_prompt, return_tensors='pt')
|
||||||
input_ids = inputs['input_ids'].to(device)
|
input_ids = inputs['input_ids'].to(device)
|
||||||
mask = inputs['attention_mask'].to(device)
|
mask = inputs['attention_mask'].to(device)
|
||||||
|
|
||||||
|
|
||||||
#compare the proposed final answer vs the ground truth
|
#compare the proposed final answer vs the ground truth
|
||||||
gt = tokenizer.decode(label, skip_special_token=True)
|
gt = tokenizer.decode(label, skip_special_token=True)
|
||||||
|
judgement = final_eval(gt, selected)
|
||||||
|
|
||||||
judgement = final_eval(gt, final_proposal)
|
#keep track of the running totals
|
||||||
|
total += 1.0
|
||||||
|
right += judgement
|
||||||
|
print("Accuracy so far: ", total/right)
|
||||||
|
|
||||||
|
total_accuracy = right/total
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
'''
|
'''
|
||||||
|
0
src/tot/data/benchmark/bench.py
Normal file
0
src/tot/data/benchmark/bench.py
Normal file
0
src/tot/data/test.ipynb
Normal file
0
src/tot/data/test.ipynb
Normal file
@ -36,7 +36,7 @@ def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
|
|||||||
return values, times
|
return values, times
|
||||||
|
|
||||||
def get_votes(task, x, ys, n_evaluate_sample):
|
def get_votes(task, x, ys, n_evaluate_sample):
|
||||||
vote_prompt = task.vote_prompt_wrap(x, ys)
|
vote_prompt = task.cot_prompt_wrap(x, ys)
|
||||||
vote_outputs = inference_model(vote_prompt, n=n_evaluate_sample, stop=None)
|
vote_outputs = inference_model(vote_prompt, n=n_evaluate_sample, stop=None)
|
||||||
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
|
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
|
||||||
return values
|
return values
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import openai
|
import openai
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import backoff
|
# import backoff
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
import time
|
import time
|
||||||
@ -41,7 +41,8 @@ def hf_model(model, tokenizer, prompt, temperature=0.7, max_tokens=1000, n=5, st
|
|||||||
"""
|
"""
|
||||||
outputs = []
|
outputs = []
|
||||||
all_times = []
|
all_times = []
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
device = 'cpu'
|
||||||
|
|
||||||
#tokenize inputs
|
#tokenize inputs
|
||||||
inputs = tokenizer(prompt, return_tensors="pt")
|
inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
48
src/tot/quant/hf_quant_int4/config.json
Normal file
48
src/tot/quant/hf_quant_int4/config.json
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
{
|
||||||
|
"_name_or_path": "meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
"architectures": [
|
||||||
|
"LlamaForCausalLM"
|
||||||
|
],
|
||||||
|
"attention_bias": false,
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 128000,
|
||||||
|
"eos_token_id": [
|
||||||
|
128001,
|
||||||
|
128008,
|
||||||
|
128009
|
||||||
|
],
|
||||||
|
"head_dim": 128,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 3072,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 8192,
|
||||||
|
"max_position_embeddings": 131072,
|
||||||
|
"mlp_bias": false,
|
||||||
|
"model_type": "llama",
|
||||||
|
"num_attention_heads": 24,
|
||||||
|
"num_hidden_layers": 28,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"pretraining_tp": 1,
|
||||||
|
"quantization_config": {
|
||||||
|
"modules_to_not_convert": null,
|
||||||
|
"quant_method": "torchao",
|
||||||
|
"quant_type": "int4_weight_only",
|
||||||
|
"quant_type_kwargs": {
|
||||||
|
"group_size": 128
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"rms_norm_eps": 1e-05,
|
||||||
|
"rope_scaling": {
|
||||||
|
"factor": 32.0,
|
||||||
|
"high_freq_factor": 4.0,
|
||||||
|
"low_freq_factor": 1.0,
|
||||||
|
"original_max_position_embeddings": 8192,
|
||||||
|
"rope_type": "llama3"
|
||||||
|
},
|
||||||
|
"rope_theta": 500000.0,
|
||||||
|
"tie_word_embeddings": true,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"transformers_version": "4.46.2",
|
||||||
|
"use_cache": true,
|
||||||
|
"vocab_size": 128256
|
||||||
|
}
|
12
src/tot/quant/hf_quant_int4/generation_config.json
Normal file
12
src/tot/quant/hf_quant_int4/generation_config.json
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"bos_token_id": 128000,
|
||||||
|
"do_sample": true,
|
||||||
|
"eos_token_id": [
|
||||||
|
128001,
|
||||||
|
128008,
|
||||||
|
128009
|
||||||
|
],
|
||||||
|
"temperature": 0.6,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"transformers_version": "4.46.2"
|
||||||
|
}
|
40
src/tot/quant/ptq_int8/config.json
Normal file
40
src/tot/quant/ptq_int8/config.json
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
{
|
||||||
|
"_name_or_path": "meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
"architectures": [
|
||||||
|
"LlamaForCausalLM"
|
||||||
|
],
|
||||||
|
"attention_bias": false,
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 128000,
|
||||||
|
"eos_token_id": [
|
||||||
|
128001,
|
||||||
|
128008,
|
||||||
|
128009
|
||||||
|
],
|
||||||
|
"head_dim": 128,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 3072,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 8192,
|
||||||
|
"max_position_embeddings": 131072,
|
||||||
|
"mlp_bias": false,
|
||||||
|
"model_type": "llama",
|
||||||
|
"num_attention_heads": 24,
|
||||||
|
"num_hidden_layers": 28,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"pretraining_tp": 1,
|
||||||
|
"rms_norm_eps": 1e-05,
|
||||||
|
"rope_scaling": {
|
||||||
|
"factor": 32.0,
|
||||||
|
"high_freq_factor": 4.0,
|
||||||
|
"low_freq_factor": 1.0,
|
||||||
|
"original_max_position_embeddings": 8192,
|
||||||
|
"rope_type": "llama3"
|
||||||
|
},
|
||||||
|
"rope_theta": 500000.0,
|
||||||
|
"tie_word_embeddings": true,
|
||||||
|
"torch_dtype": "float32",
|
||||||
|
"transformers_version": "4.46.2",
|
||||||
|
"use_cache": true,
|
||||||
|
"vocab_size": 128256
|
||||||
|
}
|
12
src/tot/quant/ptq_int8/generation_config.json
Normal file
12
src/tot/quant/ptq_int8/generation_config.json
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"bos_token_id": 128000,
|
||||||
|
"do_sample": true,
|
||||||
|
"eos_token_id": [
|
||||||
|
128001,
|
||||||
|
128008,
|
||||||
|
128009
|
||||||
|
],
|
||||||
|
"temperature": 0.6,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"transformers_version": "4.46.2"
|
||||||
|
}
|
40
src/tot/quant/qat_int8/config.json
Normal file
40
src/tot/quant/qat_int8/config.json
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
{
|
||||||
|
"_name_or_path": "meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
"architectures": [
|
||||||
|
"LlamaForCausalLM"
|
||||||
|
],
|
||||||
|
"attention_bias": false,
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 128000,
|
||||||
|
"eos_token_id": [
|
||||||
|
128001,
|
||||||
|
128008,
|
||||||
|
128009
|
||||||
|
],
|
||||||
|
"head_dim": 128,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 3072,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 8192,
|
||||||
|
"max_position_embeddings": 131072,
|
||||||
|
"mlp_bias": false,
|
||||||
|
"model_type": "llama",
|
||||||
|
"num_attention_heads": 24,
|
||||||
|
"num_hidden_layers": 28,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"pretraining_tp": 1,
|
||||||
|
"rms_norm_eps": 1e-05,
|
||||||
|
"rope_scaling": {
|
||||||
|
"factor": 32.0,
|
||||||
|
"high_freq_factor": 4.0,
|
||||||
|
"low_freq_factor": 1.0,
|
||||||
|
"original_max_position_embeddings": 8192,
|
||||||
|
"rope_type": "llama3"
|
||||||
|
},
|
||||||
|
"rope_theta": 500000.0,
|
||||||
|
"tie_word_embeddings": true,
|
||||||
|
"torch_dtype": "float32",
|
||||||
|
"transformers_version": "4.46.2",
|
||||||
|
"use_cache": true,
|
||||||
|
"vocab_size": 128256
|
||||||
|
}
|
12
src/tot/quant/qat_int8/generation_config.json
Normal file
12
src/tot/quant/qat_int8/generation_config.json
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"bos_token_id": 128000,
|
||||||
|
"do_sample": true,
|
||||||
|
"eos_token_id": [
|
||||||
|
128001,
|
||||||
|
128008,
|
||||||
|
128009
|
||||||
|
],
|
||||||
|
"temperature": 0.6,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"transformers_version": "4.46.2"
|
||||||
|
}
|
@ -8,5 +8,8 @@ def get_task(name):
|
|||||||
elif name == 'crosswords':
|
elif name == 'crosswords':
|
||||||
from src.tot.tasks.crosswords import MiniCrosswordsTask
|
from src.tot.tasks.crosswords import MiniCrosswordsTask
|
||||||
return MiniCrosswordsTask()
|
return MiniCrosswordsTask()
|
||||||
|
elif name == 'comp_bench':
|
||||||
|
from src.tot.tasks.bench import Bench
|
||||||
|
return Bench()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
@ -5,11 +5,11 @@ class Task:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __len__(self) -> int:
|
# def __len__(self) -> int:
|
||||||
pass
|
# pass
|
||||||
|
|
||||||
def get_input(self, idx: int) -> str:
|
# def get_input(self, idx: int) -> str:
|
||||||
pass
|
# pass
|
||||||
|
|
||||||
def test_output(self, idx: int, output: str):
|
def test_output(self, idx: int, output: str):
|
||||||
pass
|
pass
|
94
src/tot/tasks/bench.py
Normal file
94
src/tot/tasks/bench.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import re
|
||||||
|
import os
|
||||||
|
import sympy
|
||||||
|
import pandas as pd
|
||||||
|
from src.tot.tasks.base import Task, DATA_PATH
|
||||||
|
from src.tot.prompts.game24 import *
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_numbers(y: str) -> str:
|
||||||
|
last_line = y.strip().split('\n')[-1]
|
||||||
|
return last_line.split('left: ')[-1].split(')')[0]
|
||||||
|
|
||||||
|
|
||||||
|
class Game24Task(Task):
|
||||||
|
"""
|
||||||
|
Input (x) : a string of
|
||||||
|
Output (y) : a trajectory of 3 steps to reach 24
|
||||||
|
Reward (r) : 0 or 1, depending on whether the trajectory is correct
|
||||||
|
Input Example:
|
||||||
|
1 2 3 4
|
||||||
|
Output Example:
|
||||||
|
1 + 2 = 3 (left: 3 3 4)
|
||||||
|
3 + 3 = 6 (left: 4 6)
|
||||||
|
6 * 4 = 24 (left: 24)
|
||||||
|
(1 + 2 + 3) * 4 = 24
|
||||||
|
"""
|
||||||
|
def __init__(self, file='24.csv', depth=5):
|
||||||
|
"""
|
||||||
|
file: a csv file (fixed)
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.data = torch.load('../data/agg_dl_test.pt')
|
||||||
|
self.value_cache = {}
|
||||||
|
self.steps = depth
|
||||||
|
self.stops = None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def get_input(self, idx: int) -> str:
|
||||||
|
return self.data[idx]
|
||||||
|
|
||||||
|
def test_output(self, idx: int, output: str):
|
||||||
|
expression = output.strip().split('\n')[-1].lower().replace('answer: ', '').split('=')[0]
|
||||||
|
numbers = re.findall(r'\d+', expression)
|
||||||
|
problem_numbers = re.findall(r'\d+', self.data[idx])
|
||||||
|
if sorted(numbers) != sorted(problem_numbers):
|
||||||
|
print("sorted numbers length not equal to sorted problem numbers length")
|
||||||
|
return {'r': 0}
|
||||||
|
try:
|
||||||
|
# print(sympy.simplify(expression))
|
||||||
|
return {'r': int(sympy.simplify(expression) == 24)}
|
||||||
|
except Exception as e:
|
||||||
|
# print(e)
|
||||||
|
print("entered exception!!")
|
||||||
|
return {'r': 0}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def standard_prompt_wrap(x: str, y:str='') -> str:
|
||||||
|
return standard_prompt.format(input=x) + y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cot_prompt_wrap(x: str, y:str='') -> str:
|
||||||
|
return cot_prompt.format(input=x) + y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def propose_prompt_wrap(x: str, y: str='') -> str:
|
||||||
|
current_numbers = get_current_numbers(y if y else x)
|
||||||
|
if current_numbers == '24':
|
||||||
|
prompt = cot_prompt.format(input=x) + 'Steps:' + y
|
||||||
|
# print([prompt])
|
||||||
|
else:
|
||||||
|
prompt = propose_prompt.format(input=current_numbers)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def value_prompt_wrap(x: str, y: str) -> str:
|
||||||
|
last_line = y.strip().split('\n')[-1]
|
||||||
|
if 'left: ' not in last_line: # last step
|
||||||
|
ans = last_line.lower().replace('answer: ', '')
|
||||||
|
# print([value_last_step_prompt.format(input=x, answer=ans)])
|
||||||
|
return value_last_step_prompt.format(input=x, answer=ans)
|
||||||
|
current_numbers = get_current_numbers(y)
|
||||||
|
return value_prompt.format(input=current_numbers)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def value_outputs_unwrap(x: str, y: str, value_outputs: list) -> float:
|
||||||
|
if len(y.strip().split('\n')) == 4 and 'answer' not in y.lower():
|
||||||
|
return 0
|
||||||
|
value_names = [_.split('\n')[-1] for _ in value_outputs]
|
||||||
|
value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20} # TODO: ad hoc
|
||||||
|
value = sum(value * value_names.count(name) for name, value in value_map.items())
|
||||||
|
normalized_value = value / (len(value_names) * value_map['sure']) # scale to [0, 1]
|
||||||
|
return normalized_value
|
Loading…
Reference in New Issue
Block a user