mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-05-08 13:44:22 +00:00
adding last pieces of time profiling
This commit is contained in:
parent
975a6e8608
commit
17c8f9b8dc
3
run.py
3
run.py
@ -61,7 +61,8 @@ def run(args):
|
||||
ys, info = solve(args, task, i, model, tokenizer)
|
||||
|
||||
runtime = time.perf_counter()-start_timer
|
||||
# print(runtime)
|
||||
print(f"""For iteration {i} --
|
||||
Total Time to Solve: {runtime} seconds""")
|
||||
|
||||
# log
|
||||
infos = [task.test_output(i, y) for y in ys]
|
||||
|
@ -2,6 +2,7 @@ import itertools
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from ..models import inference_model
|
||||
import time
|
||||
|
||||
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
|
||||
value_prompt = task.value_prompt_wrap(x, y)
|
||||
@ -104,6 +105,9 @@ def solve(args, task, idx, model, tokenizer, to_print=True):
|
||||
# shouldn't worry about reevaluating the same ys as the values should be saved in the task cache
|
||||
# but could potentially add logic to remove expanded from queue
|
||||
print("Finished Generation...Started Eval!")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if args.method_evaluate == 'vote':
|
||||
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
|
||||
elif args.method_evaluate == 'value':
|
||||
@ -111,8 +115,14 @@ def solve(args, task, idx, model, tokenizer, to_print=True):
|
||||
|
||||
values = [adjust_value_with_dist(value, dist_from_start[new_y]) for value, new_y in zip(values, new_ys)]
|
||||
|
||||
eval_time = time.perf_counter()-start_time
|
||||
print(f"Node Eval Time: {eval_time} seconds")
|
||||
|
||||
# selection
|
||||
print("Finished Eval...Started Selection...")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if args.method_select == 'sample':
|
||||
ps = np.array(values) / sum(values)
|
||||
select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
|
||||
@ -120,6 +130,9 @@ def solve(args, task, idx, model, tokenizer, to_print=True):
|
||||
select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
|
||||
select_new_ys = [new_ys[select_id] for select_id in select_ids]
|
||||
|
||||
selection_time = time.perf_counter()-start_time()
|
||||
print(f"Selection Time: {selection_time} seconds")
|
||||
|
||||
# log
|
||||
print("Finished Selection...Logging...")
|
||||
if to_print:
|
||||
|
@ -2,6 +2,7 @@ import itertools
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from ..models import inference_model
|
||||
import time
|
||||
|
||||
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
|
||||
value_prompt = task.value_prompt_wrap(x, y)
|
||||
@ -68,7 +69,7 @@ def solve(args, task, idx, model, tokenizer, to_print=True):
|
||||
|
||||
|
||||
for step in range(task.steps):
|
||||
print("Started Steps!!")
|
||||
# print("Started Steps!!")
|
||||
|
||||
# generation
|
||||
print("Started Generation...")
|
||||
@ -98,13 +99,22 @@ def solve(args, task, idx, model, tokenizer, to_print=True):
|
||||
# shouldn't worry about reevaluating the same ys as the values should be saved in the task cache
|
||||
# but could potentially add logic to remove expanded from queue
|
||||
print("Finished Generation...Started Eval!")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if args.method_evaluate == 'vote':
|
||||
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
|
||||
elif args.method_evaluate == 'value':
|
||||
values, eval_times = get_values(task, x, new_ys, args.n_evaluate_sample)
|
||||
|
||||
eval_time = time.perf_counter()-start_time
|
||||
print(f"Node Eval Time: {eval_time} seconds")
|
||||
|
||||
# selection
|
||||
print("Finished Eval...Started Selection...")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if args.method_select == 'sample':
|
||||
ps = np.array(values) / sum(values)
|
||||
select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
|
||||
@ -112,6 +122,9 @@ def solve(args, task, idx, model, tokenizer, to_print=True):
|
||||
select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
|
||||
select_new_ys = [new_ys[select_id] for select_id in select_ids]
|
||||
|
||||
selection_time = time.perf_counter()-start_time()
|
||||
print(f"Selection Time: {selection_time} seconds")
|
||||
|
||||
# log
|
||||
print("Finished Selection...Logging...")
|
||||
if to_print:
|
||||
|
@ -54,20 +54,21 @@ def hf_model(model, tokenizer, prompt, temperature=0.7, max_tokens=1000, n=5, st
|
||||
n -= cnt
|
||||
|
||||
#actual generation
|
||||
print(f"cnt is {cnt}. Sending inputs to model...")
|
||||
# print(f"cnt is {cnt}. Sending inputs to model...")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
out = model.generate(**inputs, temperature=temperature, max_new_tokens=max_tokens, num_return_sequences=cnt) #might add stopping criteria depending on heuristics experimentation
|
||||
generate_time = time.perf_counter()-start_time
|
||||
print(f"Averaged Time to Generate Single Node Proposal: {[generate_time/(float(cnt))]*cnt} seconds")
|
||||
|
||||
print(f"output recieved")
|
||||
# print(f"output recieved")
|
||||
|
||||
for o in out:
|
||||
string_answer = tokenizer.decode(o)
|
||||
outputs.extend([string_answer])
|
||||
|
||||
all_times.extend([generate_time/(float(cnt))]*cnt) #may modify this later to be more precise given hf is open source
|
||||
print("Returning model inference outputs")
|
||||
# print("Returning model inference outputs")
|
||||
return outputs, all_times
|
||||
|
||||
# @backoff.on_exception(backoff.expo, openai.error.OpenAIError)
|
||||
@ -88,6 +89,8 @@ def chatgpt(model, messages, temperature=0.7, max_tokens=1000, n=5, stop=None) -
|
||||
res = client.chat.completions.create(model=model, messages=messages, temperature=temperature, n=cnt, stop=stop)
|
||||
generate_time = time.perf_counter()-start_time
|
||||
|
||||
print(f"Averaged Time to Generate Single Node Proposal: {[generate_time/(float(cnt))]*cnt} seconds")
|
||||
|
||||
outputs.extend([choice.message.content for choice in res.choices])
|
||||
|
||||
# log completion tokens
|
||||
|
Loading…
Reference in New Issue
Block a user