Work with Groq API

This commit is contained in:
ricardo-larosa 2024-04-05 18:03:12 +02:00
parent a0fe552731
commit c1ada54223
7 changed files with 120 additions and 31 deletions

1
.gitignore vendored
View File

@ -4,3 +4,4 @@ src/tree_of_thoughts_llm.egg-info/
.env
*.pyc
*.DS_Store
/datasets_cache/

View File

@ -62,4 +62,5 @@ dependencies:
- urllib3==2.0.2
- yarl==1.9.2
- datasets
- python-dotenv
prefix: /Users/ricardo/anaconda3/envs/tot

View File

@ -3,14 +3,14 @@ from tot.methods.bfs import solve
from tot.tasks.swe import SWETask
from datasets import load_dataset
train_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split = "test")
print("Downloading dataset...")
train_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split = "dev", cache_dir='datasets_cache')
args = argparse.Namespace(
backend='gpt-4',
temperature=0.1,
backend='mixtral-8x7b-32768',
temperature=0.2,
task='swe',
naive_run=False,
naive_run=False,
prompt_sample='cot',
method_generate='sample',
method_evaluate='vote',
@ -19,8 +19,12 @@ args = argparse.Namespace(
n_evaluate_sample=3,
n_select_sample=5)
print("Solving...")
task = SWETask(train_dataset)
i = 10
ys, infos = solve(args, task, i, to_print=False)
print("Solution:")
print(SWETask.parse_diff_block(ys[0]))
ys, infos = solve(args, task, 1)
print(ys[0])
print("Expected solution:")
print(train_dataset[i]["patch"])

View File

@ -1,13 +1,21 @@
import itertools
import itertools, os
import numpy as np
from functools import partial
from tot.models import gpt
api_base = os.getenv("OPENAI_API_BASE", "")
if api_base == 'https://api.groq.com/openai/v1':
from tot.models import groq
platform = groq
else:
from tot.models import gpt
platform = gpt
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
value_prompt = task.value_prompt_wrap(x, y)
if cache_value and value_prompt in task.value_cache:
return task.value_cache[value_prompt]
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
value_outputs = platform(value_prompt, n=n_evaluate_sample, stop=None)
value = task.value_outputs_unwrap(x, y, value_outputs)
if cache_value:
task.value_cache[value_prompt] = value
@ -27,13 +35,13 @@ def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
def get_votes(task, x, ys, n_evaluate_sample):
vote_prompt = task.vote_prompt_wrap(x, ys)
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
vote_outputs = platform(vote_prompt, n=n_evaluate_sample, stop=None)
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
return values
def get_proposals(task, x, y):
propose_prompt = task.propose_prompt_wrap(x, y)
proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
proposals = platform(propose_prompt, n=1, stop=None)[0].split('\n')
return [y + _ + '\n' for _ in proposals]
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
@ -43,13 +51,13 @@ def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
prompt = task.cot_prompt_wrap(x, y)
else:
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
samples = gpt(prompt, n=n_generate_sample, stop=stop)
samples = platform(prompt, n=n_generate_sample, stop=stop)
return [y + _ for _ in samples]
def solve(args, task, idx, to_print=True):
global gpt
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
print(gpt)
global platform
platform = partial(platform, model=args.backend, temperature=args.temperature)
print(platform)
x = task.get_input(idx) # input
ys = [''] # current output candidates
infos = []
@ -88,9 +96,9 @@ def solve(args, task, idx, to_print=True):
return ys, {'steps': infos}
def naive_solve(args, task, idx, to_print=True):
global gpt
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
print(gpt)
global platform
platform = partial(platform, model=args.backend, temperature=args.temperature)
print(platform)
x = task.get_input(idx) # input
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
return ys, {}

View File

@ -19,6 +19,9 @@ if api_base != "":
def completions_with_backoff(**kwargs):
return openai.ChatCompletion.create(**kwargs)
def completions(**kwargs):
return openai.ChatCompletion.create(**kwargs)
def gpt(prompt, model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None) -> list:
messages = [{"role": "user", "content": prompt}]
return chatgpt(messages, model=model, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop)
@ -43,3 +46,25 @@ def gpt_usage(backend="gpt-4"):
elif backend == "gpt-3.5-turbo":
cost = completion_tokens / 1000 * 0.002 + prompt_tokens / 1000 * 0.0015
return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens, "cost": cost}
def groq(prompt, model="mixtral-8x7b-32768", temperature=0.5, max_tokens=1500, n=1, stop=None) -> list:
global completion_tokens, prompt_tokens
messages = [{"role": "user", "content": prompt}]
return groqgpt(messages, model=model, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop)
def groqgpt(messages, model="mixtral-8x7b-32768", temperature=0.5, max_tokens=1500,n=1, stop=None) -> list:
global completion_tokens, prompt_tokens
outputs = []
while n > 0:
cnt = min(n, 20)
n -= cnt
res = completions(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stop=stop)
outputs.extend([choice["message"]["content"] for choice in res["choices"]])
# log completion tokens
completion_tokens += res["usage"]["completion_tokens"]
prompt_tokens += res["usage"]["prompt_tokens"]
return outputs
def groq_usage():
global completion_tokens, prompt_tokens
return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens}

