adding bench run scripts pt1

This commit is contained in:
emilyworks 2024-12-07 01:36:15 +00:00
parent 2df2e0238b
commit ff04bbcd41
2 changed files with 246 additions and 0 deletions

164
run_bench.py Normal file
View File

@ -0,0 +1,164 @@
import os
import json
import argparse
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.quantization
from src.tot.prompts.bench import value_prompt, propose_prompt
def load_llama(quant=None):
'''Load in one of the llama models'''
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
if args.quantize and args.quantize=='ptq_int4':
model = AutoModelForCausalLM.from_pretrained("src/tot/quant/hf_quant_int4", device_map="cuda", weights_only=False)
model = torch.compile(model, mode="max-autotune")
elif args.quantize and args.quantize=='ptq_int8':
model = AutoModelForCausalLM.from_pretrained("src/tot/quant/ptq_int8", device_map="cuda")
model = torch.compile(model, mode="max-autotune")
elif args.quantize and args.quantize == 'qat':
model = AutoModelForCausalLM.from_pretrained("src/tot/quant/qat_int8", device_map="cuda", weights_only=False)
model = torch.compile(model, mode="max-autotune")
else:
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
return model, tokenizer
def propose(problem, current_state, tokenizer, model, device):
pass
def value_proposals(problem, current_state, proposals, tokenizer, model, device):
'''
Takes in string values of problem, current state, and proposals.
Returns a numerical valuation of each combination of the three factors above.
'''
valuations = []
prompts = []
for p in proposals:
prompts.append(value_prompt.format(problem=problem, current_state=current_state, proposal=proposal))
values = tokenizer(prompts, return_tensors='pt')
value_inputs = values['input_ids'].to(device)
value_masks = values['attention_mask'].to(device)
outputs = model.generate(value_inputs, attention_mask=value_masks, max_new_tokens=5)
readable_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
for o in readable_outputs:
if 'sure' in o and 'current state is the solution' in o:
valuations.append(100.0)
elif 'sure' in o and 'current state is the solution' not in o:
valuations.append(1.0)
elif 'likely' in o:
valuations.append(0.5)
else:
valuations.append(0.0)
return valuations
def final_eval():
pass
def run(args):
'''
main run function
'''
#load in specific llama model, if applicable
#bc of the way the original repo is structured, will need to load in llama models in run.py to avoid repeated loading in models.py
if args.backend == 'llama':
if args.quantize:
model, tokenizer = load_llama(args.quantize)
else:
model, tokenizer = load_llama()
else: #gpt4 will be used later in this case
model = None
tokenizer = None
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
#set up
test_data = torch.load('src/tot/data/agg_dl_test.pt')
for samples in test_data:
for sample in samples: #going to do this one problem at a time for now.
#extract out the sample parts for the initial input
input_ids = sample['input_ids'].to(device)
label = sample['label'].to(device)
mask = sample['attention_mask'].to(device)
problem = tokenizer.decode(sample, skip_special_tokens=True)
#start solving via tot
start_timer = perf.counter()
for i in range(args.depth): #args.depth number of attempts to reach the solution
#propose next step/solutions per node/prompt
out = model.generate(
input_ids,
attention_mask=mask,
temperature=args.temperature,
max_new_tokens=max_tokens,
num_return_sequences=args.breadth)
#evaluate/rate the proposals
current_state = tokenizer.decode(input_ids, skip_special_tokens=True)
proposals = []
for o in out:
string_answer = tokenizer.decode(o)
proposals.extend([string_answer])
valuations = value_proposals(problem=problem, current_state=current_state, proposals=proposals, tokenizer=tokenizer, model=model, device=device)
#if the model believes it has reached the final solution before args.depth is up, break
if 100.0 in valuations:
break
#select the proposals that act as input to the next iteration
val_props = list(zip(proposals, valuations))
valuations.sort(key = lambda ele: ele[1], descending=True)
selected = valuations[:args.greedy_n]
inputs = tokenizer(valuations, return_tensors='pt')
input_ids = inputs['input_ids'].to(device)
mask = inputs['attention_mask'].to(device)
#compare the proposed final answer vs the ground truth
gt = tokenizer.decode(label, skip_special_token=True)
judgement = final_eval(gt, final_proposal)
def parse_args():
'''
Determines the conditions for the run.
'''
args = argparse.ArgumentParser()
#the arguments to use for our purposes
args.add_argument('--backend', type=str, choices=['gpt-4o', 'llama'], default='gpt-4o')
args.add_argument('--quantize', type=str, choices=['qat', 'ptq_int4', 'ptq_int8'])
args.add_argument('--temperature', type=float, default=0.0)
args.add_argument('--max_new_tokens', type=int, default=100)
args.add_argument('--depth', type=int, default=3)
args.add_argument('--breadth', type=int, default=3)
args.add_argument('--greedy_n', type=int, default=1)
args = args.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
print(args)
run(args)

