mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-04-23 15:44:04 +00:00
fix for get_value function modification
This commit is contained in:
parent
5796bd562c
commit
033978c4aa
18
run.py
18
run.py
@ -17,8 +17,19 @@ def run(args):
|
||||
'''
|
||||
#load in non-gpt model in this driver function for now to avoid repeated loading later on
|
||||
if args.backend == 'llama':
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
||||
if not args.quantize:
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
||||
elif args.quantize == 'qat':
|
||||
pass
|
||||
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
||||
# model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8")
|
||||
elif args.backend == 'spinquant':
|
||||
pass
|
||||
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
|
||||
# model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8")
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
model = None
|
||||
tokenizer = None
|
||||
@ -43,7 +54,7 @@ def run(args):
|
||||
ys, info = solve(args, task, i, model, tokenizer)
|
||||
|
||||
runtime = time.perf_counter()-start_timer
|
||||
print(runtime)
|
||||
# print(runtime)
|
||||
|
||||
# log
|
||||
infos = [task.test_output(i, y) for y in ys]
|
||||
@ -71,6 +82,7 @@ def parse_args():
|
||||
|
||||
#what model to use
|
||||
args.add_argument('--backend', type=str, choices=['gpt-4o', 'llama'], default='gpt-4o')
|
||||
args.add_argument('--quantize', type=str, choices=['qat', 'ptq', 'spinquant'])
|
||||
|
||||
#what temperature to use
|
||||
args.add_argument('--temperature', type=float, default=0.0)
|
||||
|
@ -6,26 +6,33 @@ from ..models import inference_model
|
||||
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
|
||||
value_prompt = task.value_prompt_wrap(x, y)
|
||||
if cache_value and value_prompt in task.value_cache:
|
||||
return task.value_cache[value_prompt]
|
||||
value_outputs = inference_model(value_prompt, n=n_evaluate_sample, stop=None)
|
||||
return task.value_cache[value_prompt], 0.0 #0 for inference latency bc cache was used
|
||||
value_outputs, eval_time = inference_model(value_prompt, n=n_evaluate_sample, stop=None)
|
||||
|
||||
value = task.value_outputs_unwrap(x, y, value_outputs)
|
||||
|
||||
if cache_value:
|
||||
task.value_cache[value_prompt] = value
|
||||
return value
|
||||
|
||||
return [value, eval_time[0]] #assumes one value, one time element
|
||||
|
||||
def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
|
||||
|
||||
values = []
|
||||
times = []
|
||||
local_value_cache = {}
|
||||
for y in ys: # each partial output
|
||||
if y in local_value_cache: # avoid duplicate candidates
|
||||
value = 0
|
||||
time = 0
|
||||
else:
|
||||
value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
|
||||
value, time = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
|
||||
local_value_cache[y] = value
|
||||
|
||||
values.append(value)
|
||||
times.append(time)
|
||||
# print(values)
|
||||
return values
|
||||
return values, times
|
||||
|
||||
def get_votes(task, x, ys, n_evaluate_sample):
|
||||
vote_prompt = task.vote_prompt_wrap(x, ys)
|
||||
@ -34,10 +41,12 @@ def get_votes(task, x, ys, n_evaluate_sample):
|
||||
return values
|
||||
|
||||
def get_proposals(task, x, y):
|
||||
final_proposals = []
|
||||
propose_prompt = task.propose_prompt_wrap(x, y)
|
||||
proposals = inference_model(propose_prompt, n=5, stop=None)[0].split('\n')
|
||||
|
||||
return [y + _ + '\n' for _ in proposals]
|
||||
proposals, generate_times = inference_model(propose_prompt, n=5, stop=None)
|
||||
for prop in proposals:
|
||||
final_proposals.extend(prop.split('\n'))
|
||||
return ([y + _ + '\n' for _ in final_proposals], generate_times)
|
||||
|
||||
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
|
||||
if prompt_sample == 'standard':
|
||||
@ -59,20 +68,29 @@ def solve(args, task, idx, model, tokenizer, to_print=True):
|
||||
|
||||
for step in range(task.steps):
|
||||
print("Started Steps!!")
|
||||
|
||||
# generation
|
||||
if args.method_generate == 'sample':
|
||||
new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
|
||||
# elif args.method_generate == 'propose':
|
||||
else:
|
||||
new_ys = [get_proposals(task, x, y) for y in ys]
|
||||
new_ys = list(itertools.chain(*new_ys))
|
||||
new_ys, generate_times = new_ys[0]
|
||||
|
||||
|
||||
|
||||
new_ys = list(itertools.chain(new_ys))
|
||||
ids = list(range(len(new_ys)))
|
||||
|
||||
print("these are the new ys")
|
||||
print(new_ys)
|
||||
print("****")
|
||||
|
||||
# evaluation
|
||||
if args.method_evaluate == 'vote':
|
||||
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
|
||||
elif args.method_evaluate == 'value':
|
||||
values = get_values(task, x, new_ys, args.n_evaluate_sample)
|
||||
values, eval_times = get_values(task, x, new_ys, args.n_evaluate_sample)
|
||||
|
||||
# selection
|
||||
if args.method_select == 'sample':
|
||||
@ -85,9 +103,9 @@ def solve(args, task, idx, model, tokenizer, to_print=True):
|
||||
# log
|
||||
if to_print:
|
||||
sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
|
||||
print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
|
||||
print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n-- generate times --: {generate_times}\n-- eval times --: {eval_times}\n')
|
||||
|
||||
infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
|
||||
infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys, 'generate_times': generate_times, 'eval_times': eval_times})
|
||||
ys = select_new_ys
|
||||
|
||||
if to_print:
|
||||
|
@ -4,6 +4,7 @@ from openai import OpenAI
|
||||
import backoff
|
||||
import torch
|
||||
import transformers
|
||||
import time
|
||||
|
||||
completion_tokens = prompt_tokens = 0
|
||||
|
||||
@ -39,7 +40,7 @@ def hf_model(model, tokenizer, prompt, temperature=0.7, max_tokens=1000, n=5, st
|
||||
Given a model (Huggingface) and input tokens, generate an output
|
||||
"""
|
||||
outputs = []
|
||||
|
||||
all_times = []
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
#tokenize inputs
|
||||
@ -53,13 +54,17 @@ def hf_model(model, tokenizer, prompt, temperature=0.7, max_tokens=1000, n=5, st
|
||||
n -= cnt
|
||||
|
||||
#actual generation
|
||||
start_time = time.perf_counter()
|
||||
out = model.generate(**inputs, temperature=temperature, max_new_tokens=max_tokens, num_return_sequences=cnt) #might add stopping criteria depending on heuristics experimentation
|
||||
generate_time = time.perf_counter()-start_time
|
||||
|
||||
for o in out:
|
||||
string_answer = tokenizer.decode(o)
|
||||
outputs.extend([string_answer])
|
||||
|
||||
return outputs
|
||||
all_times.extend([generate_time/(float(cnt))]*cnt) #may modify this later to be more precise given hf is open source
|
||||
|
||||
return outputs, all_times
|
||||
|
||||
# @backoff.on_exception(backoff.expo, openai.error.OpenAIError)
|
||||
def completions_with_backoff(**kwargs):
|
||||
@ -68,20 +73,26 @@ def completions_with_backoff(**kwargs):
|
||||
def chatgpt(model, messages, temperature=0.7, max_tokens=1000, n=5, stop=None) -> list:
|
||||
global completion_tokens, prompt_tokens
|
||||
outputs = []
|
||||
all_times = []
|
||||
client = OpenAI()
|
||||
|
||||
while n > 0:
|
||||
cnt = min(n, 20)
|
||||
n -= cnt
|
||||
|
||||
|
||||
start_time = time.perf_counter()
|
||||
res = client.chat.completions.create(model=model, messages=messages, temperature=temperature, n=cnt, stop=stop)
|
||||
generate_time = time.perf_counter()-start_time
|
||||
|
||||
outputs.extend([choice.message.content for choice in res.choices])
|
||||
|
||||
# log completion tokens
|
||||
completion_tokens += res.usage.completion_tokens
|
||||
prompt_tokens += res.usage.prompt_tokens
|
||||
|
||||
return outputs
|
||||
all_times.extend([generate_time/(float(cnt))]*cnt) #since generation happens behind black box api, going to just take the average time per seq
|
||||
|
||||
return outputs, all_times
|
||||
|
||||
def gpt_usage(backend="gpt-4o"):
|
||||
global completion_tokens, prompt_tokens
|
||||
|
0
src/tot/utils.py
Normal file
0
src/tot/utils.py
Normal file
Loading…
Reference in New Issue
Block a user