From 6d9dda04d8c5d9adee6f3a4fb3927402d57a537c Mon Sep 17 00:00:00 2001 From: Dmitrii Zakharov <91263222+r1p71d3@users.noreply.github.com> Date: Mon, 25 Nov 2024 20:23:13 -0500 Subject: [PATCH] Update bfs.py --- src/tot/methods/bfs.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/src/tot/methods/bfs.py b/src/tot/methods/bfs.py index 25355c2..b4a73fc 100644 --- a/src/tot/methods/bfs.py +++ b/src/tot/methods/bfs.py @@ -71,17 +71,31 @@ def solve(args, task, idx, model, tokenizer, to_print=True): # generation print("Started Generation...") + new_ys_with_times = [] + # TODO: be mindful with n_generate_sample and n_select_sample to avoid node explosion + # n_select_sample * n_generate_sample is the highest possible number of additional evaluations each step + # total task.steps * n_select_sample * n_generate_sample if args.method_generate == 'sample': - new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys] - # elif args.method_generate == 'propose': + for y in ys: + generated_ys, generate_times = get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) + new_ys_with_times.extend(zip(generated_ys, generate_times)) else: - new_ys = [get_proposals(task, x, y) for y in ys] - new_ys, generate_times = new_ys[0] + for y in ys: + generated_ys, generate_times = get_proposals(task, x, y) + new_ys_with_times.extend(zip(generated_ys, generate_times)) - new_ys = list(itertools.chain(new_ys)) + new_ys, generate_times = zip(*new_ys_with_times) + new_ys = list(new_ys) + generate_times = list(generate_times) + + # new_ys = list(itertools.chain(new_ys)) + new_yes = ys + new_ys # extend the current queue with the frontier + ids = list(range(len(new_ys))) # evaluation + # shouldn't worry about reevaluating the same ys as the values should be saved in the task cache + # but could potentially add logic to remove expanded from queue print("Finished Generation...Started Eval!") if args.method_evaluate == 'vote': values = get_votes(task, x, new_ys, args.n_evaluate_sample) @@ -116,4 +130,4 @@ def naive_solve(args, task, idx, to_print=True): print(inference_model) x = task.get_input(idx) # input ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None) - return ys, {} \ No newline at end of file + return ys, {}