persistent queue

This commit is contained in:
Dmitrii Zakharov 2024-12-10 19:21:13 -05:00 committed by GitHub
parent 4fdc3ade55
commit a44ace1219
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -74,14 +74,19 @@ def load_llama(quant=None):
def a_star_penalty(num, depth, k=0.1):
return num * np.exp(-k*depth)
def value_proposals(problem, current_state, proposals, tokenizer, model, device, a_star=False, depth=None):
def value_proposals(problem, current_state, proposals, tokenizer, model, device, cache=None, a_star=False, depth=None):
'''
Takes in string values of problem, current state, and proposals.
Returns a numerical valuation of each combination of the three factors above.
'''
valuations = []
prompts = []
for p in proposals:
# only eval if not prev evaluated
noncached_proposals = [p for p in proposals if p not in cache]
cache_hits = len(proposals) - len(noncached_proposals)
for p in noncached_proposals:
prompts.append(value_prompt.format(problem=problem, current_state=current_state, proposal=p))
values = tokenizer(prompts, return_tensors='pt', padding=True, truncation=True)
@ -113,8 +118,13 @@ def value_proposals(problem, current_state, proposals, tokenizer, model, device,
else:
valuations.append(0.0)
for p, v in list(zip(noncached_proposals, valuations):
cache[p] = v
return valuations
# could maybe be optimized but should be fine
valuations = [cache[p] for p in proposals]
return valuations, cache_hits
def parse_problem(problem, math=False):
'''
@ -242,6 +252,9 @@ def solve(input_ids, label, mask, model, tokenizer, device, args):
problem = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
# print(problem)
selected = ""
valuation_cache = {} # cache for repeated valuations
proposals = [] # persist the queue across iterations
for i in range(args.depth): #args.depth number of attempts to reach the solution
#propose next step/solutions per node/prompt
@ -259,7 +272,7 @@ def solve(input_ids, label, mask, model, tokenizer, device, args):
#evaluate/rate the proposals
current_state = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
proposals = []
# proposals = []
for o in out:
string_answer = tokenizer.decode(o, skip_special_tokens=True)
string_answer = string_answer.split("Possible next step:")[-1]
@ -268,8 +281,9 @@ def solve(input_ids, label, mask, model, tokenizer, device, args):
# assert isinstance(string_answer, str)
proposals.extend([string_answer])
# exit()
# could collect cache hit statistics if necessary
reval = time.perf_counter()
valuations = value_proposals(problem=problem, current_state=current_state, proposals=proposals, tokenizer=tokenizer, model=model, device=device)
valuations, cache_hits = value_proposals(problem=problem, current_state=current_state, proposals=proposals, tokenizer=tokenizer, model=model, device=device, cache=valuation_cache)
reval = time.perf_counter() - reval
average_eval_time_per_sample.append(reval)
@ -279,8 +293,12 @@ def solve(input_ids, label, mask, model, tokenizer, device, args):
#select the best proposal
val_props = list(zip(proposals, valuations))
val_props.sort(key = lambda ele: ele[1], reverse=True)
selected = val_props[:args.greedy_n][0][0]
val_props = val_props[:args.greedy_n]
selected = val_props[0][0]
val_props = val_props[1:] # remove the selected node from the queue to avoid reeval
proposals = [p for vp[0] in val_props] # update the queue to include the greedy_n highest ranking nodes
#format the chosen proposal for the next iteration
next_prompt = propose_prompt.format(problem=problem, current_state=selected)
@ -451,4 +469,4 @@ if __name__ == '__main__':
# print("THIS IS TEMP TUNING")
# print(temp_tuning.items())
# temp = pd.DataFrame(temp_tuning)
# temp.to_csv('./temp_tuning.csv')
# temp.to_csv('./temp_tuning.csv')