mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-05-08 13:44:22 +00:00
SWE ToT final
This commit is contained in:
parent
c1ada54223
commit
1f82624d3e
@ -2,9 +2,35 @@ import argparse
|
||||
from tot.methods.bfs import solve
|
||||
from tot.tasks.swe import SWETask
|
||||
from datasets import load_dataset
|
||||
import json
|
||||
import time
|
||||
|
||||
print("Downloading dataset...")
|
||||
train_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split = "dev", cache_dir='datasets_cache')
|
||||
train_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split = "test", cache_dir='datasets_cache')
|
||||
preds_path = "preds.jsonl"
|
||||
try:
|
||||
with open(preds_path, "r") as file:
|
||||
preds_jsonl = file.read()
|
||||
except FileNotFoundError:
|
||||
preds_jsonl = ""
|
||||
|
||||
def update_jsonl(instance_id, model_patch, model_name_or_path, jsonl_object):
|
||||
data = [json.loads(line) for line in jsonl_object.splitlines()]
|
||||
|
||||
new_obj = {
|
||||
"instance_id": instance_id,
|
||||
"model_patch": model_patch,
|
||||
"model_name_or_path": model_name_or_path
|
||||
}
|
||||
|
||||
data.append(new_obj)
|
||||
updated_jsonl = "\n".join([json.dumps(obj) for obj in data])
|
||||
|
||||
return updated_jsonl
|
||||
|
||||
def save_jsonl(jsonl_object, file_path="preds.jsonl"):
|
||||
with open(file_path, "w") as file:
|
||||
file.write(jsonl_object) if jsonl_object.strip() else file.write("")
|
||||
|
||||
args = argparse.Namespace(
|
||||
backend='mixtral-8x7b-32768',
|
||||
@ -21,10 +47,10 @@ args = argparse.Namespace(
|
||||
|
||||
print("Solving...")
|
||||
task = SWETask(train_dataset)
|
||||
i = 10
|
||||
ys, infos = solve(args, task, i, to_print=False)
|
||||
print("Solution:")
|
||||
print(SWETask.parse_diff_block(ys[0]))
|
||||
|
||||
print("Expected solution:")
|
||||
print(train_dataset[i]["patch"])
|
||||
for index in range(110,150):
|
||||
ys, infos = solve(args, task, index, to_print=False)
|
||||
preds_jsonl = update_jsonl(train_dataset[index]["instance_id"], SWETask.parse_diff_block(ys[0]), args.backend, preds_jsonl)
|
||||
save_jsonl(preds_jsonl, preds_path)
|
||||
print(f"Solution {index} done.")
|
||||
time.sleep(60)
|
||||
|
@ -2,7 +2,13 @@ standard_prompt = '''
|
||||
{input}
|
||||
'''
|
||||
|
||||
cot_prompt = '''I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply.
|
||||
cot_prompt = '''Given the Repository url, Base commit and Problem statement of a github issue. Please write a correct git patch to solve it.
|
||||
The output must have this format:
|
||||
Patch:
|
||||
```diff
|
||||
Your patch here.
|
||||
```
|
||||
|
||||
The patch file should be in the unified diff format. Example:
|
||||
|
||||
```diff
|
||||
@ -19,14 +25,6 @@ diff --git a/file.py b/file.py
|
||||
+ return euclidean(b, a % b)
|
||||
```
|
||||
|
||||
Given the problem statement of a github issue write a correct git patch to solve it.
|
||||
|
||||
Patch:
|
||||
```diff
|
||||
Your patch here.
|
||||
```
|
||||
|
||||
The problem statement is the following:
|
||||
{input}
|
||||
'''
|
||||
|
||||
|
@ -19,14 +19,20 @@ class SWETask(Task):
|
||||
return len(self.data)
|
||||
|
||||
def get_input(self, idx: int) -> str:
|
||||
return self.data[idx]["problem_statement"]
|
||||
swe_prompt = '''
|
||||
Repository url: https://github.com/{repo}
|
||||
Base commit: {base_commit}
|
||||
Problem statement:
|
||||
{problem_statement}
|
||||
'''
|
||||
return swe_prompt.format(repo=self.data[idx]['repo'], base_commit=self.data[idx]['base_commit'], problem_statement=self.data[idx]['problem_statement'])
|
||||
|
||||
def test_output(self, idx: int, output: str):
|
||||
output = output.split('Patch:\n')[-1]
|
||||
prompt = score_prompt + output
|
||||
api_base = os.getenv("OPENAI_API_BASE", "")
|
||||
if api_base == 'https://api.groq.com/openai/v1':
|
||||
score_output = groq(prompt, n=3, model='mixtral-8x7b-32768')
|
||||
score_output = groq(prompt, n=5, model='mixtral-8x7b-32768')
|
||||
else:
|
||||
score_outputs = gpt(prompt, n=5, model='gpt-4')
|
||||
scores = []
|
||||
|
Loading…
Reference in New Issue
Block a user