Update bfs.py

This commit is contained in:
Dmitrii Zakharov 2024-11-25 20:23:13 -05:00 committed by GitHub
parent bfa8280a48
commit 6d9dda04d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -71,17 +71,31 @@ def solve(args, task, idx, model, tokenizer, to_print=True):
# generation
print("Started Generation...")
new_ys_with_times = []
# TODO: be mindful with n_generate_sample and n_select_sample to avoid node explosion
# n_select_sample * n_generate_sample is the highest possible number of additional evaluations each step
# total task.steps * n_select_sample * n_generate_sample
if args.method_generate == 'sample':
new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
# elif args.method_generate == 'propose':
for y in ys:
generated_ys, generate_times = get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step])
new_ys_with_times.extend(zip(generated_ys, generate_times))
else:
new_ys = [get_proposals(task, x, y) for y in ys]
new_ys, generate_times = new_ys[0]
for y in ys:
generated_ys, generate_times = get_proposals(task, x, y)
new_ys_with_times.extend(zip(generated_ys, generate_times))
new_ys = list(itertools.chain(new_ys))
new_ys, generate_times = zip(*new_ys_with_times)
new_ys = list(new_ys)
generate_times = list(generate_times)
# new_ys = list(itertools.chain(new_ys))
new_yes = ys + new_ys # extend the current queue with the frontier
ids = list(range(len(new_ys)))
# evaluation
# shouldn't worry about reevaluating the same ys as the values should be saved in the task cache
# but could potentially add logic to remove expanded from queue
print("Finished Generation...Started Eval!")
if args.method_evaluate == 'vote':
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
@ -116,4 +130,4 @@ def naive_solve(args, task, idx, to_print=True):
print(inference_model)
x = task.get_input(idx) # input
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
return ys, {}
return ys, {}