mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-04-26 08:58:58 +00:00
103 lines
4.1 KiB
Python
103 lines
4.1 KiB
Python
import itertools
|
|
import numpy as np
|
|
from functools import partial
|
|
from ..models import inference_model
|
|
|
|
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
|
|
value_prompt = task.value_prompt_wrap(x, y)
|
|
if cache_value and value_prompt in task.value_cache:
|
|
return task.value_cache[value_prompt]
|
|
value_outputs = inference_model(value_prompt, n=n_evaluate_sample, stop=None)
|
|
value = task.value_outputs_unwrap(x, y, value_outputs)
|
|
if cache_value:
|
|
task.value_cache[value_prompt] = value
|
|
return value
|
|
|
|
def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
|
|
|
|
values = []
|
|
local_value_cache = {}
|
|
for y in ys: # each partial output
|
|
if y in local_value_cache: # avoid duplicate candidates
|
|
value = 0
|
|
else:
|
|
value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
|
|
local_value_cache[y] = value
|
|
values.append(value)
|
|
# print(values)
|
|
return values
|
|
|
|
def get_votes(task, x, ys, n_evaluate_sample):
|
|
vote_prompt = task.vote_prompt_wrap(x, ys)
|
|
vote_outputs = inference_model(vote_prompt, n=n_evaluate_sample, stop=None)
|
|
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
|
|
return values
|
|
|
|
def get_proposals(task, x, y):
|
|
propose_prompt = task.propose_prompt_wrap(x, y)
|
|
proposals = inference_model(propose_prompt, n=5, stop=None)[0].split('\n')
|
|
|
|
return [y + _ + '\n' for _ in proposals]
|
|
|
|
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
|
|
if prompt_sample == 'standard':
|
|
prompt = task.standard_prompt_wrap(x, y)
|
|
elif prompt_sample == 'cot':
|
|
prompt = task.cot_prompt_wrap(x, y)
|
|
else:
|
|
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
|
|
samples = inference_model(prompt, n=n_generate_sample, stop=stop)
|
|
return [y + _ for _ in samples]
|
|
|
|
def solve(args, task, idx, model, tokenizer, to_print=True):
|
|
global inference_model
|
|
inference_model = partial(inference_model, model=model, tokenizer=tokenizer, temperature=args.temperature)
|
|
# print(inference_model)
|
|
x = task.get_input(idx) # input
|
|
ys = [''] # current output candidates
|
|
infos = []
|
|
|
|
for step in range(task.steps):
|
|
print("Started Steps!!")
|
|
# generation
|
|
if args.method_generate == 'sample':
|
|
new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
|
|
# elif args.method_generate == 'propose':
|
|
else:
|
|
new_ys = [get_proposals(task, x, y) for y in ys]
|
|
new_ys = list(itertools.chain(*new_ys))
|
|
ids = list(range(len(new_ys)))
|
|
|
|
# evaluation
|
|
if args.method_evaluate == 'vote':
|
|
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
|
|
elif args.method_evaluate == 'value':
|
|
values = get_values(task, x, new_ys, args.n_evaluate_sample)
|
|
|
|
# selection
|
|
if args.method_select == 'sample':
|
|
ps = np.array(values) / sum(values)
|
|
select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
|
|
elif args.method_select == 'greedy':
|
|
select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
|
|
select_new_ys = [new_ys[select_id] for select_id in select_ids]
|
|
|
|
# log
|
|
if to_print:
|
|
sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
|
|
print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
|
|
|
|
infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
|
|
ys = select_new_ys
|
|
|
|
if to_print:
|
|
print(ys)
|
|
return ys, {'steps': infos}
|
|
|
|
def naive_solve(args, task, idx, to_print=True):
|
|
global inference_model
|
|
inference_model = partial(inference_model, model=args.backend, temperature=args.temperature)
|
|
print(inference_model)
|
|
x = task.get_input(idx) # input
|
|
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
|
|
return ys, {} |