persistent queue

This commit is contained in:
Dmitrii Zakharov 2024-12-10 19:21:13 -05:00 committed by GitHub
parent 4fdc3ade55
commit a44ace1219
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -74,14 +74,19 @@ def load_llama(quant=None):
def a_star_penalty(num, depth, k=0.1): def a_star_penalty(num, depth, k=0.1):
return num * np.exp(-k*depth) return num * np.exp(-k*depth)
def value_proposals(problem, current_state, proposals, tokenizer, model, device, a_star=False, depth=None): def value_proposals(problem, current_state, proposals, tokenizer, model, device, cache=None, a_star=False, depth=None):
''' '''
Takes in string values of problem, current state, and proposals. Takes in string values of problem, current state, and proposals.
Returns a numerical valuation of each combination of the three factors above. Returns a numerical valuation of each combination of the three factors above.
''' '''
valuations = [] valuations = []
prompts = [] prompts = []
for p in proposals:
# only eval if not prev evaluated
noncached_proposals = [p for p in proposals if p not in cache]
cache_hits = len(proposals) - len(noncached_proposals)
for p in noncached_proposals:
prompts.append(value_prompt.format(problem=problem, current_state=current_state, proposal=p)) prompts.append(value_prompt.format(problem=problem, current_state=current_state, proposal=p))
values = tokenizer(prompts, return_tensors='pt', padding=True, truncation=True) values = tokenizer(prompts, return_tensors='pt', padding=True, truncation=True)
@ -113,8 +118,13 @@ def value_proposals(problem, current_state, proposals, tokenizer, model, device,
else: else:
valuations.append(0.0) valuations.append(0.0)
for p, v in list(zip(noncached_proposals, valuations):
cache[p] = v
return valuations # could maybe be optimized but should be fine
valuations = [cache[p] for p in proposals]
return valuations, cache_hits
def parse_problem(problem, math=False): def parse_problem(problem, math=False):
''' '''
@ -242,6 +252,9 @@ def solve(input_ids, label, mask, model, tokenizer, device, args):
problem = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0] problem = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
# print(problem) # print(problem)
selected = "" selected = ""
valuation_cache = {} # cache for repeated valuations
proposals = [] # persist the queue across iterations
for i in range(args.depth): #args.depth number of attempts to reach the solution for i in range(args.depth): #args.depth number of attempts to reach the solution
#propose next step/solutions per node/prompt #propose next step/solutions per node/prompt
@ -259,7 +272,7 @@ def solve(input_ids, label, mask, model, tokenizer, device, args):
#evaluate/rate the proposals #evaluate/rate the proposals
current_state = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0] current_state = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
proposals = [] # proposals = []
for o in out: for o in out:
string_answer = tokenizer.decode(o, skip_special_tokens=True) string_answer = tokenizer.decode(o, skip_special_tokens=True)
string_answer = string_answer.split("Possible next step:")[-1] string_answer = string_answer.split("Possible next step:")[-1]
@ -268,8 +281,9 @@ def solve(input_ids, label, mask, model, tokenizer, device, args):
# assert isinstance(string_answer, str) # assert isinstance(string_answer, str)
proposals.extend([string_answer]) proposals.extend([string_answer])
# exit() # exit()
# could collect cache hit statistics if necessary
reval = time.perf_counter() reval = time.perf_counter()
valuations = value_proposals(problem=problem, current_state=current_state, proposals=proposals, tokenizer=tokenizer, model=model, device=device) valuations, cache_hits = value_proposals(problem=problem, current_state=current_state, proposals=proposals, tokenizer=tokenizer, model=model, device=device, cache=valuation_cache)
reval = time.perf_counter() - reval reval = time.perf_counter() - reval
average_eval_time_per_sample.append(reval) average_eval_time_per_sample.append(reval)
@ -279,8 +293,12 @@ def solve(input_ids, label, mask, model, tokenizer, device, args):
#select the best proposal #select the best proposal
val_props = list(zip(proposals, valuations)) val_props = list(zip(proposals, valuations))
val_props.sort(key = lambda ele: ele[1], reverse=True) val_props.sort(key = lambda ele: ele[1], reverse=True)
selected = val_props[:args.greedy_n][0][0] val_props = val_props[:args.greedy_n]
selected = val_props[0][0]
val_props = val_props[1:] # remove the selected node from the queue to avoid reeval
proposals = [p for vp[0] in val_props] # update the queue to include the greedy_n highest ranking nodes
#format the chosen proposal for the next iteration #format the chosen proposal for the next iteration
next_prompt = propose_prompt.format(problem=problem, current_state=selected) next_prompt = propose_prompt.format(problem=problem, current_state=selected)
@ -451,4 +469,4 @@ if __name__ == '__main__':
# print("THIS IS TEMP TUNING") # print("THIS IS TEMP TUNING")
# print(temp_tuning.items()) # print(temp_tuning.items())
# temp = pd.DataFrame(temp_tuning) # temp = pd.DataFrame(temp_tuning)
# temp.to_csv('./temp_tuning.csv') # temp.to_csv('./temp_tuning.csv')