mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-06-26 18:26:00 +00:00
persistent queue
This commit is contained in:
parent
4fdc3ade55
commit
a44ace1219
32
run_bench.py
32
run_bench.py
@ -74,14 +74,19 @@ def load_llama(quant=None):
|
|||||||
def a_star_penalty(num, depth, k=0.1):
|
def a_star_penalty(num, depth, k=0.1):
|
||||||
return num * np.exp(-k*depth)
|
return num * np.exp(-k*depth)
|
||||||
|
|
||||||
def value_proposals(problem, current_state, proposals, tokenizer, model, device, a_star=False, depth=None):
|
def value_proposals(problem, current_state, proposals, tokenizer, model, device, cache=None, a_star=False, depth=None):
|
||||||
'''
|
'''
|
||||||
Takes in string values of problem, current state, and proposals.
|
Takes in string values of problem, current state, and proposals.
|
||||||
Returns a numerical valuation of each combination of the three factors above.
|
Returns a numerical valuation of each combination of the three factors above.
|
||||||
'''
|
'''
|
||||||
valuations = []
|
valuations = []
|
||||||
prompts = []
|
prompts = []
|
||||||
for p in proposals:
|
|
||||||
|
# only eval if not prev evaluated
|
||||||
|
noncached_proposals = [p for p in proposals if p not in cache]
|
||||||
|
cache_hits = len(proposals) - len(noncached_proposals)
|
||||||
|
|
||||||
|
for p in noncached_proposals:
|
||||||
prompts.append(value_prompt.format(problem=problem, current_state=current_state, proposal=p))
|
prompts.append(value_prompt.format(problem=problem, current_state=current_state, proposal=p))
|
||||||
|
|
||||||
values = tokenizer(prompts, return_tensors='pt', padding=True, truncation=True)
|
values = tokenizer(prompts, return_tensors='pt', padding=True, truncation=True)
|
||||||
@ -113,8 +118,13 @@ def value_proposals(problem, current_state, proposals, tokenizer, model, device,
|
|||||||
else:
|
else:
|
||||||
valuations.append(0.0)
|
valuations.append(0.0)
|
||||||
|
|
||||||
|
for p, v in list(zip(noncached_proposals, valuations):
|
||||||
|
cache[p] = v
|
||||||
|
|
||||||
return valuations
|
# could maybe be optimized but should be fine
|
||||||
|
valuations = [cache[p] for p in proposals]
|
||||||
|
|
||||||
|
return valuations, cache_hits
|
||||||
|
|
||||||
def parse_problem(problem, math=False):
|
def parse_problem(problem, math=False):
|
||||||
'''
|
'''
|
||||||
@ -242,6 +252,9 @@ def solve(input_ids, label, mask, model, tokenizer, device, args):
|
|||||||
problem = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
problem = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
||||||
# print(problem)
|
# print(problem)
|
||||||
selected = ""
|
selected = ""
|
||||||
|
valuation_cache = {} # cache for repeated valuations
|
||||||
|
proposals = [] # persist the queue across iterations
|
||||||
|
|
||||||
for i in range(args.depth): #args.depth number of attempts to reach the solution
|
for i in range(args.depth): #args.depth number of attempts to reach the solution
|
||||||
|
|
||||||
#propose next step/solutions per node/prompt
|
#propose next step/solutions per node/prompt
|
||||||
@ -259,7 +272,7 @@ def solve(input_ids, label, mask, model, tokenizer, device, args):
|
|||||||
#evaluate/rate the proposals
|
#evaluate/rate the proposals
|
||||||
current_state = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
current_state = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
||||||
|
|
||||||
proposals = []
|
# proposals = []
|
||||||
for o in out:
|
for o in out:
|
||||||
string_answer = tokenizer.decode(o, skip_special_tokens=True)
|
string_answer = tokenizer.decode(o, skip_special_tokens=True)
|
||||||
string_answer = string_answer.split("Possible next step:")[-1]
|
string_answer = string_answer.split("Possible next step:")[-1]
|
||||||
@ -268,8 +281,9 @@ def solve(input_ids, label, mask, model, tokenizer, device, args):
|
|||||||
# assert isinstance(string_answer, str)
|
# assert isinstance(string_answer, str)
|
||||||
proposals.extend([string_answer])
|
proposals.extend([string_answer])
|
||||||
# exit()
|
# exit()
|
||||||
|
# could collect cache hit statistics if necessary
|
||||||
reval = time.perf_counter()
|
reval = time.perf_counter()
|
||||||
valuations = value_proposals(problem=problem, current_state=current_state, proposals=proposals, tokenizer=tokenizer, model=model, device=device)
|
valuations, cache_hits = value_proposals(problem=problem, current_state=current_state, proposals=proposals, tokenizer=tokenizer, model=model, device=device, cache=valuation_cache)
|
||||||
reval = time.perf_counter() - reval
|
reval = time.perf_counter() - reval
|
||||||
average_eval_time_per_sample.append(reval)
|
average_eval_time_per_sample.append(reval)
|
||||||
|
|
||||||
@ -279,8 +293,12 @@ def solve(input_ids, label, mask, model, tokenizer, device, args):
|
|||||||
|
|
||||||
#select the best proposal
|
#select the best proposal
|
||||||
val_props = list(zip(proposals, valuations))
|
val_props = list(zip(proposals, valuations))
|
||||||
|
|
||||||
val_props.sort(key = lambda ele: ele[1], reverse=True)
|
val_props.sort(key = lambda ele: ele[1], reverse=True)
|
||||||
selected = val_props[:args.greedy_n][0][0]
|
val_props = val_props[:args.greedy_n]
|
||||||
|
selected = val_props[0][0]
|
||||||
|
val_props = val_props[1:] # remove the selected node from the queue to avoid reeval
|
||||||
|
proposals = [p for vp[0] in val_props] # update the queue to include the greedy_n highest ranking nodes
|
||||||
|
|
||||||
#format the chosen proposal for the next iteration
|
#format the chosen proposal for the next iteration
|
||||||
next_prompt = propose_prompt.format(problem=problem, current_state=selected)
|
next_prompt = propose_prompt.format(problem=problem, current_state=selected)
|
||||||
@ -451,4 +469,4 @@ if __name__ == '__main__':
|
|||||||
# print("THIS IS TEMP TUNING")
|
# print("THIS IS TEMP TUNING")
|
||||||
# print(temp_tuning.items())
|
# print(temp_tuning.items())
|
||||||
# temp = pd.DataFrame(temp_tuning)
|
# temp = pd.DataFrame(temp_tuning)
|
||||||
# temp.to_csv('./temp_tuning.csv')
|
# temp.to_csv('./temp_tuning.csv')
|
||||||
|
Loading…
Reference in New Issue
Block a user