mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-05-22 03:46:02 +00:00
code pt2 for run_bench
This commit is contained in:
parent
c69681e3a0
commit
e968d35da4
179
run_bench.py
179
run_bench.py
@ -7,8 +7,14 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
import torch.quantization
|
||||
|
||||
from src.tot.data.benchmark.bench import *
|
||||
from src.tot.prompts.bench import value_prompt, propose_prompt
|
||||
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
import random
|
||||
import multiprocessing
|
||||
|
||||
def load_llama(quant=None):
|
||||
'''Load in one of the llama models'''
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
||||
@ -62,19 +68,103 @@ def final_eval(gt, final_prop):
|
||||
print("THIS IS THE FINAL PROP")
|
||||
print(final_prop)
|
||||
print("THIS IS THE GT")
|
||||
print(gt)
|
||||
|
||||
|
||||
if gt in final_prop:
|
||||
return 1.0
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def get_test_data(tokenizer, batch_size):
|
||||
'''
|
||||
Process and return the composite benchmark test data in a dataloader
|
||||
'''
|
||||
# print(tokenizer)
|
||||
|
||||
gpqa_raw = load_dataset("Idavidrein/gpqa", "gpqa_diamond")
|
||||
gpqa_choices = [[a, b, c, d] for a, b, c, d in
|
||||
zip(gpqa_raw['train']['Correct Answer'], gpqa_raw['train']['Incorrect Answer 1'],
|
||||
gpqa_raw['train']['Incorrect Answer 2'], gpqa_raw['train']['Incorrect Answer 3'])]
|
||||
for choices in gpqa_choices:
|
||||
random.shuffle(choices)
|
||||
|
||||
gpqa_questions_proc = format_for_mm(gpqa_raw['train']['Question'], gpqa_choices)
|
||||
|
||||
#math (for math)
|
||||
math_raw = load_dataset("lighteval/MATH", "all")
|
||||
|
||||
# #mmlu (for gen knowledge + reasoning)
|
||||
mmlu_raw = load_dataset("cais/mmlu", "all")
|
||||
mmlu_questions_proc_test = format_for_mm(mmlu_raw['test']['question'], mmlu_raw['test']['choices'])
|
||||
|
||||
#master list - test
|
||||
sublist_input_test = gpqa_questions_proc[158:] + math_raw['test']['problem'] + mmlu_questions_proc_test
|
||||
sublist_answer_test = gpqa_raw['train']['Correct Answer'][158:] + math_raw['test']['solution'] + mmlu_raw['test']['answer']
|
||||
agg_test_set = benchmark_dataset(sublist_input_test, sublist_answer_test, tokenizer)
|
||||
|
||||
return DataLoader(agg_test_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_qat)
|
||||
|
||||
def solve(input_ids, label, mask, model, tokenizer, device, args):
|
||||
'''
|
||||
the main tot run
|
||||
'''
|
||||
|
||||
problem = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
||||
# print(problem)
|
||||
selected = ""
|
||||
for i in range(args.depth): #args.depth number of attempts to reach the solution
|
||||
|
||||
#propose next step/solutions per node/prompt
|
||||
|
||||
out = model.generate(
|
||||
input_ids,
|
||||
attention_mask=mask,
|
||||
temperature=args.temperature,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
num_return_sequences=args.breadth,
|
||||
)
|
||||
|
||||
|
||||
#evaluate/rate the proposals
|
||||
current_state = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
||||
|
||||
proposals = []
|
||||
for o in out:
|
||||
string_answer = tokenizer.decode(o[-args.max_new_tokens:], skip_special_tokens=True)
|
||||
assert isinstance(string_answer, str)
|
||||
proposals.extend([string_answer])
|
||||
|
||||
valuations = value_proposals(problem=problem, current_state=current_state, proposals=proposals, tokenizer=tokenizer, model=model, device=device)
|
||||
|
||||
#if the model believes it has reached the final solution before args.depth is up, break
|
||||
if 100.0 in valuations:
|
||||
break
|
||||
|
||||
#select the best proposal
|
||||
val_props = list(zip(proposals, valuations))
|
||||
val_props.sort(key = lambda ele: ele[1], reverse=True)
|
||||
selected = val_props[:args.greedy_n][0][0]
|
||||
|
||||
# print("THIS IS SELCTED")
|
||||
# print(selected)
|
||||
|
||||
#format the chosen proposal for the next iteration
|
||||
next_prompt = propose_prompt.format(problem=problem, current_state=selected)
|
||||
inputs = tokenizer(next_prompt, return_tensors='pt')
|
||||
input_ids = inputs['input_ids'].to(device)
|
||||
mask = inputs['attention_mask'].to(device)
|
||||
|
||||
|
||||
#compare the proposed final answer vs the ground truth
|
||||
gt = tokenizer.batch_decode(label, skip_special_tokens=True)
|
||||
judgement = final_eval(gt[0], selected)
|
||||
|
||||
return judgement
|
||||
|
||||
def run(args):
|
||||
'''
|
||||
main run function
|
||||
'''
|
||||
#load in specific llama model, if applicable
|
||||
### SETUP MODEL ###
|
||||
#bc of the way the original repo is structured, will need to load in llama models in run.py to avoid repeated loading in models.py
|
||||
if args.backend == 'llama':
|
||||
if args.quantize:
|
||||
@ -85,78 +175,34 @@ def run(args):
|
||||
model = None
|
||||
tokenizer = None
|
||||
|
||||
### SETUP DATA ###
|
||||
test_data = get_test_data(tokenizer, args.concurrent)
|
||||
|
||||
### OTHER SETUP ###
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model.to(device)
|
||||
|
||||
#set up
|
||||
test_data = torch.load('src/tot/data/agg_dl_test.pt')
|
||||
|
||||
total = 0
|
||||
right = 0
|
||||
|
||||
for samples in test_data:
|
||||
for sample in test_data:
|
||||
|
||||
#extract out the sample parts for the initial input
|
||||
input_ids = sample['input_ids'].to(device)
|
||||
label = sample['label'].to(device)
|
||||
mask = sample['attention_mask'].to(device)
|
||||
|
||||
for sample in samples: #going to do this one problem at a time for now.
|
||||
|
||||
#extract out the sample parts for the initial input
|
||||
input_ids = sample['input_ids'].to(device)
|
||||
label = sample['label'].to(device)
|
||||
mask = sample['attention_mask'].to(device)
|
||||
|
||||
problem = tokenizer.decode(sample, skip_special_tokens=True)
|
||||
|
||||
#start solving via tot
|
||||
start_timer = perf.counter()
|
||||
|
||||
selected = ""
|
||||
for i in range(args.depth): #args.depth number of attempts to reach the solution
|
||||
|
||||
#propose next step/solutions per node/prompt
|
||||
|
||||
out = model.generate(
|
||||
input_ids,
|
||||
attention_mask=mask,
|
||||
temperature=args.temperature,
|
||||
max_new_tokens=args.max_tokens,
|
||||
num_return_sequences=args.breadth)
|
||||
|
||||
|
||||
#evaluate/rate the proposals
|
||||
current_state = tokenizer.decode(input_ids, skip_special_tokens=True)
|
||||
|
||||
proposals = []
|
||||
for o in out:
|
||||
string_answer = tokenizer.decode(o)
|
||||
proposals.extend([string_answer])
|
||||
|
||||
valuations = value_proposals(problem=problem, current_state=current_state, proposals=proposals, tokenizer=tokenizer, model=model, device=device)
|
||||
|
||||
#if the model believes it has reached the final solution before args.depth is up, break
|
||||
if 100.0 in valuations:
|
||||
break
|
||||
|
||||
#select the best proposal
|
||||
val_props = list(zip(proposals, valuations))
|
||||
val_props.sort(key = lambda ele: ele[1], descending=True)
|
||||
selected = val_props[:args.greedy_n]
|
||||
|
||||
#format the chosen proposal for the next iteration
|
||||
next_prompt = propose_prompt.format(problem=problem, current_state=selected)
|
||||
inputs = tokenizer(next_prompt, return_tensors='pt')
|
||||
input_ids = inputs['input_ids'].to(device)
|
||||
mask = inputs['attention_mask'].to(device)
|
||||
|
||||
|
||||
#compare the proposed final answer vs the ground truth
|
||||
gt = tokenizer.decode(label, skip_special_token=True)
|
||||
judgement = final_eval(gt, selected)
|
||||
|
||||
#keep track of the running totals
|
||||
#cannot get multiple gpus. will run this on a single gpu one sample at a time for simplicity
|
||||
for i in range(len(input_ids)):
|
||||
judgement = solve(input_ids[i].view(1,-1), label[i].view(1,-1), mask[i].view(1,-1), model, tokenizer, device, args)
|
||||
total += 1.0
|
||||
right += judgement
|
||||
print("Accuracy so far: ", total/right)
|
||||
|
||||
total_accuracy = right/total
|
||||
#keep track of the running totals
|
||||
print("Accuracy so far: ", right/total)
|
||||
|
||||
print("FINAL ACCURACY: ", right/total)
|
||||
|
||||
|
||||
|
||||
def parse_args():
|
||||
'''
|
||||
@ -172,6 +218,7 @@ def parse_args():
|
||||
args.add_argument('--depth', type=int, default=3)
|
||||
args.add_argument('--breadth', type=int, default=3)
|
||||
args.add_argument('--greedy_n', type=int, default=1)
|
||||
args.add_argument('--concurrent', type=int, default=4)
|
||||
|
||||
args = args.parse_args()
|
||||
return args
|
||||
|
@ -0,0 +1,79 @@
|
||||
import torch
|
||||
import torch.quantization
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
import random
|
||||
import multiprocessing
|
||||
import os
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
class benchmark_dataset(torch.utils.data.Dataset):
|
||||
'''formats the data for dataloader'''
|
||||
|
||||
def __init__(self, input, labels, tokenizer, filter_n=150):
|
||||
'''constructor. input samples and output labels'''
|
||||
|
||||
self.input = input
|
||||
self.labels = labels
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.filter_len(filter_n)
|
||||
|
||||
def filter_len(self, n):
|
||||
|
||||
new_input = []
|
||||
new_label = []
|
||||
|
||||
for q, a in zip(self.input, self.labels):
|
||||
tk_len_q = len(tokenizer(str(q), return_tensors='pt')['input_ids'][0])
|
||||
tk_len_a = len(tokenizer(str(a), return_tensors='pt')['input_ids'][0])
|
||||
|
||||
if tk_len_q <= n and tk_len_a <= n:
|
||||
new_input.append(q)
|
||||
new_label.append(a)
|
||||
|
||||
print(f"""
|
||||
Len of Original Input: {len(self.input)}
|
||||
Len of Original Labels: {len(self.labels)}
|
||||
Len of New_Input: {len(new_input)}
|
||||
Len of New_Label: {len(new_label)}
|
||||
|
||||
Sample Input, Label: {new_input[0], new_label[0]}
|
||||
|
||||
""")
|
||||
|
||||
self.input = new_input
|
||||
self.labels = new_label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
||||
return {"question": self.input[idx], "answer": self.labels[idx]}
|
||||
|
||||
def format_for_mm(question, choices):
|
||||
'''
|
||||
Formats questions and choices into one multiple-choice-question string
|
||||
'''
|
||||
return [f"""Choose the choice that best answer the following question:
|
||||
Question:
|
||||
{q.strip()}
|
||||
Choices:
|
||||
{c}
|
||||
"""
|
||||
for q, c in zip(question, choices)]
|
||||
|
||||
def collate_fn_qat(batch):
|
||||
|
||||
# Now collate into mini-batches
|
||||
inputs = tokenizer([i['question'] for i in batch], return_tensors='pt', padding='max_length', truncation=True, max_length=150)
|
||||
# labels = tokenizer([str(i['answer']) for i in batch], return_tensors='pt', padding='max_length', truncation=True, max_length=65)
|
||||
labels = tokenizer([str(i['answer']) for i in batch], return_tensors='pt', padding='max_length', truncation=True, max_length=150)
|
||||
|
||||
# labels = [ele[-100:] for ele in labels['input_ids']]
|
||||
|
||||
return {'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask'], 'label': labels['input_ids']}
|
@ -62,7 +62,7 @@ def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
|
||||
def solve(args, task, idx, model, tokenizer, to_print=True):
|
||||
global inference_model
|
||||
inference_model = partial(inference_model, model=model, tokenizer=tokenizer, temperature=args.temperature)
|
||||
# print(inference_model)
|
||||
print(inference_model)
|
||||
x = task.get_input(idx) # input
|
||||
ys = [''] # current output candidates
|
||||
infos = []
|
||||
@ -145,3 +145,72 @@ def naive_solve(args, task, idx, to_print=True):
|
||||
x = task.get_input(idx) # input
|
||||
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
|
||||
return ys, {}
|
||||
|
||||
def solve_bench(args, model, tokenizer, to_print=True, depth=5, breadth=2):
|
||||
global inference_model
|
||||
inference_model = partial(inference_model, model=model, tokenizer=tokenizer, temperature=args.temperature)
|
||||
|
||||
ys = [''] # current output candidates
|
||||
infos = []
|
||||
|
||||
for step in range(5):
|
||||
# print("Started Steps!!")
|
||||
|
||||
# generation
|
||||
print("Started Generation...")
|
||||
new_ys_with_times = []
|
||||
|
||||
generated_ys, generate_times = get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step])
|
||||
new_ys_with_times.extend(zip(generated_ys, generate_times))
|
||||
|
||||
new_ys, generate_times = zip(*new_ys_with_times)
|
||||
new_ys = list(new_ys)
|
||||
generate_times = list(generate_times)
|
||||
|
||||
# new_ys = list(itertools.chain(new_ys))
|
||||
new_ys = ys + new_ys # extend the current queue with the frontier
|
||||
|
||||
ids = list(range(len(new_ys)))
|
||||
|
||||
# evaluation
|
||||
# shouldn't worry about reevaluating the same ys as the values should be saved in the task cache
|
||||
# but could potentially add logic to remove expanded from queue
|
||||
print("Finished Generation...Started Eval!")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if args.method_evaluate == 'vote':
|
||||
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
|
||||
elif args.method_evaluate == 'value':
|
||||
values, eval_times = get_values(task, x, new_ys, args.n_evaluate_sample)
|
||||
|
||||
eval_time = time.perf_counter()-start_time
|
||||
print(f"Node Eval Time: {eval_time} seconds")
|
||||
|
||||
# selection
|
||||
print("Finished Eval...Started Selection...")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if args.method_select == 'sample':
|
||||
ps = np.array(values) / sum(values)
|
||||
select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
|
||||
elif args.method_select == 'greedy':
|
||||
select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
|
||||
select_new_ys = [new_ys[select_id] for select_id in select_ids]
|
||||
|
||||
selection_time = time.perf_counter()-start_time()
|
||||
print(f"Selection Time: {selection_time} seconds")
|
||||
|
||||
# log
|
||||
print("Finished Selection...Logging...")
|
||||
if to_print:
|
||||
sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
|
||||
print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n-- generate times --: {generate_times}\n-- eval times --: {eval_times}\n')
|
||||
|
||||
infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys, 'generate_times': generate_times, 'eval_times': eval_times})
|
||||
ys = select_new_ys
|
||||
|
||||
if to_print:
|
||||
print(ys)
|
||||
return ys, {'steps': infos}
|
@ -1,4 +1,4 @@
|
||||
propose_prompt = '''Given a problem goal and a current state, propose one possible next step to solve the problem from the current state. If the current state solves the problem, say "current state is the solution".
|
||||
propose_prompt = '''[INST]Given a problem goal and a current state, propose one possible next step to solve the problem from the current state. If the current state solves the problem, say "current state is the solution".
|
||||
###
|
||||
Here are two examples to help:
|
||||
|
||||
@ -19,7 +19,7 @@ Choose the choice that best answer the following question:
|
||||
Choices:
|
||||
['Adams only.', 'Brooks only.', 'Case only.', 'Adams and Brooks']
|
||||
Current State:
|
||||
Brooks Only.
|
||||
Choice 2
|
||||
Possible next step:
|
||||
current state is the solution.
|
||||
|
||||
@ -28,10 +28,10 @@ Problem Goal:
|
||||
{problem}
|
||||
Current State:
|
||||
{current_state}
|
||||
Possible next step:
|
||||
Possible next step:[\INST]
|
||||
'''
|
||||
|
||||
value_prompt = '''Given a problem goal, a current state, and a proposed next step, evaluate if the next step from the current state can help solve or answer the problem (yes/likely/no)
|
||||
value_prompt = '''[INST]Given a problem goal, a current state, and a proposed next step, evaluate if the next step from the current state can help solve or answer the problem (yes/likely/no)
|
||||
###
|
||||
Here are three examples to help:
|
||||
Example #1:
|
||||
@ -52,7 +52,7 @@ Choose the choice that best answer the following question:
|
||||
Choices:
|
||||
['Adams only.', 'Brooks only.', 'Case only.', 'Adams and Brooks']
|
||||
Current State:
|
||||
Adams and Brooks.
|
||||
Choice 4
|
||||
Proposal:
|
||||
current state is the solution.
|
||||
Evaluation:
|
||||
@ -65,7 +65,7 @@ Choose the choice that best answer the following question:
|
||||
Choices:
|
||||
['Adams only.', 'Brooks only.', 'Case only.', 'Adams and Brooks']
|
||||
Current State:
|
||||
Brooks Only.
|
||||
Choice 2
|
||||
Proposal:
|
||||
current state is the solution.
|
||||
Evaluation:
|
||||
@ -78,5 +78,5 @@ Current State:
|
||||
{current_state}
|
||||
Proposal:
|
||||
{proposal}
|
||||
Evaluation:
|
||||
Evaluation:[\INST]
|
||||
'''
|
@ -11,7 +11,7 @@ def get_current_numbers(y: str) -> str:
|
||||
return last_line.split('left: ')[-1].split(')')[0]
|
||||
|
||||
|
||||
class Game24Task(Task):
|
||||
class Bench(Task):
|
||||
"""
|
||||
Input (x) : a string of
|
||||
Output (y) : a trajectory of 3 steps to reach 24
|
||||
@ -24,7 +24,7 @@ class Game24Task(Task):
|
||||
6 * 4 = 24 (left: 24)
|
||||
(1 + 2 + 3) * 4 = 24
|
||||
"""
|
||||
def __init__(self, file='24.csv', depth=5):
|
||||
def __init__(self, depth=5):
|
||||
"""
|
||||
file: a csv file (fixed)
|
||||
"""
|
||||
@ -34,13 +34,14 @@ class Game24Task(Task):
|
||||
self.steps = depth
|
||||
self.stops = None
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
# def __len__(self) -> int:
|
||||
# return len(self.data)
|
||||
|
||||
def get_input(self, idx: int) -> str:
|
||||
return self.data[idx]
|
||||
# def get_input(self, idx: int) -> str:
|
||||
# return self.data[idx]
|
||||
|
||||
def test_output(self, idx: int, output: str):
|
||||
gt =
|
||||
expression = output.strip().split('\n')[-1].lower().replace('answer: ', '').split('=')[0]
|
||||
numbers = re.findall(r'\d+', expression)
|
||||
problem_numbers = re.findall(r'\d+', self.data[idx])
|
||||
@ -55,13 +56,13 @@ class Game24Task(Task):
|
||||
print("entered exception!!")
|
||||
return {'r': 0}
|
||||
|
||||
@staticmethod
|
||||
def standard_prompt_wrap(x: str, y:str='') -> str:
|
||||
return standard_prompt.format(input=x) + y
|
||||
# @staticmethod
|
||||
# def standard_prompt_wrap(x: str, y:str='') -> str:
|
||||
# return standard_prompt.format(input=x) + y
|
||||
|
||||
@staticmethod
|
||||
def cot_prompt_wrap(x: str, y:str='') -> str:
|
||||
return cot_prompt.format(input=x) + y
|
||||
# @staticmethod
|
||||
# def cot_prompt_wrap(x: str, y:str='') -> str:
|
||||
# return cot_prompt.format(input=x) + y
|
||||
|
||||
@staticmethod
|
||||
def propose_prompt_wrap(x: str, y: str='') -> str:
|
||||
|
Loading…
Reference in New Issue
Block a user