mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-05-23 12:23:58 +00:00
154 lines
6.5 KiB
Python
154 lines
6.5 KiB
Python
import os
|
|
import json
|
|
import argparse
|
|
import time
|
|
|
|
from src.tot.tasks import get_task
|
|
from src.tot.methods.bfs import solve, naive_solve, solve_bench
|
|
# from src.tot.methods.a_star import solve, naive_solve, solve_bench
|
|
from src.tot.models import gpt_usage
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
import torch
|
|
import torch.quantization
|
|
|
|
|
|
def run(args):
|
|
'''
|
|
main run function
|
|
'''
|
|
#load in specific llama model, if applicable
|
|
#bc of the way the original repo is structured, will need to load in llama models in run.py to avoid repeated loading in models.py
|
|
if args.backend == 'llama':
|
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
|
|
|
if args.quantize and args.quantize=='ptq_int4':
|
|
model = AutoModelForCausalLM.from_pretrained("src/tot/quant/hf_quant_int4", device_map="cuda", weights_only=False)
|
|
model = torch.compile(model, mode="max-autotune")
|
|
elif args.quantize and args.quantize=='ptq_int8':
|
|
model = AutoModelForCausalLM.from_pretrained("src/tot/quant/ptq_int8", device_map="cuda")
|
|
model = torch.compile(model, mode="max-autotune")
|
|
elif args.quantize and args.quantize == 'qat':
|
|
model = AutoModelForCausalLM.from_pretrained("src/tot/quant/qat_int8", device_map="cuda", weights_only=False)
|
|
model = torch.compile(model, mode="max-autotune")
|
|
else:
|
|
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
|
|
|
else: #gpt4 will be used later in this case
|
|
model = None
|
|
tokenizer = None
|
|
|
|
#set up
|
|
if args.task == 'comp_bench':
|
|
test_data = torch.load('src/tot/data/agg_dl_test.pt')
|
|
|
|
for sample in test_data:
|
|
|
|
start_timer = perf.counter()
|
|
|
|
WHILE WE GO THROUGH THE TREE
|
|
#solve
|
|
ys, info = solve_bench(args, model, tokenizer)
|
|
|
|
runtime = time.perf_counter()-start_timer
|
|
print(f"""For iteration {i} --
|
|
Total Time to Solve: {runtime} seconds""")
|
|
|
|
#evaluate/rate the proposals
|
|
|
|
|
|
#compare the proposed final answer vs the ground truth
|
|
gt = tokenizer.decode(sample['label'][0], return_tensors='pt').replace("<|eot_id|>", "").replace("<|begin_of_text|>", "").strip()
|
|
# log
|
|
# infos = [task.test_output(i, y) for y in ys]
|
|
# info.update({'idx': i, 'ys': ys, 'infos': infos, 'usage_so_far (gpt only)': gpt_usage(args.backend), 'total_runtime': runtime})
|
|
# logs.append(info)
|
|
# with open(file, 'w') as f:
|
|
# json.dump(logs, f, indent=4)
|
|
|
|
# log main metric
|
|
# accs = [info['r'] for info in infos]
|
|
# cnt_avg += sum(accs) / len(accs)
|
|
# cnt_any += any(accs)
|
|
# print(i, 'sum(accs)', sum(accs), 'cnt_avg', cnt_avg, 'cnt_any', cnt_any, '\n')
|
|
|
|
# n = args.task_end_index - args.task_start_index
|
|
# print(cnt_avg / n, cnt_any / n)
|
|
# print('usage_so_far', gpt_usage(args.backend))
|
|
|
|
else: #the original tot tasks from the paper
|
|
task = get_task(args.task)
|
|
logs, cnt_avg, cnt_any = [], 0, 0
|
|
if args.naive_run:
|
|
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
|
|
else:
|
|
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
|
|
os.makedirs(os.path.dirname(file), exist_ok=True)
|
|
|
|
#run the specified range of tasks
|
|
for i in range(args.task_start_index, args.task_end_index):
|
|
|
|
# solve
|
|
start_timer = time.perf_counter()
|
|
if args.naive_run:
|
|
ys, info = naive_solve(args, task, i, model, tokenizer)
|
|
else:
|
|
ys, info = solve(args, task, i, model, tokenizer)
|
|
|
|
runtime = time.perf_counter()-start_timer
|
|
print(f"""For iteration {i} --
|
|
Total Time to Solve: {runtime} seconds""")
|
|
|
|
# log
|
|
infos = [task.test_output(i, y) for y in ys]
|
|
info.update({'idx': i, 'ys': ys, 'infos': infos, 'usage_so_far (gpt only)': gpt_usage(args.backend), 'total_runtime': runtime})
|
|
logs.append(info)
|
|
with open(file, 'w') as f:
|
|
json.dump(logs, f, indent=4)
|
|
|
|
# log main metric
|
|
accs = [info['r'] for info in infos]
|
|
cnt_avg += sum(accs) / len(accs)
|
|
cnt_any += any(accs)
|
|
print(i, 'sum(accs)', sum(accs), 'cnt_avg', cnt_avg, 'cnt_any', cnt_any, '\n')
|
|
|
|
n = args.task_end_index - args.task_start_index
|
|
print(cnt_avg / n, cnt_any / n)
|
|
print('usage_so_far', gpt_usage(args.backend))
|
|
|
|
|
|
def parse_args():
|
|
'''
|
|
Determines the conditions for the run.
|
|
'''
|
|
args = argparse.ArgumentParser()
|
|
|
|
#the arguments to use for our purposes
|
|
args.add_argument('--backend', type=str, choices=['gpt-4o', 'llama'], default='gpt-4o')
|
|
args.add_argument('--quantize', type=str, choices=['qat', 'ptq_int4', 'ptq_int8'])
|
|
args.add_argument('--temperature', type=float, default=0.0)
|
|
args.add_argument('--method_generate', type=str, choices=['sample', 'propose'])
|
|
args.add_argument('--method_evaluate', type=str, choices=['value', 'vote'])
|
|
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'], default='greedy')
|
|
|
|
#other args from the original repo
|
|
args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords', 'comp_bench'])
|
|
|
|
args.add_argument('--task_start_index', type=int, default=900)
|
|
args.add_argument('--task_end_index', type=int, default=1000)
|
|
|
|
args.add_argument('--naive_run', action='store_true')
|
|
args.add_argument('--prompt_sample', type=str, choices=['standard', 'cot']) # only used when method_generate = sample, or naive_run
|
|
|
|
args.add_argument('--n_generate_sample', type=int, default=1) # only thing needed if naive_run
|
|
args.add_argument('--n_evaluate_sample', type=int, default=1)
|
|
args.add_argument('--n_select_sample', type=int, default=1)
|
|
|
|
args = args.parse_args()
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
print(args)
|
|
run(args) |