From ff04bbcd41e250a116062eb3e9380ec42cde4456 Mon Sep 17 00:00:00 2001 From: emilyworks Date: Sat, 7 Dec 2024 01:36:15 +0000 Subject: [PATCH] adding bench run scripts pt1 --- run_bench.py | 164 +++++++++++++++++++++++++++++++++++++++ src/tot/prompts/bench.py | 82 ++++++++++++++++++++ 2 files changed, 246 insertions(+) create mode 100644 run_bench.py create mode 100644 src/tot/prompts/bench.py diff --git a/run_bench.py b/run_bench.py new file mode 100644 index 0000000..c92a06b --- /dev/null +++ b/run_bench.py @@ -0,0 +1,164 @@ +import os +import json +import argparse +import time + +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +import torch.quantization + +from src.tot.prompts.bench import value_prompt, propose_prompt + +def load_llama(quant=None): + '''Load in one of the llama models''' + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") + + if args.quantize and args.quantize=='ptq_int4': + model = AutoModelForCausalLM.from_pretrained("src/tot/quant/hf_quant_int4", device_map="cuda", weights_only=False) + model = torch.compile(model, mode="max-autotune") + elif args.quantize and args.quantize=='ptq_int8': + model = AutoModelForCausalLM.from_pretrained("src/tot/quant/ptq_int8", device_map="cuda") + model = torch.compile(model, mode="max-autotune") + elif args.quantize and args.quantize == 'qat': + model = AutoModelForCausalLM.from_pretrained("src/tot/quant/qat_int8", device_map="cuda", weights_only=False) + model = torch.compile(model, mode="max-autotune") + else: + model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") + + return model, tokenizer + +def propose(problem, current_state, tokenizer, model, device): + pass + +def value_proposals(problem, current_state, proposals, tokenizer, model, device): + ''' + Takes in string values of problem, current state, and proposals. + Returns a numerical valuation of each combination of the three factors above. + ''' + valuations = [] + prompts = [] + for p in proposals: + prompts.append(value_prompt.format(problem=problem, current_state=current_state, proposal=proposal)) + + values = tokenizer(prompts, return_tensors='pt') + value_inputs = values['input_ids'].to(device) + value_masks = values['attention_mask'].to(device) + + outputs = model.generate(value_inputs, attention_mask=value_masks, max_new_tokens=5) + readable_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + for o in readable_outputs: + if 'sure' in o and 'current state is the solution' in o: + valuations.append(100.0) + elif 'sure' in o and 'current state is the solution' not in o: + valuations.append(1.0) + elif 'likely' in o: + valuations.append(0.5) + else: + valuations.append(0.0) + + return valuations + +def final_eval(): + pass + + +def run(args): + ''' + main run function + ''' + #load in specific llama model, if applicable + #bc of the way the original repo is structured, will need to load in llama models in run.py to avoid repeated loading in models.py + if args.backend == 'llama': + if args.quantize: + model, tokenizer = load_llama(args.quantize) + else: + model, tokenizer = load_llama() + else: #gpt4 will be used later in this case + model = None + tokenizer = None + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + model.to(device) + + #set up + test_data = torch.load('src/tot/data/agg_dl_test.pt') + + for samples in test_data: + + for sample in samples: #going to do this one problem at a time for now. + + #extract out the sample parts for the initial input + input_ids = sample['input_ids'].to(device) + label = sample['label'].to(device) + mask = sample['attention_mask'].to(device) + + problem = tokenizer.decode(sample, skip_special_tokens=True) + + #start solving via tot + start_timer = perf.counter() + + for i in range(args.depth): #args.depth number of attempts to reach the solution + + #propose next step/solutions per node/prompt + + out = model.generate( + input_ids, + attention_mask=mask, + temperature=args.temperature, + max_new_tokens=max_tokens, + num_return_sequences=args.breadth) + + + #evaluate/rate the proposals + current_state = tokenizer.decode(input_ids, skip_special_tokens=True) + + proposals = [] + for o in out: + string_answer = tokenizer.decode(o) + proposals.extend([string_answer]) + + valuations = value_proposals(problem=problem, current_state=current_state, proposals=proposals, tokenizer=tokenizer, model=model, device=device) + + #if the model believes it has reached the final solution before args.depth is up, break + if 100.0 in valuations: + break + + #select the proposals that act as input to the next iteration + val_props = list(zip(proposals, valuations)) + valuations.sort(key = lambda ele: ele[1], descending=True) + selected = valuations[:args.greedy_n] + + inputs = tokenizer(valuations, return_tensors='pt') + input_ids = inputs['input_ids'].to(device) + mask = inputs['attention_mask'].to(device) + + + #compare the proposed final answer vs the ground truth + gt = tokenizer.decode(label, skip_special_token=True) + + judgement = final_eval(gt, final_proposal) + + +def parse_args(): + ''' + Determines the conditions for the run. + ''' + args = argparse.ArgumentParser() + + #the arguments to use for our purposes + args.add_argument('--backend', type=str, choices=['gpt-4o', 'llama'], default='gpt-4o') + args.add_argument('--quantize', type=str, choices=['qat', 'ptq_int4', 'ptq_int8']) + args.add_argument('--temperature', type=float, default=0.0) + args.add_argument('--max_new_tokens', type=int, default=100) + args.add_argument('--depth', type=int, default=3) + args.add_argument('--breadth', type=int, default=3) + args.add_argument('--greedy_n', type=int, default=1) + + args = args.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + print(args) + run(args) \ No newline at end of file diff --git a/src/tot/prompts/bench.py b/src/tot/prompts/bench.py new file mode 100644 index 0000000..8327882 --- /dev/null +++ b/src/tot/prompts/bench.py @@ -0,0 +1,82 @@ +propose_prompt = '''Given a problem goal and a current state, propose one possible next step to solve the problem from the current state. If the current state solves the problem, say "current state is the solution". +### +Here are two examples to help: + +Example #1: +Problem Goal: +I have four numbers: 4, 4, 6, and 8. How can I use those numbers with basic arithmetic operations (+ - * /) to obtain 24? +Current State: +4+8 = 12 +I have 4, 6, and 12 left. +Possible next step: +6-4 = 2 + +Example #2: +Problem Goal: +Choose the choice that best answer the following question: + Question: + Davis decided to kill Adams. He set out for Adams's house. Before he got there he saw Brooks, who resembled Adams. Thinking that Brooks was Adams, Davis shot at Brooks. The shot missed Brooks but wounded Case, who was some distance away. Davis had not seen Case. In a prosecution under a statute that proscribes any attempt to commit murder, the district attorney should indicate that the intended victim(s) was/were + Choices: + ['Adams only.', 'Brooks only.', 'Case only.', 'Adams and Brooks'] +Current State: +Brooks Only. +Possible next step: +current state is the solution. + +### +Problem Goal: +{problem} +Current State: +{current_state} +Possible next step: +''' + +value_prompt = '''Given a problem goal, a current state, and a proposed next step, evaluate if the next step from the current state can help solve or answer the problem (yes/likely/no) +### +Here are three examples to help: +Example #1: +Problem Goal: +I have four numbers: 4, 4, 6, and 8. How can I use those numbers with basic arithmetic operations (+ - * /) to obtain 24? +Current State: +4+8 = 12 (remaining numbers: 4, 6, 12) +Proposal: +6-4 = 2 (remaining numbers: 2, 12) +Evaluation: +likely + +Example #2: +Problem Goal: +Choose the choice that best answer the following question: + Question: + Davis decided to kill Adams. He set out for Adams's house. Before he got there he saw Brooks, who resembled Adams. Thinking that Brooks was Adams, Davis shot at Brooks. The shot missed Brooks but wounded Case, who was some distance away. Davis had not seen Case. In a prosecution under a statute that proscribes any attempt to commit murder, the district attorney should indicate that the intended victim(s) was/were + Choices: + ['Adams only.', 'Brooks only.', 'Case only.', 'Adams and Brooks'] +Current State: +Adams and Brooks. +Proposal: +current state is the solution. +Evaluation: +no + +Example #3: +Choose the choice that best answer the following question: + Question: + Davis decided to kill Adams. He set out for Adams's house. Before he got there he saw Brooks, who resembled Adams. Thinking that Brooks was Adams, Davis shot at Brooks. The shot missed Brooks but wounded Case, who was some distance away. Davis had not seen Case. In a prosecution under a statute that proscribes any attempt to commit murder, the district attorney should indicate that the intended victim(s) was/were + Choices: + ['Adams only.', 'Brooks only.', 'Case only.', 'Adams and Brooks'] +Current State: +Brooks Only. +Proposal: +current state is the solution. +Evaluation: +yes + +### +Problem Goal: +{problem} +Current State: +{current_state} +Proposal: +{proposal} +Evaluation: +''' \ No newline at end of file