From 0857d2840e6c44428c92ad41300c6e7afc7e829c Mon Sep 17 00:00:00 2001 From: r1p71d3 Date: Wed, 27 Nov 2024 02:37:58 -0500 Subject: [PATCH] adjust values by distance (game24 only) --- src/tot/methods/a_star.py | 142 ++++++++++++++++++++++++++++++++++++++ src/tot/methods/bfs.py | 3 +- src/tot/tasks/game24.py | 3 +- 3 files changed, 146 insertions(+), 2 deletions(-) create mode 100644 src/tot/methods/a_star.py diff --git a/src/tot/methods/a_star.py b/src/tot/methods/a_star.py new file mode 100644 index 0000000..56c28ee --- /dev/null +++ b/src/tot/methods/a_star.py @@ -0,0 +1,142 @@ +import itertools +import numpy as np +from functools import partial +from ..models import inference_model + +def get_value(task, x, y, n_evaluate_sample, cache_value=True): + value_prompt = task.value_prompt_wrap(x, y) + if cache_value and value_prompt in task.value_cache: + return task.value_cache[value_prompt], 0.0 #0 for inference latency bc cache was used + value_outputs, eval_time = inference_model(value_prompt, n=n_evaluate_sample, stop=None) + + value = task.value_outputs_unwrap(x, y, value_outputs) + + if cache_value: + task.value_cache[value_prompt] = value + + return [value, eval_time[0]] #assumes one value, one time element + +def get_values(task, x, ys, n_evaluate_sample, cache_value=True): + + values = [] + times = [] + local_value_cache = {} + for y in ys: # each partial output + if y in local_value_cache: # avoid duplicate candidates + value = 0 + time = 0 + else: + value, time = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value) + local_value_cache[y] = value + + values.append(value) + times.append(time) + # print(values) + return values, times + +def adjust_value_with_dist(value, dist, k=0.1): # apply a multiplicative penalty to each value based on the distance from the start + new_val = value * np.exp(-k * dist) + return new_val + +def get_votes(task, x, ys, n_evaluate_sample): + vote_prompt = task.vote_prompt_wrap(x, ys) + vote_outputs = inference_model(vote_prompt, n=n_evaluate_sample, stop=None) + values = task.vote_outputs_unwrap(vote_outputs, len(ys)) + return values + +def get_proposals(task, x, y): + final_proposals = [] + propose_prompt = task.propose_prompt_wrap(x, y) + proposals, generate_times = inference_model(propose_prompt, n=5, stop=None) + for prop in proposals: + final_proposals.extend(prop.split('\n')) + return ([y + _ + '\n' for _ in final_proposals], generate_times) + +def get_samples(task, x, y, n_generate_sample, prompt_sample, stop): + if prompt_sample == 'standard': + prompt = task.standard_prompt_wrap(x, y) + elif prompt_sample == 'cot': + prompt = task.cot_prompt_wrap(x, y) + else: + raise ValueError(f'prompt_sample {prompt_sample} not recognized') + samples = inference_model(prompt, n=n_generate_sample, stop=stop) + return [y + _ for _ in samples] + +def solve(args, task, idx, model, tokenizer, to_print=True): + global inference_model + inference_model = partial(inference_model, model=model, tokenizer=tokenizer, temperature=args.temperature) + # print(inference_model) + x = task.get_input(idx) # input + ys = [''] # current output candidates + infos = [] + dist_from_start = {'': 0} + + for step in range(task.steps): + print("Started Steps!!") + + # generation + print("Started Generation...") + new_ys_with_times = [] + # TODO: be mindful with n_generate_sample and n_select_sample to avoid node explosion + # n_select_sample * n_generate_sample is the highest possible number of additional evaluations each step + # total task.steps * n_select_sample * n_generate_sample + if args.method_generate == 'sample': + for y in ys: + generated_ys, generate_times = get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) + new_ys_with_times.extend(zip(generated_ys, generate_times)) + dist_from_start.update({new_y: dist_from_start[y] + 1 for new_y in generated_ys}) + else: + for y in ys: + generated_ys, generate_times = get_proposals(task, x, y) + new_ys_with_times.extend(zip(generated_ys, generate_times)) + dist_from_start.update({new_y: dist_from_start[y] + 1 for new_y in generated_ys}) + + new_ys, generate_times = zip(*new_ys_with_times) + new_ys = list(new_ys) + generate_times = list(generate_times) + + # new_ys = list(itertools.chain(new_ys)) + new_ys = ys + new_ys # extend the current queue with the frontier + + ids = list(range(len(new_ys))) + + # evaluation + # shouldn't worry about reevaluating the same ys as the values should be saved in the task cache + # but could potentially add logic to remove expanded from queue + print("Finished Generation...Started Eval!") + if args.method_evaluate == 'vote': + values = get_votes(task, x, new_ys, args.n_evaluate_sample) + elif args.method_evaluate == 'value': + values, eval_times = get_values(task, x, new_ys, args.n_evaluate_sample) + + values = [adjust_value_with_dist(value, dist_from_start[new_y]) for value, new_y in zip(values, new_ys)] + + # selection + print("Finished Eval...Started Selection...") + if args.method_select == 'sample': + ps = np.array(values) / sum(values) + select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist() + elif args.method_select == 'greedy': + select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample] + select_new_ys = [new_ys[select_id] for select_id in select_ids] + + # log + print("Finished Selection...Logging...") + if to_print: + sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True)) + print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n-- generate times --: {generate_times}\n-- eval times --: {eval_times}\n') + + infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys, 'generate_times': generate_times, 'eval_times': eval_times}) + ys = select_new_ys + + if to_print: + print(ys) + return ys, {'steps': infos} + +def naive_solve(args, task, idx, to_print=True): + global inference_model + inference_model = partial(inference_model, model=args.backend, temperature=args.temperature) + print(inference_model) + x = task.get_input(idx) # input + ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None) + return ys, {} diff --git a/src/tot/methods/bfs.py b/src/tot/methods/bfs.py index b4a73fc..c22dbc3 100644 --- a/src/tot/methods/bfs.py +++ b/src/tot/methods/bfs.py @@ -66,6 +66,7 @@ def solve(args, task, idx, model, tokenizer, to_print=True): ys = [''] # current output candidates infos = [] + for step in range(task.steps): print("Started Steps!!") @@ -89,7 +90,7 @@ def solve(args, task, idx, model, tokenizer, to_print=True): generate_times = list(generate_times) # new_ys = list(itertools.chain(new_ys)) - new_yes = ys + new_ys # extend the current queue with the frontier + new_ys = ys + new_ys # extend the current queue with the frontier ids = list(range(len(new_ys))) diff --git a/src/tot/tasks/game24.py b/src/tot/tasks/game24.py index da264ff..4f61e94 100644 --- a/src/tot/tasks/game24.py +++ b/src/tot/tasks/game24.py @@ -91,4 +91,5 @@ class Game24Task(Task): value_names = [_.split('\n')[-1] for _ in value_outputs] value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20} # TODO: ad hoc value = sum(value * value_names.count(name) for name, value in value_map.items()) - return value \ No newline at end of file + normalized_value = value / (len(value_names) * value_map['sure']) # scale to [0, 1] + return normalized_value