mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-06-26 18:26:00 +00:00
First running version of ToT with SWE-bench
This commit is contained in:
parent
f26589b0c0
commit
a0fe552731
1
env.yml
1
env.yml
@ -61,4 +61,5 @@ dependencies:
|
||||
- openai==0.27.7
|
||||
- urllib3==2.0.2
|
||||
- yarl==1.9.2
|
||||
- datasets
|
||||
prefix: /Users/ricardo/anaconda3/envs/tot
|
||||
|
@ -3,6 +3,7 @@ from tot.methods.bfs import solve
|
||||
from tot.tasks.swe import SWETask
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
train_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split = "test")
|
||||
|
||||
args = argparse.Namespace(
|
||||
@ -10,8 +11,8 @@ args = argparse.Namespace(
|
||||
temperature=0.1,
|
||||
task='swe',
|
||||
naive_run=False,
|
||||
prompt_sample=None,
|
||||
method_generate='propose',
|
||||
prompt_sample='cot',
|
||||
method_generate='sample',
|
||||
method_evaluate='vote',
|
||||
method_select='greedy',
|
||||
n_generate_sample=1,
|
@ -1,25 +1,26 @@
|
||||
standard_prompt = '''
|
||||
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
|
||||
Given the problem statement of a github issue write a correct git patch to solve it. The problem statement is the following: {input}
|
||||
'''
|
||||
|
||||
cot_prompt = '''
|
||||
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
|
||||
|
||||
Given the problem statement of a github issue write a correct git patch to solve it.
|
||||
Make a plan then write. Your output should be of the following format:
|
||||
|
||||
Plan:
|
||||
Your plan here.
|
||||
|
||||
Passage:
|
||||
Your passage here.
|
||||
Patch:
|
||||
Your patch here.
|
||||
|
||||
The problem statement is the following: {input}
|
||||
'''
|
||||
|
||||
|
||||
vote_prompt = '''Given an instruction and several choices, decide which choice is most promising. Analyze each choice in detail, then conclude in the last line "The best choice is {s}", where s the integer id of the choice.
|
||||
vote_prompt = '''Given an instruction and several choices, decide which choice is most promising. Analyze each choice in detail, then conclude in the last line "The best choice is {s}", where {s} the integer id of the choice.
|
||||
'''
|
||||
|
||||
compare_prompt = '''Briefly analyze the coherency of the following two passages. Conclude in the last line "The more coherent passage is 1", "The more coherent passage is 2", or "The two passages are similarly coherent".
|
||||
compare_prompt = '''Briefly analyze the degree of correctness of the following two patches. Conclude in the last line "The more correct patch is 1", "The more correct patch is 2", or "The two patches are equally correct".
|
||||
'''
|
||||
|
||||
score_prompt = '''Analyze the following passage, then at the last line conclude "Thus the coherency score is {s}", where s is an integer from 1 to 10.
|
||||
score_prompt = '''Analyze the following patch, then at the last line conclude "Therefore the correctness score is {s}", where s is an integer from 1 to 10.
|
||||
'''
|
@ -1,37 +1,34 @@
|
||||
import re
|
||||
from tot.tasks.base import Task
|
||||
from tot.prompts.text import *
|
||||
from tot.prompts.swe import *
|
||||
from tot.models import gpt
|
||||
|
||||
class SWETask(Task):
|
||||
"""
|
||||
Input (x) : a task instruction
|
||||
Output (y) : a text generation
|
||||
Input (x) : a problem statement
|
||||
Output (y) : a patch generation
|
||||
Reward (r) : # TODO
|
||||
"""
|
||||
def __init__(self, dataset):
|
||||
"""
|
||||
file: a text file, each line is some sentences
|
||||
"""
|
||||
super().__init__()
|
||||
self.data = dataset
|
||||
self.steps = 2
|
||||
self.stops = ['\nPassage:\n', None]
|
||||
self.stops = ['\nPatch:\n', None]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get_input(self, idx: int) -> str:
|
||||
return self.data[idx]
|
||||
return self.data[idx]["problem_statement"]
|
||||
|
||||
def test_output(self, idx: int, output: str):
|
||||
output = output.split('Passage:\n')[-1]
|
||||
output = output.split('Patch:\n')[-1]
|
||||
prompt = score_prompt + output
|
||||
score_outputs = gpt(prompt, n=5, model='gpt-4')
|
||||
scores = []
|
||||
for score_output in score_outputs:
|
||||
# print(score_output)
|
||||
pattern = r".*coherency score is (\d+).*"
|
||||
print("score_output: ",score_output)
|
||||
pattern = r".*correctness score is (\d+).*"
|
||||
match = re.match(pattern, score_output, re.DOTALL)
|
||||
if match:
|
||||
score = int(match.groups()[0])
|
||||
@ -39,7 +36,7 @@ class SWETask(Task):
|
||||
else:
|
||||
print(f'------------------score no match: {[score_output]}')
|
||||
print(scores)
|
||||
# print('------------')
|
||||
print('------------')
|
||||
info = {'rs': scores, 'r': sum(scores) / len(scores) if scores else 0}
|
||||
return info
|
||||
|
||||
@ -77,17 +74,17 @@ class SWETask(Task):
|
||||
@staticmethod
|
||||
def compare_prompt_wrap(x: str, ys: list) -> str:
|
||||
assert len(ys) == 2, 'compare prompt only supports 2 candidates'
|
||||
ys = [y.split('Passage:\n')[-1] for y in ys]
|
||||
prompt = compare_prompt + f'Passage 1:\n{ys[0]}\n\nPassage 2:\n{ys[1]}\n'
|
||||
ys = [y.split('Patch:\n')[-1] for y in ys]
|
||||
prompt = compare_prompt + f'Patch: 1:\n{ys[0]}\n\nPatch: 2:\n{ys[1]}\n'
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def compare_output_unwrap(compare_output: str):
|
||||
if 'more coherent passage is 1' in compare_output:
|
||||
if 'more correct patch is 1' in compare_output:
|
||||
return 0
|
||||
elif 'more coherent passage is 2' in compare_output:
|
||||
elif 'more correct patch is 2' in compare_output:
|
||||
return 1
|
||||
elif 'two passages are similarly coherent' in compare_output:
|
||||
elif 'two patches are equally correct' in compare_output:
|
||||
return 0.5
|
||||
else:
|
||||
print(f'-----------------compare no match: {[compare_output]}')
|
||||
|
Loading…
Reference in New Issue
Block a user