mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-06-26 18:26:00 +00:00
Work with Groq API
This commit is contained in:
parent
a0fe552731
commit
c1ada54223
1
.gitignore
vendored
1
.gitignore
vendored
@ -4,3 +4,4 @@ src/tree_of_thoughts_llm.egg-info/
|
||||
.env
|
||||
*.pyc
|
||||
*.DS_Store
|
||||
/datasets_cache/
|
1
env.yml
1
env.yml
@ -62,4 +62,5 @@ dependencies:
|
||||
- urllib3==2.0.2
|
||||
- yarl==1.9.2
|
||||
- datasets
|
||||
- python-dotenv
|
||||
prefix: /Users/ricardo/anaconda3/envs/tot
|
||||
|
@ -3,14 +3,14 @@ from tot.methods.bfs import solve
|
||||
from tot.tasks.swe import SWETask
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
train_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split = "test")
|
||||
print("Downloading dataset...")
|
||||
train_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split = "dev", cache_dir='datasets_cache')
|
||||
|
||||
args = argparse.Namespace(
|
||||
backend='gpt-4',
|
||||
temperature=0.1,
|
||||
backend='mixtral-8x7b-32768',
|
||||
temperature=0.2,
|
||||
task='swe',
|
||||
naive_run=False,
|
||||
naive_run=False,
|
||||
prompt_sample='cot',
|
||||
method_generate='sample',
|
||||
method_evaluate='vote',
|
||||
@ -19,8 +19,12 @@ args = argparse.Namespace(
|
||||
n_evaluate_sample=3,
|
||||
n_select_sample=5)
|
||||
|
||||
print("Solving...")
|
||||
task = SWETask(train_dataset)
|
||||
i = 10
|
||||
ys, infos = solve(args, task, i, to_print=False)
|
||||
print("Solution:")
|
||||
print(SWETask.parse_diff_block(ys[0]))
|
||||
|
||||
ys, infos = solve(args, task, 1)
|
||||
print(ys[0])
|
||||
|
||||
print("Expected solution:")
|
||||
print(train_dataset[i]["patch"])
|
@ -1,13 +1,21 @@
|
||||
import itertools
|
||||
import itertools, os
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from tot.models import gpt
|
||||
|
||||
api_base = os.getenv("OPENAI_API_BASE", "")
|
||||
if api_base == 'https://api.groq.com/openai/v1':
|
||||
from tot.models import groq
|
||||
platform = groq
|
||||
else:
|
||||
from tot.models import gpt
|
||||
platform = gpt
|
||||
|
||||
|
||||
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
|
||||
value_prompt = task.value_prompt_wrap(x, y)
|
||||
if cache_value and value_prompt in task.value_cache:
|
||||
return task.value_cache[value_prompt]
|
||||
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
|
||||
value_outputs = platform(value_prompt, n=n_evaluate_sample, stop=None)
|
||||
value = task.value_outputs_unwrap(x, y, value_outputs)
|
||||
if cache_value:
|
||||
task.value_cache[value_prompt] = value
|
||||
@ -27,13 +35,13 @@ def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
|
||||
|
||||
def get_votes(task, x, ys, n_evaluate_sample):
|
||||
vote_prompt = task.vote_prompt_wrap(x, ys)
|
||||
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
|
||||
vote_outputs = platform(vote_prompt, n=n_evaluate_sample, stop=None)
|
||||
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
|
||||
return values
|
||||
|
||||
def get_proposals(task, x, y):
|
||||
propose_prompt = task.propose_prompt_wrap(x, y)
|
||||
proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
|
||||
proposals = platform(propose_prompt, n=1, stop=None)[0].split('\n')
|
||||
return [y + _ + '\n' for _ in proposals]
|
||||
|
||||
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
|
||||
@ -43,13 +51,13 @@ def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
|
||||
prompt = task.cot_prompt_wrap(x, y)
|
||||
else:
|
||||
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
|
||||
samples = gpt(prompt, n=n_generate_sample, stop=stop)
|
||||
samples = platform(prompt, n=n_generate_sample, stop=stop)
|
||||
return [y + _ for _ in samples]
|
||||
|
||||
def solve(args, task, idx, to_print=True):
|
||||
global gpt
|
||||
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
|
||||
print(gpt)
|
||||
global platform
|
||||
platform = partial(platform, model=args.backend, temperature=args.temperature)
|
||||
print(platform)
|
||||
x = task.get_input(idx) # input
|
||||
ys = [''] # current output candidates
|
||||
infos = []
|
||||
@ -88,9 +96,9 @@ def solve(args, task, idx, to_print=True):
|
||||
return ys, {'steps': infos}
|
||||
|
||||
def naive_solve(args, task, idx, to_print=True):
|
||||
global gpt
|
||||
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
|
||||
print(gpt)
|
||||
global platform
|
||||
platform = partial(platform, model=args.backend, temperature=args.temperature)
|
||||
print(platform)
|
||||
x = task.get_input(idx) # input
|
||||
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
|
||||
return ys, {}
|
@ -19,6 +19,9 @@ if api_base != "":
|
||||
def completions_with_backoff(**kwargs):
|
||||
return openai.ChatCompletion.create(**kwargs)
|
||||
|
||||
def completions(**kwargs):
|
||||
return openai.ChatCompletion.create(**kwargs)
|
||||
|
||||
def gpt(prompt, model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None) -> list:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
return chatgpt(messages, model=model, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop)
|
||||
@ -43,3 +46,25 @@ def gpt_usage(backend="gpt-4"):
|
||||
elif backend == "gpt-3.5-turbo":
|
||||
cost = completion_tokens / 1000 * 0.002 + prompt_tokens / 1000 * 0.0015
|
||||
return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens, "cost": cost}
|
||||
|
||||
def groq(prompt, model="mixtral-8x7b-32768", temperature=0.5, max_tokens=1500, n=1, stop=None) -> list:
|
||||
global completion_tokens, prompt_tokens
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
return groqgpt(messages, model=model, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop)
|
||||
|
||||
def groqgpt(messages, model="mixtral-8x7b-32768", temperature=0.5, max_tokens=1500,n=1, stop=None) -> list:
|
||||
global completion_tokens, prompt_tokens
|
||||
outputs = []
|
||||
while n > 0:
|
||||
cnt = min(n, 20)
|
||||
n -= cnt
|
||||
res = completions(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stop=stop)
|
||||
outputs.extend([choice["message"]["content"] for choice in res["choices"]])
|
||||
# log completion tokens
|
||||
completion_tokens += res["usage"]["completion_tokens"]
|
||||
prompt_tokens += res["usage"]["prompt_tokens"]
|
||||
return outputs
|
||||
|
||||
def groq_usage():
|
||||
global completion_tokens, prompt_tokens
|
||||
return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens}
|
@ -1,18 +1,33 @@
|
||||
standard_prompt = '''
|
||||
Given the problem statement of a github issue write a correct git patch to solve it. The problem statement is the following: {input}
|
||||
{input}
|
||||
'''
|
||||
|
||||
cot_prompt = '''
|
||||
Given the problem statement of a github issue write a correct git patch to solve it.
|
||||
Make a plan then write. Your output should be of the following format:
|
||||
cot_prompt = '''I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply.
|
||||
The patch file should be in the unified diff format. Example:
|
||||
|
||||
Plan:
|
||||
Your plan here.
|
||||
```diff
|
||||
diff --git a/file.py b/file.py
|
||||
--- a/file.py
|
||||
+++ b/file.py
|
||||
@@ -1,27 +1,35 @@
|
||||
def euclidean(a, b):
|
||||
- while b:
|
||||
- a, b = b, a % b
|
||||
- return a
|
||||
+ if b == 0:
|
||||
+ return a
|
||||
+ return euclidean(b, a % b)
|
||||
```
|
||||
|
||||
Given the problem statement of a github issue write a correct git patch to solve it.
|
||||
|
||||
Patch:
|
||||
```diff
|
||||
Your patch here.
|
||||
```
|
||||
|
||||
The problem statement is the following: {input}
|
||||
The problem statement is the following:
|
||||
{input}
|
||||
'''
|
||||
|
||||
|
||||
@ -22,5 +37,5 @@ vote_prompt = '''Given an instruction and several choices, decide which choice i
|
||||
compare_prompt = '''Briefly analyze the degree of correctness of the following two patches. Conclude in the last line "The more correct patch is 1", "The more correct patch is 2", or "The two patches are equally correct".
|
||||
'''
|
||||
|
||||
score_prompt = '''Analyze the following patch, then at the last line conclude "Therefore the correctness score is {s}", where s is an integer from 1 to 10.
|
||||
score_prompt = '''Analyze the following patch, then at the last line conclude "Therefore the correctness score is {s}", where {s} is an integer from 1 to 10.
|
||||
'''
|
@ -1,7 +1,7 @@
|
||||
import re
|
||||
import re, os
|
||||
from tot.tasks.base import Task
|
||||
from tot.prompts.swe import *
|
||||
from tot.models import gpt
|
||||
from tot.models import gpt, groq
|
||||
|
||||
class SWETask(Task):
|
||||
"""
|
||||
@ -24,7 +24,11 @@ class SWETask(Task):
|
||||
def test_output(self, idx: int, output: str):
|
||||
output = output.split('Patch:\n')[-1]
|
||||
prompt = score_prompt + output
|
||||
score_outputs = gpt(prompt, n=5, model='gpt-4')
|
||||
api_base = os.getenv("OPENAI_API_BASE", "")
|
||||
if api_base == 'https://api.groq.com/openai/v1':
|
||||
score_output = groq(prompt, n=3, model='mixtral-8x7b-32768')
|
||||
else:
|
||||
score_outputs = gpt(prompt, n=5, model='gpt-4')
|
||||
scores = []
|
||||
for score_output in score_outputs:
|
||||
print("score_output: ",score_output)
|
||||
@ -89,3 +93,34 @@ class SWETask(Task):
|
||||
else:
|
||||
print(f'-----------------compare no match: {[compare_output]}')
|
||||
return -1
|
||||
|
||||
@staticmethod
|
||||
def parse_diff_block(text: str):
|
||||
"""Extracts the first block of unified diff format.
|
||||
|
||||
Args:
|
||||
text (str): The large text containing one or more diff blocks.
|
||||
|
||||
Returns:
|
||||
str: The first block between `diff and ` if found, otherwise None.
|
||||
"""
|
||||
|
||||
start_pattern = r"```diff"
|
||||
end_pattern = r"```"
|
||||
|
||||
in_diff_block = False
|
||||
diff_block = []
|
||||
|
||||
for line in text.splitlines():
|
||||
if start_pattern in line:
|
||||
in_diff_block = True
|
||||
continue # Skip the line with the start marker
|
||||
|
||||
if in_diff_block:
|
||||
if end_pattern in line:
|
||||
in_diff_block = False
|
||||
break # End of the diff block
|
||||
else:
|
||||
diff_block.append(line)
|
||||
|
||||
return "\n".join(diff_block) if diff_block else None
|
||||
|
Loading…
Reference in New Issue
Block a user