mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-04-23 07:34:13 +00:00
added other model support part 1. part 2 pending hf gated repo access
This commit is contained in:
parent
ab400345c5
commit
b7e9c7242c
@ -1,13 +1,13 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from tot.models import gpt
|
from tot.models import inference_model
|
||||||
|
|
||||||
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
|
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
|
||||||
value_prompt = task.value_prompt_wrap(x, y)
|
value_prompt = task.value_prompt_wrap(x, y)
|
||||||
if cache_value and value_prompt in task.value_cache:
|
if cache_value and value_prompt in task.value_cache:
|
||||||
return task.value_cache[value_prompt]
|
return task.value_cache[value_prompt]
|
||||||
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
|
value_outputs = inference_model(value_prompt, n=n_evaluate_sample, stop=None)
|
||||||
value = task.value_outputs_unwrap(x, y, value_outputs)
|
value = task.value_outputs_unwrap(x, y, value_outputs)
|
||||||
if cache_value:
|
if cache_value:
|
||||||
task.value_cache[value_prompt] = value
|
task.value_cache[value_prompt] = value
|
||||||
@ -27,13 +27,13 @@ def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
|
|||||||
|
|
||||||
def get_votes(task, x, ys, n_evaluate_sample):
|
def get_votes(task, x, ys, n_evaluate_sample):
|
||||||
vote_prompt = task.vote_prompt_wrap(x, ys)
|
vote_prompt = task.vote_prompt_wrap(x, ys)
|
||||||
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
|
vote_outputs = inference_model(vote_prompt, n=n_evaluate_sample, stop=None)
|
||||||
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
|
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def get_proposals(task, x, y):
|
def get_proposals(task, x, y):
|
||||||
propose_prompt = task.propose_prompt_wrap(x, y)
|
propose_prompt = task.propose_prompt_wrap(x, y)
|
||||||
proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
|
proposals = inference_model(propose_prompt, n=1, stop=None)[0].split('\n')
|
||||||
return [y + _ + '\n' for _ in proposals]
|
return [y + _ + '\n' for _ in proposals]
|
||||||
|
|
||||||
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
|
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
|
||||||
@ -43,13 +43,13 @@ def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
|
|||||||
prompt = task.cot_prompt_wrap(x, y)
|
prompt = task.cot_prompt_wrap(x, y)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
|
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
|
||||||
samples = gpt(prompt, n=n_generate_sample, stop=stop)
|
samples = inference_model(prompt, n=n_generate_sample, stop=stop)
|
||||||
return [y + _ for _ in samples]
|
return [y + _ for _ in samples]
|
||||||
|
|
||||||
def solve(args, task, idx, to_print=True):
|
def solve(args, task, idx, to_print=True):
|
||||||
global gpt
|
global inference_model
|
||||||
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
|
inference_model = partial(inference_model, model=args.backend, temperature=args.temperature)
|
||||||
print(gpt)
|
print(inference_model)
|
||||||
x = task.get_input(idx) # input
|
x = task.get_input(idx) # input
|
||||||
ys = [''] # current output candidates
|
ys = [''] # current output candidates
|
||||||
infos = []
|
infos = []
|
||||||
@ -88,9 +88,9 @@ def solve(args, task, idx, to_print=True):
|
|||||||
return ys, {'steps': infos}
|
return ys, {'steps': infos}
|
||||||
|
|
||||||
def naive_solve(args, task, idx, to_print=True):
|
def naive_solve(args, task, idx, to_print=True):
|
||||||
global gpt
|
global inference_model
|
||||||
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
|
inference_model = partial(inference_model, model=args.backend, temperature=args.temperature)
|
||||||
print(gpt)
|
print(inference_model)
|
||||||
x = task.get_input(idx) # input
|
x = task.get_input(idx) # input
|
||||||
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
|
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
|
||||||
return ys, {}
|
return ys, {}
|
@ -2,6 +2,8 @@ import os
|
|||||||
import openai
|
import openai
|
||||||
import backoff
|
import backoff
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
completion_tokens = prompt_tokens = 0
|
completion_tokens = prompt_tokens = 0
|
||||||
|
|
||||||
api_key = os.getenv("OPENAI_API_KEY", "")
|
api_key = os.getenv("OPENAI_API_KEY", "")
|
||||||
@ -15,25 +17,69 @@ if api_base != "":
|
|||||||
print("Warning: OPENAI_API_BASE is set to {}".format(api_base))
|
print("Warning: OPENAI_API_BASE is set to {}".format(api_base))
|
||||||
openai.api_base = api_base
|
openai.api_base = api_base
|
||||||
|
|
||||||
|
#######################
|
||||||
|
### Model Inference ###
|
||||||
|
#######################
|
||||||
|
|
||||||
|
def inference_model(prompt, model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None, vllm=False, quant=False) -> list:
|
||||||
|
'''
|
||||||
|
Driver function for model inference.
|
||||||
|
'''
|
||||||
|
if model == "llama_3.2" and vllm:
|
||||||
|
return llama_32(prompt, quant, vllm, temperature, max_tokens, n, stop)
|
||||||
|
else:
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
return chatgpt(prompt, model=model, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop)
|
||||||
|
|
||||||
|
def llama_32(prompt, quant, vllm, temperature, max_tokens, n, stop): #will add vllm support later
|
||||||
|
'''
|
||||||
|
Use llama3.2 for inference
|
||||||
|
'''
|
||||||
|
if quant:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8")
|
||||||
|
else:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B")
|
||||||
|
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
outputs = hf_model(model, inputs, temperature, max_tokens, n, stop)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def hf_model(model, input_tokens, temperature=0.7, max_tokens=1000, n=1, stop=None):
|
||||||
|
"""
|
||||||
|
Given a model (Huggingface) and input tokens, generate an output
|
||||||
|
"""
|
||||||
|
outputs = []
|
||||||
|
|
||||||
|
while n > 0:
|
||||||
|
cnt = min(n, 20) #never generate more than 20 outputs per same input
|
||||||
|
n -= cnt
|
||||||
|
outputs = model.generate(**input_tokens, temperature=temperature, max_new_tokens=max_tokens, num_return_sequences=cnt) #might add stopping criteria depending on heuristics experimentation
|
||||||
|
#need to take a look at the specific output format once i get access to the gated repo
|
||||||
|
#need to outputs.extend()
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
@backoff.on_exception(backoff.expo, openai.error.OpenAIError)
|
@backoff.on_exception(backoff.expo, openai.error.OpenAIError)
|
||||||
def completions_with_backoff(**kwargs):
|
def completions_with_backoff(**kwargs):
|
||||||
return openai.ChatCompletion.create(**kwargs)
|
return openai.ChatCompletion.create(**kwargs)
|
||||||
|
|
||||||
def gpt(prompt, model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None) -> list:
|
|
||||||
messages = [{"role": "user", "content": prompt}]
|
|
||||||
return chatgpt(messages, model=model, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop)
|
|
||||||
|
|
||||||
def chatgpt(messages, model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None) -> list:
|
def chatgpt(messages, model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None) -> list:
|
||||||
global completion_tokens, prompt_tokens
|
global completion_tokens, prompt_tokens
|
||||||
outputs = []
|
outputs = []
|
||||||
while n > 0:
|
while n > 0:
|
||||||
cnt = min(n, 20)
|
cnt = min(n, 20) #never generate more than 20 outputs per same input
|
||||||
n -= cnt
|
n -= cnt
|
||||||
|
|
||||||
res = completions_with_backoff(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, n=cnt, stop=stop)
|
res = completions_with_backoff(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, n=cnt, stop=stop)
|
||||||
outputs.extend([choice["message"]["content"] for choice in res["choices"]])
|
outputs.extend([choice["message"]["content"] for choice in res["choices"]])
|
||||||
|
|
||||||
# log completion tokens
|
# log completion tokens
|
||||||
completion_tokens += res["usage"]["completion_tokens"]
|
completion_tokens += res["usage"]["completion_tokens"]
|
||||||
prompt_tokens += res["usage"]["prompt_tokens"]
|
prompt_tokens += res["usage"]["prompt_tokens"]
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def gpt_usage(backend="gpt-4"):
|
def gpt_usage(backend="gpt-4"):
|
||||||
|
Loading…
Reference in New Issue
Block a user