View File

@ -1,18 +1,33 @@
standard_prompt = '''
Given the problem statement of a github issue write a correct git patch to solve it. The problem statement is the following: {input}
{input}
'''
cot_prompt = '''
Given the problem statement of a github issue write a correct git patch to solve it.
Make a plan then write. Your output should be of the following format:
cot_prompt = '''I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply.
The patch file should be in the unified diff format. Example:
Plan:
Your plan here.
```diff
diff --git a/file.py b/file.py
--- a/file.py
+++ b/file.py
@@ -1,27 +1,35 @@
def euclidean(a, b):
- while b:
- a, b = b, a % b
- return a
+ if b == 0:
+ return a
+ return euclidean(b, a % b)
```
Given the problem statement of a github issue write a correct git patch to solve it.
Patch:
```diff
Your patch here.
```
The problem statement is the following: {input}
The problem statement is the following:
{input}
'''
@ -22,5 +37,5 @@ vote_prompt = '''Given an instruction and several choices, decide which choice i
compare_prompt = '''Briefly analyze the degree of correctness of the following two patches. Conclude in the last line "The more correct patch is 1", "The more correct patch is 2", or "The two patches are equally correct".
'''
score_prompt = '''Analyze the following patch, then at the last line conclude "Therefore the correctness score is {s}", where s is an integer from 1 to 10.
score_prompt = '''Analyze the following patch, then at the last line conclude "Therefore the correctness score is {s}", where {s} is an integer from 1 to 10.
'''

View File

@ -1,7 +1,7 @@
import re
import re, os
from tot.tasks.base import Task
from tot.prompts.swe import *
from tot.models import gpt
from tot.models import gpt, groq
class SWETask(Task):
"""
@ -24,7 +24,11 @@ class SWETask(Task):
def test_output(self, idx: int, output: str):
output = output.split('Patch:\n')[-1]
prompt = score_prompt + output
score_outputs = gpt(prompt, n=5, model='gpt-4')
api_base = os.getenv("OPENAI_API_BASE", "")
if api_base == 'https://api.groq.com/openai/v1':
score_output = groq(prompt, n=3, model='mixtral-8x7b-32768')
else:
score_outputs = gpt(prompt, n=5, model='gpt-4')
scores = []
for score_output in score_outputs:
print("score_output: ",score_output)
@ -89,3 +93,34 @@ class SWETask(Task):
else:
print(f'-----------------compare no match: {[compare_output]}')
return -1
@staticmethod
def parse_diff_block(text: str):
"""Extracts the first block of unified diff format.
Args:
text (str): The large text containing one or more diff blocks.
Returns:
str: The first block between `diff and ` if found, otherwise None.
"""
start_pattern = r"```diff"
end_pattern = r"```"
in_diff_block = False
diff_block = []
for line in text.splitlines():
if start_pattern in line:
in_diff_block = True
continue # Skip the line with the start marker
if in_diff_block:
if end_pattern in line:
in_diff_block = False
break # End of the diff block
else:
diff_block.append(line)
return "\n".join(diff_block) if diff_block else None