mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-06-26 18:26:00 +00:00
Update bfs.py
This commit is contained in:
parent
bfa8280a48
commit
6d9dda04d8
@ -71,17 +71,31 @@ def solve(args, task, idx, model, tokenizer, to_print=True):
|
|||||||
|
|
||||||
# generation
|
# generation
|
||||||
print("Started Generation...")
|
print("Started Generation...")
|
||||||
|
new_ys_with_times = []
|
||||||
|
# TODO: be mindful with n_generate_sample and n_select_sample to avoid node explosion
|
||||||
|
# n_select_sample * n_generate_sample is the highest possible number of additional evaluations each step
|
||||||
|
# total task.steps * n_select_sample * n_generate_sample
|
||||||
if args.method_generate == 'sample':
|
if args.method_generate == 'sample':
|
||||||
new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
|
for y in ys:
|
||||||
# elif args.method_generate == 'propose':
|
generated_ys, generate_times = get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step])
|
||||||
|
new_ys_with_times.extend(zip(generated_ys, generate_times))
|
||||||
else:
|
else:
|
||||||
new_ys = [get_proposals(task, x, y) for y in ys]
|
for y in ys:
|
||||||
new_ys, generate_times = new_ys[0]
|
generated_ys, generate_times = get_proposals(task, x, y)
|
||||||
|
new_ys_with_times.extend(zip(generated_ys, generate_times))
|
||||||
|
|
||||||
new_ys = list(itertools.chain(new_ys))
|
new_ys, generate_times = zip(*new_ys_with_times)
|
||||||
|
new_ys = list(new_ys)
|
||||||
|
generate_times = list(generate_times)
|
||||||
|
|
||||||
|
# new_ys = list(itertools.chain(new_ys))
|
||||||
|
new_yes = ys + new_ys # extend the current queue with the frontier
|
||||||
|
|
||||||
ids = list(range(len(new_ys)))
|
ids = list(range(len(new_ys)))
|
||||||
|
|
||||||
# evaluation
|
# evaluation
|
||||||
|
# shouldn't worry about reevaluating the same ys as the values should be saved in the task cache
|
||||||
|
# but could potentially add logic to remove expanded from queue
|
||||||
print("Finished Generation...Started Eval!")
|
print("Finished Generation...Started Eval!")
|
||||||
if args.method_evaluate == 'vote':
|
if args.method_evaluate == 'vote':
|
||||||
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
|
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
|
||||||
@ -116,4 +130,4 @@ def naive_solve(args, task, idx, to_print=True):
|
|||||||
print(inference_model)
|
print(inference_model)
|
||||||
x = task.get_input(idx) # input
|
x = task.get_input(idx) # input
|
||||||
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
|
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
|
||||||
return ys, {}
|
return ys, {}
|
||||||
|
Loading…
Reference in New Issue
Block a user