mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-04-23 23:54:04 +00:00
persistent queue
This commit is contained in:
parent
4fdc3ade55
commit
a44ace1219
32
run_bench.py
32
run_bench.py
@ -74,14 +74,19 @@ def load_llama(quant=None):
|
||||
def a_star_penalty(num, depth, k=0.1):
|
||||
return num * np.exp(-k*depth)
|
||||
|
||||
def value_proposals(problem, current_state, proposals, tokenizer, model, device, a_star=False, depth=None):
|
||||
def value_proposals(problem, current_state, proposals, tokenizer, model, device, cache=None, a_star=False, depth=None):
|
||||
'''
|
||||
Takes in string values of problem, current state, and proposals.
|
||||
Returns a numerical valuation of each combination of the three factors above.
|
||||
'''
|
||||
valuations = []
|
||||
prompts = []
|
||||
for p in proposals:
|
||||
|
||||
# only eval if not prev evaluated
|
||||
noncached_proposals = [p for p in proposals if p not in cache]
|
||||
cache_hits = len(proposals) - len(noncached_proposals)
|
||||
|
||||
for p in noncached_proposals:
|
||||
prompts.append(value_prompt.format(problem=problem, current_state=current_state, proposal=p))
|
||||
|
||||
values = tokenizer(prompts, return_tensors='pt', padding=True, truncation=True)
|
||||
@ -113,8 +118,13 @@ def value_proposals(problem, current_state, proposals, tokenizer, model, device,
|
||||
else:
|
||||
valuations.append(0.0)
|
||||
|
||||
for p, v in list(zip(noncached_proposals, valuations):
|
||||
cache[p] = v
|
||||
|
||||
return valuations
|
||||
# could maybe be optimized but should be fine
|
||||
valuations = [cache[p] for p in proposals]
|
||||
|
||||
return valuations, cache_hits
|
||||
|
||||
def parse_problem(problem, math=False):
|
||||
'''
|
||||
@ -242,6 +252,9 @@ def solve(input_ids, label, mask, model, tokenizer, device, args):
|
||||
problem = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
||||
# print(problem)
|
||||
selected = ""
|
||||
valuation_cache = {} # cache for repeated valuations
|
||||
proposals = [] # persist the queue across iterations
|
||||
|
||||
for i in range(args.depth): #args.depth number of attempts to reach the solution
|
||||
|
||||
#propose next step/solutions per node/prompt
|
||||
@ -259,7 +272,7 @@ def solve(input_ids, label, mask, model, tokenizer, device, args):
|
||||
#evaluate/rate the proposals
|
||||
current_state = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
||||
|
||||
proposals = []
|
||||
# proposals = []
|
||||
for o in out:
|
||||
string_answer = tokenizer.decode(o, skip_special_tokens=True)
|
||||
string_answer = string_answer.split("Possible next step:")[-1]
|
||||
@ -268,8 +281,9 @@ def solve(input_ids, label, mask, model, tokenizer, device, args):
|
||||
# assert isinstance(string_answer, str)
|
||||
proposals.extend([string_answer])
|
||||
# exit()
|
||||
# could collect cache hit statistics if necessary
|
||||
reval = time.perf_counter()
|
||||
valuations = value_proposals(problem=problem, current_state=current_state, proposals=proposals, tokenizer=tokenizer, model=model, device=device)
|
||||
valuations, cache_hits = value_proposals(problem=problem, current_state=current_state, proposals=proposals, tokenizer=tokenizer, model=model, device=device, cache=valuation_cache)
|
||||
reval = time.perf_counter() - reval
|
||||
average_eval_time_per_sample.append(reval)
|
||||
|
||||
@ -279,8 +293,12 @@ def solve(input_ids, label, mask, model, tokenizer, device, args):
|
||||
|
||||
#select the best proposal
|
||||
val_props = list(zip(proposals, valuations))
|
||||
|
||||
val_props.sort(key = lambda ele: ele[1], reverse=True)
|
||||
selected = val_props[:args.greedy_n][0][0]
|
||||
val_props = val_props[:args.greedy_n]
|
||||
selected = val_props[0][0]
|
||||
val_props = val_props[1:] # remove the selected node from the queue to avoid reeval
|
||||
proposals = [p for vp[0] in val_props] # update the queue to include the greedy_n highest ranking nodes
|
||||
|
||||
#format the chosen proposal for the next iteration
|
||||
next_prompt = propose_prompt.format(problem=problem, current_state=selected)
|
||||
@ -451,4 +469,4 @@ if __name__ == '__main__':
|
||||
# print("THIS IS TEMP TUNING")
|
||||
# print(temp_tuning.items())
|
||||
# temp = pd.DataFrame(temp_tuning)
|
||||
# temp.to_csv('./temp_tuning.csv')
|
||||
# temp.to_csv('./temp_tuning.csv')
|
||||
|
Loading…
Reference in New Issue
Block a user