From bfa8280a48f1b290639a27c4cd1b3923916e174e Mon Sep 17 00:00:00 2001 From: emilyworks Date: Mon, 18 Nov 2024 05:00:11 +0000 Subject: [PATCH] temp prints to test runs --- src/tot/methods/bfs.py | 10 ++++------ src/tot/models.py | 6 +++++- src/tot/tasks/game24.py | 2 ++ 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/tot/methods/bfs.py b/src/tot/methods/bfs.py index 59e071d..25355c2 100644 --- a/src/tot/methods/bfs.py +++ b/src/tot/methods/bfs.py @@ -70,6 +70,7 @@ def solve(args, task, idx, model, tokenizer, to_print=True): print("Started Steps!!") # generation + print("Started Generation...") if args.method_generate == 'sample': new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys] # elif args.method_generate == 'propose': @@ -77,22 +78,18 @@ def solve(args, task, idx, model, tokenizer, to_print=True): new_ys = [get_proposals(task, x, y) for y in ys] new_ys, generate_times = new_ys[0] - - new_ys = list(itertools.chain(new_ys)) ids = list(range(len(new_ys))) - print("these are the new ys") - print(new_ys) - print("****") - # evaluation + print("Finished Generation...Started Eval!") if args.method_evaluate == 'vote': values = get_votes(task, x, new_ys, args.n_evaluate_sample) elif args.method_evaluate == 'value': values, eval_times = get_values(task, x, new_ys, args.n_evaluate_sample) # selection + print("Finished Eval...Started Selection...") if args.method_select == 'sample': ps = np.array(values) / sum(values) select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist() @@ -101,6 +98,7 @@ def solve(args, task, idx, model, tokenizer, to_print=True): select_new_ys = [new_ys[select_id] for select_id in select_ids] # log + print("Finished Selection...Logging...") if to_print: sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True)) print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n-- generate times --: {generate_times}\n-- eval times --: {eval_times}\n') diff --git a/src/tot/models.py b/src/tot/models.py index 62bb082..e2797f8 100644 --- a/src/tot/models.py +++ b/src/tot/models.py @@ -54,16 +54,20 @@ def hf_model(model, tokenizer, prompt, temperature=0.7, max_tokens=1000, n=5, st n -= cnt #actual generation + print(f"cnt is {cnt}. Sending inputs to model...") + start_time = time.perf_counter() out = model.generate(**inputs, temperature=temperature, max_new_tokens=max_tokens, num_return_sequences=cnt) #might add stopping criteria depending on heuristics experimentation generate_time = time.perf_counter()-start_time + print(f"output recieved") + for o in out: string_answer = tokenizer.decode(o) outputs.extend([string_answer]) all_times.extend([generate_time/(float(cnt))]*cnt) #may modify this later to be more precise given hf is open source - + print("Returning model inference outputs") return outputs, all_times # @backoff.on_exception(backoff.expo, openai.error.OpenAIError) diff --git a/src/tot/tasks/game24.py b/src/tot/tasks/game24.py index c1cefaf..da264ff 100644 --- a/src/tot/tasks/game24.py +++ b/src/tot/tasks/game24.py @@ -46,12 +46,14 @@ class Game24Task(Task): numbers = re.findall(r'\d+', expression) problem_numbers = re.findall(r'\d+', self.data[idx]) if sorted(numbers) != sorted(problem_numbers): + print("sorted numbers length not equal to sorted problem numbers length") return {'r': 0} try: # print(sympy.simplify(expression)) return {'r': int(sympy.simplify(expression) == 24)} except Exception as e: # print(e) + print("entered exception!!") return {'r': 0} @staticmethod