diff --git a/run_bench.py b/run_bench.py index df6c2be..0b18f5f 100644 --- a/run_bench.py +++ b/run_bench.py @@ -74,14 +74,19 @@ def load_llama(quant=None): def a_star_penalty(num, depth, k=0.1): return num * np.exp(-k*depth) -def value_proposals(problem, current_state, proposals, tokenizer, model, device, a_star=False, depth=None): +def value_proposals(problem, current_state, proposals, tokenizer, model, device, cache=None, a_star=False, depth=None): ''' Takes in string values of problem, current state, and proposals. Returns a numerical valuation of each combination of the three factors above. ''' valuations = [] prompts = [] - for p in proposals: + + # only eval if not prev evaluated + noncached_proposals = [p for p in proposals if p not in cache] + cache_hits = len(proposals) - len(noncached_proposals) + + for p in noncached_proposals: prompts.append(value_prompt.format(problem=problem, current_state=current_state, proposal=p)) values = tokenizer(prompts, return_tensors='pt', padding=True, truncation=True) @@ -113,8 +118,13 @@ def value_proposals(problem, current_state, proposals, tokenizer, model, device, else: valuations.append(0.0) + for p, v in list(zip(noncached_proposals, valuations): + cache[p] = v - return valuations + # could maybe be optimized but should be fine + valuations = [cache[p] for p in proposals] + + return valuations, cache_hits def parse_problem(problem, math=False): ''' @@ -242,6 +252,9 @@ def solve(input_ids, label, mask, model, tokenizer, device, args): problem = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0] # print(problem) selected = "" + valuation_cache = {} # cache for repeated valuations + proposals = [] # persist the queue across iterations + for i in range(args.depth): #args.depth number of attempts to reach the solution #propose next step/solutions per node/prompt @@ -259,7 +272,7 @@ def solve(input_ids, label, mask, model, tokenizer, device, args): #evaluate/rate the proposals current_state = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0] - proposals = [] + # proposals = [] for o in out: string_answer = tokenizer.decode(o, skip_special_tokens=True) string_answer = string_answer.split("Possible next step:")[-1] @@ -268,8 +281,9 @@ def solve(input_ids, label, mask, model, tokenizer, device, args): # assert isinstance(string_answer, str) proposals.extend([string_answer]) # exit() + # could collect cache hit statistics if necessary reval = time.perf_counter() - valuations = value_proposals(problem=problem, current_state=current_state, proposals=proposals, tokenizer=tokenizer, model=model, device=device) + valuations, cache_hits = value_proposals(problem=problem, current_state=current_state, proposals=proposals, tokenizer=tokenizer, model=model, device=device, cache=valuation_cache) reval = time.perf_counter() - reval average_eval_time_per_sample.append(reval) @@ -279,8 +293,12 @@ def solve(input_ids, label, mask, model, tokenizer, device, args): #select the best proposal val_props = list(zip(proposals, valuations)) + val_props.sort(key = lambda ele: ele[1], reverse=True) - selected = val_props[:args.greedy_n][0][0] + val_props = val_props[:args.greedy_n] + selected = val_props[0][0] + val_props = val_props[1:] # remove the selected node from the queue to avoid reeval + proposals = [p for vp[0] in val_props] # update the queue to include the greedy_n highest ranking nodes #format the chosen proposal for the next iteration next_prompt = propose_prompt.format(problem=problem, current_state=selected) @@ -451,4 +469,4 @@ if __name__ == '__main__': # print("THIS IS TEMP TUNING") # print(temp_tuning.items()) # temp = pd.DataFrame(temp_tuning) - # temp.to_csv('./temp_tuning.csv') \ No newline at end of file + # temp.to_csv('./temp_tuning.csv')