added other model support part 1. part 2 pending hf gated repo access

This commit is contained in:
emilyworks 2024-10-27 00:08:03 -04:00
parent ab400345c5
commit b7e9c7242c
2 changed files with 62 additions and 16 deletions

View File

@ -1,13 +1,13 @@
import itertools import itertools
import numpy as np import numpy as np
from functools import partial from functools import partial
from tot.models import gpt from tot.models import inference_model
def get_value(task, x, y, n_evaluate_sample, cache_value=True): def get_value(task, x, y, n_evaluate_sample, cache_value=True):
value_prompt = task.value_prompt_wrap(x, y) value_prompt = task.value_prompt_wrap(x, y)
if cache_value and value_prompt in task.value_cache: if cache_value and value_prompt in task.value_cache:
return task.value_cache[value_prompt] return task.value_cache[value_prompt]
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None) value_outputs = inference_model(value_prompt, n=n_evaluate_sample, stop=None)
value = task.value_outputs_unwrap(x, y, value_outputs) value = task.value_outputs_unwrap(x, y, value_outputs)
if cache_value: if cache_value:
task.value_cache[value_prompt] = value task.value_cache[value_prompt] = value
@ -27,13 +27,13 @@ def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
def get_votes(task, x, ys, n_evaluate_sample): def get_votes(task, x, ys, n_evaluate_sample):
vote_prompt = task.vote_prompt_wrap(x, ys) vote_prompt = task.vote_prompt_wrap(x, ys)
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None) vote_outputs = inference_model(vote_prompt, n=n_evaluate_sample, stop=None)
values = task.vote_outputs_unwrap(vote_outputs, len(ys)) values = task.vote_outputs_unwrap(vote_outputs, len(ys))
return values return values
def get_proposals(task, x, y): def get_proposals(task, x, y):
propose_prompt = task.propose_prompt_wrap(x, y) propose_prompt = task.propose_prompt_wrap(x, y)
proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n') proposals = inference_model(propose_prompt, n=1, stop=None)[0].split('\n')
return [y + _ + '\n' for _ in proposals] return [y + _ + '\n' for _ in proposals]
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop): def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
@ -43,13 +43,13 @@ def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
prompt = task.cot_prompt_wrap(x, y) prompt = task.cot_prompt_wrap(x, y)
else: else:
raise ValueError(f'prompt_sample {prompt_sample} not recognized') raise ValueError(f'prompt_sample {prompt_sample} not recognized')
samples = gpt(prompt, n=n_generate_sample, stop=stop) samples = inference_model(prompt, n=n_generate_sample, stop=stop)
return [y + _ for _ in samples] return [y + _ for _ in samples]
def solve(args, task, idx, to_print=True): def solve(args, task, idx, to_print=True):
global gpt global inference_model
gpt = partial(gpt, model=args.backend, temperature=args.temperature) inference_model = partial(inference_model, model=args.backend, temperature=args.temperature)
print(gpt) print(inference_model)
x = task.get_input(idx) # input x = task.get_input(idx) # input
ys = [''] # current output candidates ys = [''] # current output candidates
infos = [] infos = []
@ -88,9 +88,9 @@ def solve(args, task, idx, to_print=True):
return ys, {'steps': infos} return ys, {'steps': infos}
def naive_solve(args, task, idx, to_print=True): def naive_solve(args, task, idx, to_print=True):
global gpt global inference_model
gpt = partial(gpt, model=args.backend, temperature=args.temperature) inference_model = partial(inference_model, model=args.backend, temperature=args.temperature)
print(gpt) print(inference_model)
x = task.get_input(idx) # input x = task.get_input(idx) # input
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None) ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
return ys, {} return ys, {}

View File

@ -2,6 +2,8 @@ import os
import openai import openai
import backoff import backoff
from transformers import AutoTokenizer, AutoModelForCausalLM
completion_tokens = prompt_tokens = 0 completion_tokens = prompt_tokens = 0
api_key = os.getenv("OPENAI_API_KEY", "") api_key = os.getenv("OPENAI_API_KEY", "")
@ -15,25 +17,69 @@ if api_base != "":
print("Warning: OPENAI_API_BASE is set to {}".format(api_base)) print("Warning: OPENAI_API_BASE is set to {}".format(api_base))
openai.api_base = api_base openai.api_base = api_base
#######################
### Model Inference ###
#######################
def inference_model(prompt, model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None, vllm=False, quant=False) -> list:
'''
Driver function for model inference.
'''
if model == "llama_3.2" and vllm:
return llama_32(prompt, quant, vllm, temperature, max_tokens, n, stop)
else:
messages = [{"role": "user", "content": prompt}]
return chatgpt(prompt, model=model, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop)
def llama_32(prompt, quant, vllm, temperature, max_tokens, n, stop): #will add vllm support later
'''
Use llama3.2 for inference
'''
if quant:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8")
else:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B")
inputs = tokenizer(prompt, return_tensors="pt")
outputs = hf_model(model, inputs, temperature, max_tokens, n, stop)
return outputs
def hf_model(model, input_tokens, temperature=0.7, max_tokens=1000, n=1, stop=None):
"""
Given a model (Huggingface) and input tokens, generate an output
"""
outputs = []
while n > 0:
cnt = min(n, 20) #never generate more than 20 outputs per same input
n -= cnt
outputs = model.generate(**input_tokens, temperature=temperature, max_new_tokens=max_tokens, num_return_sequences=cnt) #might add stopping criteria depending on heuristics experimentation
#need to take a look at the specific output format once i get access to the gated repo
#need to outputs.extend()
return outputs
@backoff.on_exception(backoff.expo, openai.error.OpenAIError) @backoff.on_exception(backoff.expo, openai.error.OpenAIError)
def completions_with_backoff(**kwargs): def completions_with_backoff(**kwargs):
return openai.ChatCompletion.create(**kwargs) return openai.ChatCompletion.create(**kwargs)
def gpt(prompt, model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None) -> list:
messages = [{"role": "user", "content": prompt}]
return chatgpt(messages, model=model, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop)
def chatgpt(messages, model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None) -> list: def chatgpt(messages, model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None) -> list:
global completion_tokens, prompt_tokens global completion_tokens, prompt_tokens
outputs = [] outputs = []
while n > 0: while n > 0:
cnt = min(n, 20) cnt = min(n, 20) #never generate more than 20 outputs per same input
n -= cnt n -= cnt
res = completions_with_backoff(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, n=cnt, stop=stop) res = completions_with_backoff(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, n=cnt, stop=stop)
outputs.extend([choice["message"]["content"] for choice in res["choices"]]) outputs.extend([choice["message"]["content"] for choice in res["choices"]])
# log completion tokens # log completion tokens
completion_tokens += res["usage"]["completion_tokens"] completion_tokens += res["usage"]["completion_tokens"]
prompt_tokens += res["usage"]["prompt_tokens"] prompt_tokens += res["usage"]["prompt_tokens"]
return outputs return outputs
def gpt_usage(backend="gpt-4"): def gpt_usage(backend="gpt-4"):