Work with Groq API

This commit is contained in:
ricardo-larosa 2024-04-05 18:03:12 +02:00
parent a0fe552731
commit c1ada54223
7 changed files with 120 additions and 31 deletions

1
.gitignore vendored
View File

@ -4,3 +4,4 @@ src/tree_of_thoughts_llm.egg-info/
.env .env
*.pyc *.pyc
*.DS_Store *.DS_Store
/datasets_cache/

View File

@ -62,4 +62,5 @@ dependencies:
- urllib3==2.0.2 - urllib3==2.0.2
- yarl==1.9.2 - yarl==1.9.2
- datasets - datasets
- python-dotenv
prefix: /Users/ricardo/anaconda3/envs/tot prefix: /Users/ricardo/anaconda3/envs/tot

View File

@ -3,12 +3,12 @@ from tot.methods.bfs import solve
from tot.tasks.swe import SWETask from tot.tasks.swe import SWETask
from datasets import load_dataset from datasets import load_dataset
print("Downloading dataset...")
train_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split = "test") train_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split = "dev", cache_dir='datasets_cache')
args = argparse.Namespace( args = argparse.Namespace(
backend='gpt-4', backend='mixtral-8x7b-32768',
temperature=0.1, temperature=0.2,
task='swe', task='swe',
naive_run=False, naive_run=False,
prompt_sample='cot', prompt_sample='cot',
@ -19,8 +19,12 @@ args = argparse.Namespace(
n_evaluate_sample=3, n_evaluate_sample=3,
n_select_sample=5) n_select_sample=5)
print("Solving...")
task = SWETask(train_dataset) task = SWETask(train_dataset)
i = 10
ys, infos = solve(args, task, i, to_print=False)
print("Solution:")
print(SWETask.parse_diff_block(ys[0]))
ys, infos = solve(args, task, 1) print("Expected solution:")
print(ys[0]) print(train_dataset[i]["patch"])

View File

@ -1,13 +1,21 @@
import itertools import itertools, os
import numpy as np import numpy as np
from functools import partial from functools import partial
from tot.models import gpt
api_base = os.getenv("OPENAI_API_BASE", "")
if api_base == 'https://api.groq.com/openai/v1':
from tot.models import groq
platform = groq
else:
from tot.models import gpt
platform = gpt
def get_value(task, x, y, n_evaluate_sample, cache_value=True): def get_value(task, x, y, n_evaluate_sample, cache_value=True):
value_prompt = task.value_prompt_wrap(x, y) value_prompt = task.value_prompt_wrap(x, y)
if cache_value and value_prompt in task.value_cache: if cache_value and value_prompt in task.value_cache:
return task.value_cache[value_prompt] return task.value_cache[value_prompt]
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None) value_outputs = platform(value_prompt, n=n_evaluate_sample, stop=None)
value = task.value_outputs_unwrap(x, y, value_outputs) value = task.value_outputs_unwrap(x, y, value_outputs)
if cache_value: if cache_value:
task.value_cache[value_prompt] = value task.value_cache[value_prompt] = value
@ -27,13 +35,13 @@ def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
def get_votes(task, x, ys, n_evaluate_sample): def get_votes(task, x, ys, n_evaluate_sample):
vote_prompt = task.vote_prompt_wrap(x, ys) vote_prompt = task.vote_prompt_wrap(x, ys)
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None) vote_outputs = platform(vote_prompt, n=n_evaluate_sample, stop=None)
values = task.vote_outputs_unwrap(vote_outputs, len(ys)) values = task.vote_outputs_unwrap(vote_outputs, len(ys))
return values return values
def get_proposals(task, x, y): def get_proposals(task, x, y):
propose_prompt = task.propose_prompt_wrap(x, y) propose_prompt = task.propose_prompt_wrap(x, y)
proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n') proposals = platform(propose_prompt, n=1, stop=None)[0].split('\n')
return [y + _ + '\n' for _ in proposals] return [y + _ + '\n' for _ in proposals]
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop): def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
@ -43,13 +51,13 @@ def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
prompt = task.cot_prompt_wrap(x, y) prompt = task.cot_prompt_wrap(x, y)
else: else:
raise ValueError(f'prompt_sample {prompt_sample} not recognized') raise ValueError(f'prompt_sample {prompt_sample} not recognized')
samples = gpt(prompt, n=n_generate_sample, stop=stop) samples = platform(prompt, n=n_generate_sample, stop=stop)
return [y + _ for _ in samples] return [y + _ for _ in samples]
def solve(args, task, idx, to_print=True): def solve(args, task, idx, to_print=True):
global gpt global platform
gpt = partial(gpt, model=args.backend, temperature=args.temperature) platform = partial(platform, model=args.backend, temperature=args.temperature)
print(gpt) print(platform)
x = task.get_input(idx) # input x = task.get_input(idx) # input
ys = [''] # current output candidates ys = [''] # current output candidates
infos = [] infos = []
@ -88,9 +96,9 @@ def solve(args, task, idx, to_print=True):
return ys, {'steps': infos} return ys, {'steps': infos}
def naive_solve(args, task, idx, to_print=True): def naive_solve(args, task, idx, to_print=True):
global gpt global platform
gpt = partial(gpt, model=args.backend, temperature=args.temperature) platform = partial(platform, model=args.backend, temperature=args.temperature)
print(gpt) print(platform)
x = task.get_input(idx) # input x = task.get_input(idx) # input
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None) ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
return ys, {} return ys, {}