82
src/tot/prompts/bench.py Normal file
View File

@ -0,0 +1,82 @@
propose_prompt = '''Given a problem goal and a current state, propose one possible next step to solve the problem from the current state. If the current state solves the problem, say "current state is the solution".
###
Here are two examples to help:
Example #1:
Problem Goal:
I have four numbers: 4, 4, 6, and 8. How can I use those numbers with basic arithmetic operations (+ - * /) to obtain 24?
Current State:
4+8 = 12
I have 4, 6, and 12 left.
Possible next step:
6-4 = 2
Example #2:
Problem Goal:
Choose the choice that best answer the following question:
Question:
Davis decided to kill Adams. He set out for Adams's house. Before he got there he saw Brooks, who resembled Adams. Thinking that Brooks was Adams, Davis shot at Brooks. The shot missed Brooks but wounded Case, who was some distance away. Davis had not seen Case. In a prosecution under a statute that proscribes any attempt to commit murder, the district attorney should indicate that the intended victim(s) was/were
Choices:
['Adams only.', 'Brooks only.', 'Case only.', 'Adams and Brooks']
Current State:
Brooks Only.
Possible next step:
current state is the solution.
###
Problem Goal:
{problem}
Current State:
{current_state}
Possible next step:
'''
value_prompt = '''Given a problem goal, a current state, and a proposed next step, evaluate if the next step from the current state can help solve or answer the problem (yes/likely/no)
###
Here are three examples to help:
Example #1:
Problem Goal:
I have four numbers: 4, 4, 6, and 8. How can I use those numbers with basic arithmetic operations (+ - * /) to obtain 24?
Current State:
4+8 = 12 (remaining numbers: 4, 6, 12)
Proposal:
6-4 = 2 (remaining numbers: 2, 12)
Evaluation:
likely
Example #2:
Problem Goal:
Choose the choice that best answer the following question:
Question:
Davis decided to kill Adams. He set out for Adams's house. Before he got there he saw Brooks, who resembled Adams. Thinking that Brooks was Adams, Davis shot at Brooks. The shot missed Brooks but wounded Case, who was some distance away. Davis had not seen Case. In a prosecution under a statute that proscribes any attempt to commit murder, the district attorney should indicate that the intended victim(s) was/were
Choices:
['Adams only.', 'Brooks only.', 'Case only.', 'Adams and Brooks']
Current State:
Adams and Brooks.
Proposal:
current state is the solution.
Evaluation:
no
Example #3:
Choose the choice that best answer the following question:
Question:
Davis decided to kill Adams. He set out for Adams's house. Before he got there he saw Brooks, who resembled Adams. Thinking that Brooks was Adams, Davis shot at Brooks. The shot missed Brooks but wounded Case, who was some distance away. Davis had not seen Case. In a prosecution under a statute that proscribes any attempt to commit murder, the district attorney should indicate that the intended victim(s) was/were
Choices:
['Adams only.', 'Brooks only.', 'Case only.', 'Adams and Brooks']
Current State:
Brooks Only.
Proposal:
current state is the solution.
Evaluation:
yes
###
Problem Goal:
{problem}
Current State:
{current_state}
Proposal:
{proposal}
Evaluation:
'''