mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-05-10 06:20:33 +00:00
SWE ToT final
This commit is contained in:
parent
c1ada54223
commit
1f82624d3e
@ -2,9 +2,35 @@ import argparse
|
|||||||
from tot.methods.bfs import solve
|
from tot.methods.bfs import solve
|
||||||
from tot.tasks.swe import SWETask
|
from tot.tasks.swe import SWETask
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
print("Downloading dataset...")
|
print("Downloading dataset...")
|
||||||
train_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split = "dev", cache_dir='datasets_cache')
|
train_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split = "test", cache_dir='datasets_cache')
|
||||||
|
preds_path = "preds.jsonl"
|
||||||
|
try:
|
||||||
|
with open(preds_path, "r") as file:
|
||||||
|
preds_jsonl = file.read()
|
||||||
|
except FileNotFoundError:
|
||||||
|
preds_jsonl = ""
|
||||||
|
|
||||||
|
def update_jsonl(instance_id, model_patch, model_name_or_path, jsonl_object):
|
||||||
|
data = [json.loads(line) for line in jsonl_object.splitlines()]
|
||||||
|
|
||||||
|
new_obj = {
|
||||||
|
"instance_id": instance_id,
|
||||||
|
"model_patch": model_patch,
|
||||||
|
"model_name_or_path": model_name_or_path
|
||||||
|
}
|
||||||
|
|
||||||
|
data.append(new_obj)
|
||||||
|
updated_jsonl = "\n".join([json.dumps(obj) for obj in data])
|
||||||
|
|
||||||
|
return updated_jsonl
|
||||||
|
|
||||||
|
def save_jsonl(jsonl_object, file_path="preds.jsonl"):
|
||||||
|
with open(file_path, "w") as file:
|
||||||
|
file.write(jsonl_object) if jsonl_object.strip() else file.write("")
|
||||||
|
|
||||||
args = argparse.Namespace(
|
args = argparse.Namespace(
|
||||||
backend='mixtral-8x7b-32768',
|
backend='mixtral-8x7b-32768',
|
||||||
@ -21,10 +47,10 @@ args = argparse.Namespace(
|
|||||||
|
|
||||||
print("Solving...")
|
print("Solving...")
|
||||||
task = SWETask(train_dataset)
|
task = SWETask(train_dataset)
|
||||||
i = 10
|
|
||||||
ys, infos = solve(args, task, i, to_print=False)
|
|
||||||
print("Solution:")
|
|
||||||
print(SWETask.parse_diff_block(ys[0]))
|
|
||||||
|
|
||||||
print("Expected solution:")
|
for index in range(110,150):
|
||||||
print(train_dataset[i]["patch"])
|
ys, infos = solve(args, task, index, to_print=False)
|
||||||
|
preds_jsonl = update_jsonl(train_dataset[index]["instance_id"], SWETask.parse_diff_block(ys[0]), args.backend, preds_jsonl)
|
||||||
|
save_jsonl(preds_jsonl, preds_path)
|
||||||
|
print(f"Solution {index} done.")
|
||||||
|
time.sleep(60)
|
||||||
|
@ -2,7 +2,13 @@ standard_prompt = '''
|
|||||||
{input}
|
{input}
|
||||||
'''
|
'''
|
||||||
|
|
||||||
cot_prompt = '''I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply.
|
cot_prompt = '''Given the Repository url, Base commit and Problem statement of a github issue. Please write a correct git patch to solve it.
|
||||||
|
The output must have this format:
|
||||||
|
Patch:
|
||||||
|
```diff
|
||||||
|
Your patch here.
|
||||||
|
```
|
||||||
|
|
||||||
The patch file should be in the unified diff format. Example:
|
The patch file should be in the unified diff format. Example:
|
||||||
|
|
||||||
```diff
|
```diff
|
||||||
@ -19,14 +25,6 @@ diff --git a/file.py b/file.py
|
|||||||
+ return euclidean(b, a % b)
|
+ return euclidean(b, a % b)
|
||||||
```
|
```
|
||||||
|
|
||||||
Given the problem statement of a github issue write a correct git patch to solve it.
|
|
||||||
|
|
||||||
Patch:
|
|
||||||
```diff
|
|
||||||
Your patch here.
|
|
||||||
```
|
|
||||||
|
|
||||||
The problem statement is the following:
|
|
||||||
{input}
|
{input}
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
@ -19,14 +19,20 @@ class SWETask(Task):
|
|||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
def get_input(self, idx: int) -> str:
|
def get_input(self, idx: int) -> str:
|
||||||
return self.data[idx]["problem_statement"]
|
swe_prompt = '''
|
||||||
|
Repository url: https://github.com/{repo}
|
||||||
|
Base commit: {base_commit}
|
||||||
|
Problem statement:
|
||||||
|
{problem_statement}
|
||||||
|
'''
|
||||||
|
return swe_prompt.format(repo=self.data[idx]['repo'], base_commit=self.data[idx]['base_commit'], problem_statement=self.data[idx]['problem_statement'])
|
||||||
|
|
||||||
def test_output(self, idx: int, output: str):
|
def test_output(self, idx: int, output: str):
|
||||||
output = output.split('Patch:\n')[-1]
|
output = output.split('Patch:\n')[-1]
|
||||||
prompt = score_prompt + output
|
prompt = score_prompt + output
|
||||||
api_base = os.getenv("OPENAI_API_BASE", "")
|
api_base = os.getenv("OPENAI_API_BASE", "")
|
||||||
if api_base == 'https://api.groq.com/openai/v1':
|
if api_base == 'https://api.groq.com/openai/v1':
|
||||||
score_output = groq(prompt, n=3, model='mixtral-8x7b-32768')
|
score_output = groq(prompt, n=5, model='mixtral-8x7b-32768')
|
||||||
else:
|
else:
|
||||||
score_outputs = gpt(prompt, n=5, model='gpt-4')
|
score_outputs = gpt(prompt, n=5, model='gpt-4')
|
||||||
scores = []
|
scores = []
|
||||||
|
Loading…
Reference in New Issue
Block a user