SWE ToT final

This commit is contained in:
ricardo-larosa 2024-04-07 14:54:52 +02:00
parent c1ada54223
commit 1f82624d3e
3 changed files with 48 additions and 18 deletions

View File

@ -2,9 +2,35 @@ import argparse
from tot.methods.bfs import solve
from tot.tasks.swe import SWETask
from datasets import load_dataset
import json
import time
print("Downloading dataset...")
train_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split = "dev", cache_dir='datasets_cache')
train_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split = "test", cache_dir='datasets_cache')
preds_path = "preds.jsonl"
try:
with open(preds_path, "r") as file:
preds_jsonl = file.read()
except FileNotFoundError:
preds_jsonl = ""
def update_jsonl(instance_id, model_patch, model_name_or_path, jsonl_object):
data = [json.loads(line) for line in jsonl_object.splitlines()]
new_obj = {
"instance_id": instance_id,
"model_patch": model_patch,
"model_name_or_path": model_name_or_path
}
data.append(new_obj)
updated_jsonl = "\n".join([json.dumps(obj) for obj in data])
return updated_jsonl
def save_jsonl(jsonl_object, file_path="preds.jsonl"):
with open(file_path, "w") as file:
file.write(jsonl_object) if jsonl_object.strip() else file.write("")
args = argparse.Namespace(
backend='mixtral-8x7b-32768',
@ -21,10 +47,10 @@ args = argparse.Namespace(
print("Solving...")
task = SWETask(train_dataset)
i = 10
ys, infos = solve(args, task, i, to_print=False)
print("Solution:")
print(SWETask.parse_diff_block(ys[0]))
print("Expected solution:")
print(train_dataset[i]["patch"])
for index in range(110,150):
ys, infos = solve(args, task, index, to_print=False)
preds_jsonl = update_jsonl(train_dataset[index]["instance_id"], SWETask.parse_diff_block(ys[0]), args.backend, preds_jsonl)
save_jsonl(preds_jsonl, preds_path)
print(f"Solution {index} done.")
time.sleep(60)

View File

@ -2,7 +2,13 @@ standard_prompt = '''
{input}
'''
cot_prompt = '''I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply.
cot_prompt = '''Given the Repository url, Base commit and Problem statement of a github issue. Please write a correct git patch to solve it.
The output must have this format:
Patch:
```diff
Your patch here.
```
The patch file should be in the unified diff format. Example:
```diff
@ -19,14 +25,6 @@ diff --git a/file.py b/file.py
+ return euclidean(b, a % b)
```
Given the problem statement of a github issue write a correct git patch to solve it.
Patch:
```diff
Your patch here.
```
The problem statement is the following:
{input}
'''

View File

@ -19,14 +19,20 @@ class SWETask(Task):
return len(self.data)
def get_input(self, idx: int) -> str:
return self.data[idx]["problem_statement"]
swe_prompt = '''
Repository url: https://github.com/{repo}
Base commit: {base_commit}
Problem statement:
{problem_statement}
'''
return swe_prompt.format(repo=self.data[idx]['repo'], base_commit=self.data[idx]['base_commit'], problem_statement=self.data[idx]['problem_statement'])
def test_output(self, idx: int, output: str):
output = output.split('Patch:\n')[-1]
prompt = score_prompt + output
api_base = os.getenv("OPENAI_API_BASE", "")
if api_base == 'https://api.groq.com/openai/v1':
score_output = groq(prompt, n=3, model='mixtral-8x7b-32768')
score_output = groq(prompt, n=5, model='mixtral-8x7b-32768')
else:
score_outputs = gpt(prompt, n=5, model='gpt-4')
scores = []