diff --git a/runs/swe-bench.py b/runs/swe-bench.py index 32a797b..bc813e4 100644 --- a/runs/swe-bench.py +++ b/runs/swe-bench.py @@ -2,9 +2,35 @@ import argparse from tot.methods.bfs import solve from tot.tasks.swe import SWETask from datasets import load_dataset +import json +import time print("Downloading dataset...") -train_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split = "dev", cache_dir='datasets_cache') +train_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split = "test", cache_dir='datasets_cache') +preds_path = "preds.jsonl" +try: + with open(preds_path, "r") as file: + preds_jsonl = file.read() +except FileNotFoundError: + preds_jsonl = "" + +def update_jsonl(instance_id, model_patch, model_name_or_path, jsonl_object): + data = [json.loads(line) for line in jsonl_object.splitlines()] + + new_obj = { + "instance_id": instance_id, + "model_patch": model_patch, + "model_name_or_path": model_name_or_path + } + + data.append(new_obj) + updated_jsonl = "\n".join([json.dumps(obj) for obj in data]) + + return updated_jsonl + +def save_jsonl(jsonl_object, file_path="preds.jsonl"): + with open(file_path, "w") as file: + file.write(jsonl_object) if jsonl_object.strip() else file.write("") args = argparse.Namespace( backend='mixtral-8x7b-32768', @@ -21,10 +47,10 @@ args = argparse.Namespace( print("Solving...") task = SWETask(train_dataset) -i = 10 -ys, infos = solve(args, task, i, to_print=False) -print("Solution:") -print(SWETask.parse_diff_block(ys[0])) -print("Expected solution:") -print(train_dataset[i]["patch"]) \ No newline at end of file +for index in range(110,150): + ys, infos = solve(args, task, index, to_print=False) + preds_jsonl = update_jsonl(train_dataset[index]["instance_id"], SWETask.parse_diff_block(ys[0]), args.backend, preds_jsonl) + save_jsonl(preds_jsonl, preds_path) + print(f"Solution {index} done.") + time.sleep(60) diff --git a/src/tot/prompts/swe.py b/src/tot/prompts/swe.py index 256b86f..1f6dbbb 100644 --- a/src/tot/prompts/swe.py +++ b/src/tot/prompts/swe.py @@ -2,7 +2,13 @@ standard_prompt = ''' {input} ''' -cot_prompt = '''I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply. +cot_prompt = '''Given the Repository url, Base commit and Problem statement of a github issue. Please write a correct git patch to solve it. +The output must have this format: +Patch: +```diff +Your patch here. +``` + The patch file should be in the unified diff format. Example: ```diff @@ -19,14 +25,6 @@ diff --git a/file.py b/file.py + return euclidean(b, a % b) ``` -Given the problem statement of a github issue write a correct git patch to solve it. - -Patch: -```diff -Your patch here. -``` - -The problem statement is the following: {input} ''' diff --git a/src/tot/tasks/swe.py b/src/tot/tasks/swe.py index 5ff86ea..2abc375 100644 --- a/src/tot/tasks/swe.py +++ b/src/tot/tasks/swe.py @@ -19,14 +19,20 @@ class SWETask(Task): return len(self.data) def get_input(self, idx: int) -> str: - return self.data[idx]["problem_statement"] + swe_prompt = ''' +Repository url: https://github.com/{repo} +Base commit: {base_commit} +Problem statement: +{problem_statement} +''' + return swe_prompt.format(repo=self.data[idx]['repo'], base_commit=self.data[idx]['base_commit'], problem_statement=self.data[idx]['problem_statement']) def test_output(self, idx: int, output: str): output = output.split('Patch:\n')[-1] prompt = score_prompt + output api_base = os.getenv("OPENAI_API_BASE", "") if api_base == 'https://api.groq.com/openai/v1': - score_output = groq(prompt, n=3, model='mixtral-8x7b-32768') + score_output = groq(prompt, n=5, model='mixtral-8x7b-32768') else: score_outputs = gpt(prompt, n=5, model='gpt-4') scores = []