mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-05-22 03:46:02 +00:00
adjust values by distance (game24 only)
This commit is contained in:
parent
6d9dda04d8
commit
0857d2840e
142
src/tot/methods/a_star.py
Normal file
142
src/tot/methods/a_star.py
Normal file
@ -0,0 +1,142 @@
|
||||
import itertools
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from ..models import inference_model
|
||||
|
||||
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
|
||||
value_prompt = task.value_prompt_wrap(x, y)
|
||||
if cache_value and value_prompt in task.value_cache:
|
||||
return task.value_cache[value_prompt], 0.0 #0 for inference latency bc cache was used
|
||||
value_outputs, eval_time = inference_model(value_prompt, n=n_evaluate_sample, stop=None)
|
||||
|
||||
value = task.value_outputs_unwrap(x, y, value_outputs)
|
||||
|
||||
if cache_value:
|
||||
task.value_cache[value_prompt] = value
|
||||
|
||||
return [value, eval_time[0]] #assumes one value, one time element
|
||||
|
||||
def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
|
||||
|
||||
values = []
|
||||
times = []
|
||||
local_value_cache = {}
|
||||
for y in ys: # each partial output
|
||||
if y in local_value_cache: # avoid duplicate candidates
|
||||
value = 0
|
||||
time = 0
|
||||
else:
|
||||
value, time = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
|
||||
local_value_cache[y] = value
|
||||
|
||||
values.append(value)
|
||||
times.append(time)
|
||||
# print(values)
|
||||
return values, times
|
||||
|
||||
def adjust_value_with_dist(value, dist, k=0.1): # apply a multiplicative penalty to each value based on the distance from the start
|
||||
new_val = value * np.exp(-k * dist)
|
||||
return new_val
|
||||
|
||||
def get_votes(task, x, ys, n_evaluate_sample):
|
||||
vote_prompt = task.vote_prompt_wrap(x, ys)
|
||||
vote_outputs = inference_model(vote_prompt, n=n_evaluate_sample, stop=None)
|
||||
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
|
||||
return values
|
||||
|
||||
def get_proposals(task, x, y):
|
||||
final_proposals = []
|
||||
propose_prompt = task.propose_prompt_wrap(x, y)
|
||||
proposals, generate_times = inference_model(propose_prompt, n=5, stop=None)
|
||||
for prop in proposals:
|
||||
final_proposals.extend(prop.split('\n'))
|
||||
return ([y + _ + '\n' for _ in final_proposals], generate_times)
|
||||
|
||||
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
|
||||
if prompt_sample == 'standard':
|
||||
prompt = task.standard_prompt_wrap(x, y)
|
||||
elif prompt_sample == 'cot':
|
||||
prompt = task.cot_prompt_wrap(x, y)
|
||||
else:
|
||||
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
|
||||
samples = inference_model(prompt, n=n_generate_sample, stop=stop)
|
||||
return [y + _ for _ in samples]
|
||||
|
||||
def solve(args, task, idx, model, tokenizer, to_print=True):
|
||||
global inference_model
|
||||
inference_model = partial(inference_model, model=model, tokenizer=tokenizer, temperature=args.temperature)
|
||||
# print(inference_model)
|
||||
x = task.get_input(idx) # input
|
||||
ys = [''] # current output candidates
|
||||
infos = []
|
||||
dist_from_start = {'': 0}
|
||||
|
||||
for step in range(task.steps):
|
||||
print("Started Steps!!")
|
||||
|
||||
# generation
|
||||
print("Started Generation...")
|
||||
new_ys_with_times = []
|
||||
# TODO: be mindful with n_generate_sample and n_select_sample to avoid node explosion
|
||||
# n_select_sample * n_generate_sample is the highest possible number of additional evaluations each step
|
||||
# total task.steps * n_select_sample * n_generate_sample
|
||||
if args.method_generate == 'sample':
|
||||
for y in ys:
|
||||
generated_ys, generate_times = get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step])
|
||||
new_ys_with_times.extend(zip(generated_ys, generate_times))
|
||||
dist_from_start.update({new_y: dist_from_start[y] + 1 for new_y in generated_ys})
|
||||
else:
|
||||
for y in ys:
|
||||
generated_ys, generate_times = get_proposals(task, x, y)
|
||||
new_ys_with_times.extend(zip(generated_ys, generate_times))
|
||||
dist_from_start.update({new_y: dist_from_start[y] + 1 for new_y in generated_ys})
|
||||
|
||||
new_ys, generate_times = zip(*new_ys_with_times)
|
||||
new_ys = list(new_ys)
|
||||
generate_times = list(generate_times)
|
||||
|
||||
# new_ys = list(itertools.chain(new_ys))
|
||||
new_ys = ys + new_ys # extend the current queue with the frontier
|
||||
|
||||
ids = list(range(len(new_ys)))
|
||||
|
||||
# evaluation
|
||||
# shouldn't worry about reevaluating the same ys as the values should be saved in the task cache
|
||||
# but could potentially add logic to remove expanded from queue
|
||||
print("Finished Generation...Started Eval!")
|
||||
if args.method_evaluate == 'vote':
|
||||
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
|
||||
elif args.method_evaluate == 'value':
|
||||
values, eval_times = get_values(task, x, new_ys, args.n_evaluate_sample)
|
||||
|
||||
values = [adjust_value_with_dist(value, dist_from_start[new_y]) for value, new_y in zip(values, new_ys)]
|
||||
|
||||
# selection
|
||||
print("Finished Eval...Started Selection...")
|
||||
if args.method_select == 'sample':
|
||||
ps = np.array(values) / sum(values)
|
||||
select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
|
||||
elif args.method_select == 'greedy':
|
||||
select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
|
||||
select_new_ys = [new_ys[select_id] for select_id in select_ids]
|
||||
|
||||
# log
|
||||
print("Finished Selection...Logging...")
|
||||
if to_print:
|
||||
sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
|
||||
print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n-- generate times --: {generate_times}\n-- eval times --: {eval_times}\n')
|
||||
|
||||
infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys, 'generate_times': generate_times, 'eval_times': eval_times})
|
||||
ys = select_new_ys
|
||||
|
||||
if to_print:
|
||||
print(ys)
|
||||
return ys, {'steps': infos}
|
||||
|
||||
def naive_solve(args, task, idx, to_print=True):
|
||||
global inference_model
|
||||
inference_model = partial(inference_model, model=args.backend, temperature=args.temperature)
|
||||
print(inference_model)
|
||||
x = task.get_input(idx) # input
|
||||
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
|
||||
return ys, {}
|
@ -66,6 +66,7 @@ def solve(args, task, idx, model, tokenizer, to_print=True):
|
||||
ys = [''] # current output candidates
|
||||
infos = []
|
||||
|
||||
|
||||
for step in range(task.steps):
|
||||
print("Started Steps!!")
|
||||
|
||||
@ -89,7 +90,7 @@ def solve(args, task, idx, model, tokenizer, to_print=True):
|
||||
generate_times = list(generate_times)
|
||||
|
||||
# new_ys = list(itertools.chain(new_ys))
|
||||
new_yes = ys + new_ys # extend the current queue with the frontier
|
||||
new_ys = ys + new_ys # extend the current queue with the frontier
|
||||
|
||||
ids = list(range(len(new_ys)))
|
||||
|
||||
|
@ -91,4 +91,5 @@ class Game24Task(Task):
|
||||
value_names = [_.split('\n')[-1] for _ in value_outputs]
|
||||
value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20} # TODO: ad hoc
|
||||
value = sum(value * value_names.count(name) for name, value in value_map.items())
|
||||
return value
|
||||
normalized_value = value / (len(value_names) * value_map['sure']) # scale to [0, 1]
|
||||
return normalized_value
|
||||
|
Loading…
Reference in New Issue
Block a user