mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-06-26 18:26:00 +00:00
Update bfs.py
This commit is contained in:
parent
bfa8280a48
commit
6d9dda04d8
@ -71,17 +71,31 @@ def solve(args, task, idx, model, tokenizer, to_print=True):
|
||||
|
||||
# generation
|
||||
print("Started Generation...")
|
||||
new_ys_with_times = []
|
||||
# TODO: be mindful with n_generate_sample and n_select_sample to avoid node explosion
|
||||
# n_select_sample * n_generate_sample is the highest possible number of additional evaluations each step
|
||||
# total task.steps * n_select_sample * n_generate_sample
|
||||
if args.method_generate == 'sample':
|
||||
new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
|
||||
# elif args.method_generate == 'propose':
|
||||
for y in ys:
|
||||
generated_ys, generate_times = get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step])
|
||||
new_ys_with_times.extend(zip(generated_ys, generate_times))
|
||||
else:
|
||||
new_ys = [get_proposals(task, x, y) for y in ys]
|
||||
new_ys, generate_times = new_ys[0]
|
||||
for y in ys:
|
||||
generated_ys, generate_times = get_proposals(task, x, y)
|
||||
new_ys_with_times.extend(zip(generated_ys, generate_times))
|
||||
|
||||
new_ys = list(itertools.chain(new_ys))
|
||||
new_ys, generate_times = zip(*new_ys_with_times)
|
||||
new_ys = list(new_ys)
|
||||
generate_times = list(generate_times)
|
||||
|
||||
# new_ys = list(itertools.chain(new_ys))
|
||||
new_yes = ys + new_ys # extend the current queue with the frontier
|
||||
|
||||
ids = list(range(len(new_ys)))
|
||||
|
||||
# evaluation
|
||||
# shouldn't worry about reevaluating the same ys as the values should be saved in the task cache
|
||||
# but could potentially add logic to remove expanded from queue
|
||||
print("Finished Generation...Started Eval!")
|
||||
if args.method_evaluate == 'vote':
|
||||
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
|
||||
@ -116,4 +130,4 @@ def naive_solve(args, task, idx, to_print=True):
|
||||
print(inference_model)
|
||||
x = task.get_input(idx) # input
|
||||
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
|
||||
return ys, {}
|
||||
return ys, {}
|
||||
|
Loading…
Reference in New Issue
Block a user