View File

@ -19,6 +19,9 @@ if api_base != "":
def completions_with_backoff(**kwargs): def completions_with_backoff(**kwargs):
return openai.ChatCompletion.create(**kwargs) return openai.ChatCompletion.create(**kwargs)
def completions(**kwargs):
return openai.ChatCompletion.create(**kwargs)
def gpt(prompt, model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None) -> list: def gpt(prompt, model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None) -> list:
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
return chatgpt(messages, model=model, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop) return chatgpt(messages, model=model, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop)
@ -43,3 +46,25 @@ def gpt_usage(backend="gpt-4"):
elif backend == "gpt-3.5-turbo": elif backend == "gpt-3.5-turbo":
cost = completion_tokens / 1000 * 0.002 + prompt_tokens / 1000 * 0.0015 cost = completion_tokens / 1000 * 0.002 + prompt_tokens / 1000 * 0.0015
return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens, "cost": cost} return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens, "cost": cost}
def groq(prompt, model="mixtral-8x7b-32768", temperature=0.5, max_tokens=1500, n=1, stop=None) -> list:
global completion_tokens, prompt_tokens
messages = [{"role": "user", "content": prompt}]
return groqgpt(messages, model=model, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop)
def groqgpt(messages, model="mixtral-8x7b-32768", temperature=0.5, max_tokens=1500,n=1, stop=None) -> list:
global completion_tokens, prompt_tokens
outputs = []
while n > 0:
cnt = min(n, 20)
n -= cnt
res = completions(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stop=stop)
outputs.extend([choice["message"]["content"] for choice in res["choices"]])
# log completion tokens
completion_tokens += res["usage"]["completion_tokens"]
prompt_tokens += res["usage"]["prompt_tokens"]
return outputs
def groq_usage():
global completion_tokens, prompt_tokens
return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens}

View File

@ -1,18 +1,33 @@
standard_prompt = ''' standard_prompt = '''
Given the problem statement of a github issue write a correct git patch to solve it. The problem statement is the following: {input} {input}
''' '''
cot_prompt = ''' cot_prompt = '''I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply.
Given the problem statement of a github issue write a correct git patch to solve it. The patch file should be in the unified diff format. Example:
Make a plan then write. Your output should be of the following format:
Plan: ```diff
Your plan here. diff --git a/file.py b/file.py
--- a/file.py
+++ b/file.py
@@ -1,27 +1,35 @@
def euclidean(a, b):
- while b:
- a, b = b, a % b
- return a
+ if b == 0:
+ return a
+ return euclidean(b, a % b)
```
Given the problem statement of a github issue write a correct git patch to solve it.
Patch: Patch:
```diff
Your patch here. Your patch here.
```
The problem statement is the following: {input} The problem statement is the following:
{input}
''' '''
@ -22,5 +37,5 @@ vote_prompt = '''Given an instruction and several choices, decide which choice i
compare_prompt = '''Briefly analyze the degree of correctness of the following two patches. Conclude in the last line "The more correct patch is 1", "The more correct patch is 2", or "The two patches are equally correct". compare_prompt = '''Briefly analyze the degree of correctness of the following two patches. Conclude in the last line "The more correct patch is 1", "The more correct patch is 2", or "The two patches are equally correct".
''' '''
score_prompt = '''Analyze the following patch, then at the last line conclude "Therefore the correctness score is {s}", where s is an integer from 1 to 10. score_prompt = '''Analyze the following patch, then at the last line conclude "Therefore the correctness score is {s}", where {s} is an integer from 1 to 10.
''' '''

View File

@ -1,7 +1,7 @@
import re import re, os
from tot.tasks.base import Task from tot.tasks.base import Task
from tot.prompts.swe import * from tot.prompts.swe import *
from tot.models import gpt from tot.models import gpt, groq
class SWETask(Task): class SWETask(Task):
""" """
@ -24,6 +24,10 @@ class SWETask(Task):
def test_output(self, idx: int, output: str): def test_output(self, idx: int, output: str):
output = output.split('Patch:\n')[-1] output = output.split('Patch:\n')[-1]
prompt = score_prompt + output prompt = score_prompt + output
api_base = os.getenv("OPENAI_API_BASE", "")
if api_base == 'https://api.groq.com/openai/v1':
score_output = groq(prompt, n=3, model='mixtral-8x7b-32768')
else:
score_outputs = gpt(prompt, n=5, model='gpt-4') score_outputs = gpt(prompt, n=5, model='gpt-4')
scores = [] scores = []
for score_output in score_outputs: for score_output in score_outputs:
@ -89,3 +93,34 @@ class SWETask(Task):
else: else:
print(f'-----------------compare no match: {[compare_output]}') print(f'-----------------compare no match: {[compare_output]}')
return -1 return -1
@staticmethod
def parse_diff_block(text: str):
"""Extracts the first block of unified diff format.
Args:
text (str): The large text containing one or more diff blocks.
Returns:
str: The first block between `diff and ` if found, otherwise None.
"""
start_pattern = r"```diff"
end_pattern = r"```"
in_diff_block = False
diff_block = []
for line in text.splitlines():
if start_pattern in line:
in_diff_block = True
continue # Skip the line with the start marker
if in_diff_block:
if end_pattern in line:
in_diff_block = False
break # End of the diff block
else:
diff_block.append(line)
return "\n".join(diff_block) if diff_block else None