diff --git a/.gitignore b/.gitignore index eb63277..6dc729f 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ src/tree_of_thoughts_llm.egg-info/ .env *.pyc *.DS_Store +/datasets_cache/ \ No newline at end of file diff --git a/env.yml b/env.yml index 92b8328..34be92c 100644 --- a/env.yml +++ b/env.yml @@ -62,4 +62,5 @@ dependencies: - urllib3==2.0.2 - yarl==1.9.2 - datasets + - python-dotenv prefix: /Users/ricardo/anaconda3/envs/tot diff --git a/runs/swe-bench.py b/runs/swe-bench.py index e61fa86..32a797b 100644 --- a/runs/swe-bench.py +++ b/runs/swe-bench.py @@ -3,14 +3,14 @@ from tot.methods.bfs import solve from tot.tasks.swe import SWETask from datasets import load_dataset - -train_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split = "test") +print("Downloading dataset...") +train_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split = "dev", cache_dir='datasets_cache') args = argparse.Namespace( - backend='gpt-4', - temperature=0.1, + backend='mixtral-8x7b-32768', + temperature=0.2, task='swe', - naive_run=False, + naive_run=False, prompt_sample='cot', method_generate='sample', method_evaluate='vote', @@ -19,8 +19,12 @@ args = argparse.Namespace( n_evaluate_sample=3, n_select_sample=5) +print("Solving...") task = SWETask(train_dataset) +i = 10 +ys, infos = solve(args, task, i, to_print=False) +print("Solution:") +print(SWETask.parse_diff_block(ys[0])) -ys, infos = solve(args, task, 1) -print(ys[0]) - \ No newline at end of file +print("Expected solution:") +print(train_dataset[i]["patch"]) \ No newline at end of file diff --git a/src/tot/methods/bfs.py b/src/tot/methods/bfs.py index 4675bb9..2418d6e 100644 --- a/src/tot/methods/bfs.py +++ b/src/tot/methods/bfs.py @@ -1,13 +1,21 @@ -import itertools +import itertools, os import numpy as np from functools import partial -from tot.models import gpt + +api_base = os.getenv("OPENAI_API_BASE", "") +if api_base == 'https://api.groq.com/openai/v1': + from tot.models import groq + platform = groq +else: + from tot.models import gpt + platform = gpt + def get_value(task, x, y, n_evaluate_sample, cache_value=True): value_prompt = task.value_prompt_wrap(x, y) if cache_value and value_prompt in task.value_cache: return task.value_cache[value_prompt] - value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None) + value_outputs = platform(value_prompt, n=n_evaluate_sample, stop=None) value = task.value_outputs_unwrap(x, y, value_outputs) if cache_value: task.value_cache[value_prompt] = value @@ -27,13 +35,13 @@ def get_values(task, x, ys, n_evaluate_sample, cache_value=True): def get_votes(task, x, ys, n_evaluate_sample): vote_prompt = task.vote_prompt_wrap(x, ys) - vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None) + vote_outputs = platform(vote_prompt, n=n_evaluate_sample, stop=None) values = task.vote_outputs_unwrap(vote_outputs, len(ys)) return values def get_proposals(task, x, y): propose_prompt = task.propose_prompt_wrap(x, y) - proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n') + proposals = platform(propose_prompt, n=1, stop=None)[0].split('\n') return [y + _ + '\n' for _ in proposals] def get_samples(task, x, y, n_generate_sample, prompt_sample, stop): @@ -43,13 +51,13 @@ def get_samples(task, x, y, n_generate_sample, prompt_sample, stop): prompt = task.cot_prompt_wrap(x, y) else: raise ValueError(f'prompt_sample {prompt_sample} not recognized') - samples = gpt(prompt, n=n_generate_sample, stop=stop) + samples = platform(prompt, n=n_generate_sample, stop=stop) return [y + _ for _ in samples] def solve(args, task, idx, to_print=True): - global gpt - gpt = partial(gpt, model=args.backend, temperature=args.temperature) - print(gpt) + global platform + platform = partial(platform, model=args.backend, temperature=args.temperature) + print(platform) x = task.get_input(idx) # input ys = [''] # current output candidates infos = [] @@ -88,9 +96,9 @@ def solve(args, task, idx, to_print=True): return ys, {'steps': infos} def naive_solve(args, task, idx, to_print=True): - global gpt - gpt = partial(gpt, model=args.backend, temperature=args.temperature) - print(gpt) + global platform + platform = partial(platform, model=args.backend, temperature=args.temperature) + print(platform) x = task.get_input(idx) # input ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None) return ys, {} \ No newline at end of file diff --git a/src/tot/models.py b/src/tot/models.py index b3c4fe0..db91fdd 100644 --- a/src/tot/models.py +++ b/src/tot/models.py @@ -19,6 +19,9 @@ if api_base != "": def completions_with_backoff(**kwargs): return openai.ChatCompletion.create(**kwargs) +def completions(**kwargs): + return openai.ChatCompletion.create(**kwargs) + def gpt(prompt, model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None) -> list: messages = [{"role": "user", "content": prompt}] return chatgpt(messages, model=model, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop) @@ -43,3 +46,25 @@ def gpt_usage(backend="gpt-4"): elif backend == "gpt-3.5-turbo": cost = completion_tokens / 1000 * 0.002 + prompt_tokens / 1000 * 0.0015 return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens, "cost": cost} + +def groq(prompt, model="mixtral-8x7b-32768", temperature=0.5, max_tokens=1500, n=1, stop=None) -> list: + global completion_tokens, prompt_tokens + messages = [{"role": "user", "content": prompt}] + return groqgpt(messages, model=model, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop) + +def groqgpt(messages, model="mixtral-8x7b-32768", temperature=0.5, max_tokens=1500,n=1, stop=None) -> list: + global completion_tokens, prompt_tokens + outputs = [] + while n > 0: + cnt = min(n, 20) + n -= cnt + res = completions(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stop=stop) + outputs.extend([choice["message"]["content"] for choice in res["choices"]]) + # log completion tokens + completion_tokens += res["usage"]["completion_tokens"] + prompt_tokens += res["usage"]["prompt_tokens"] + return outputs + +def groq_usage(): + global completion_tokens, prompt_tokens + return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens} \ No newline at end of file diff --git a/src/tot/prompts/swe.py b/src/tot/prompts/swe.py index 16fa7e6..256b86f 100644 --- a/src/tot/prompts/swe.py +++ b/src/tot/prompts/swe.py @@ -1,18 +1,33 @@ standard_prompt = ''' -Given the problem statement of a github issue write a correct git patch to solve it. The problem statement is the following: {input} +{input} ''' -cot_prompt = ''' -Given the problem statement of a github issue write a correct git patch to solve it. -Make a plan then write. Your output should be of the following format: +cot_prompt = '''I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply. +The patch file should be in the unified diff format. Example: -Plan: -Your plan here. +```diff +diff --git a/file.py b/file.py +--- a/file.py ++++ b/file.py +@@ -1,27 +1,35 @@ + def euclidean(a, b): +- while b: +- a, b = b, a % b +- return a ++ if b == 0: ++ return a ++ return euclidean(b, a % b) +``` + +Given the problem statement of a github issue write a correct git patch to solve it. Patch: +```diff Your patch here. +``` -The problem statement is the following: {input} +The problem statement is the following: +{input} ''' @@ -22,5 +37,5 @@ vote_prompt = '''Given an instruction and several choices, decide which choice i compare_prompt = '''Briefly analyze the degree of correctness of the following two patches. Conclude in the last line "The more correct patch is 1", "The more correct patch is 2", or "The two patches are equally correct". ''' -score_prompt = '''Analyze the following patch, then at the last line conclude "Therefore the correctness score is {s}", where s is an integer from 1 to 10. +score_prompt = '''Analyze the following patch, then at the last line conclude "Therefore the correctness score is {s}", where {s} is an integer from 1 to 10. ''' \ No newline at end of file diff --git a/src/tot/tasks/swe.py b/src/tot/tasks/swe.py index eb163ff..5ff86ea 100644 --- a/src/tot/tasks/swe.py +++ b/src/tot/tasks/swe.py @@ -1,7 +1,7 @@ -import re +import re, os from tot.tasks.base import Task from tot.prompts.swe import * -from tot.models import gpt +from tot.models import gpt, groq class SWETask(Task): """ @@ -24,7 +24,11 @@ class SWETask(Task): def test_output(self, idx: int, output: str): output = output.split('Patch:\n')[-1] prompt = score_prompt + output - score_outputs = gpt(prompt, n=5, model='gpt-4') + api_base = os.getenv("OPENAI_API_BASE", "") + if api_base == 'https://api.groq.com/openai/v1': + score_output = groq(prompt, n=3, model='mixtral-8x7b-32768') + else: + score_outputs = gpt(prompt, n=5, model='gpt-4') scores = [] for score_output in score_outputs: print("score_output: ",score_output) @@ -89,3 +93,34 @@ class SWETask(Task): else: print(f'-----------------compare no match: {[compare_output]}') return -1 + + @staticmethod + def parse_diff_block(text: str): + """Extracts the first block of unified diff format. + + Args: + text (str): The large text containing one or more diff blocks. + + Returns: + str: The first block between `diff and ` if found, otherwise None. + """ + + start_pattern = r"```diff" + end_pattern = r"```" + + in_diff_block = False + diff_block = [] + + for line in text.splitlines(): + if start_pattern in line: + in_diff_block = True + continue # Skip the line with the start marker + + if in_diff_block: + if end_pattern in line: + in_diff_block = False + break # End of the diff block + else: + diff_block.append(line) + + return "\n".join(diff_block) if diff_block else None