First running version of ToT with SWE-bench

This commit is contained in:
ricardo-larosa 2024-04-03 15:36:36 +02:00
parent f26589b0c0
commit a0fe552731
6 changed files with 27 additions and 27 deletions

View File

@ -61,4 +61,5 @@ dependencies:
- openai==0.27.7
- urllib3==2.0.2
- yarl==1.9.2
- datasets
prefix: /Users/ricardo/anaconda3/envs/tot

View File

@ -3,6 +3,7 @@ from tot.methods.bfs import solve
from tot.tasks.swe import SWETask
from datasets import load_dataset
train_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split = "test")
args = argparse.Namespace(
@ -10,8 +11,8 @@ args = argparse.Namespace(
temperature=0.1,
task='swe',
naive_run=False,
prompt_sample=None,
method_generate='propose',
prompt_sample='cot',
method_generate='sample',
method_evaluate='vote',
method_select='greedy',
n_generate_sample=1,

View File

@ -1,25 +1,26 @@
standard_prompt = '''
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
Given the problem statement of a github issue write a correct git patch to solve it. The problem statement is the following: {input}
'''
cot_prompt = '''
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
Given the problem statement of a github issue write a correct git patch to solve it.
Make a plan then write. Your output should be of the following format:
Plan:
Your plan here.
Passage:
Your passage here.
Patch:
Your patch here.
The problem statement is the following: {input}
'''
vote_prompt = '''Given an instruction and several choices, decide which choice is most promising. Analyze each choice in detail, then conclude in the last line "The best choice is {s}", where s the integer id of the choice.
vote_prompt = '''Given an instruction and several choices, decide which choice is most promising. Analyze each choice in detail, then conclude in the last line "The best choice is {s}", where {s} the integer id of the choice.
'''
compare_prompt = '''Briefly analyze the coherency of the following two passages. Conclude in the last line "The more coherent passage is 1", "The more coherent passage is 2", or "The two passages are similarly coherent".
compare_prompt = '''Briefly analyze the degree of correctness of the following two patches. Conclude in the last line "The more correct patch is 1", "The more correct patch is 2", or "The two patches are equally correct".
'''
score_prompt = '''Analyze the following passage, then at the last line conclude "Thus the coherency score is {s}", where s is an integer from 1 to 10.
score_prompt = '''Analyze the following patch, then at the last line conclude "Therefore the correctness score is {s}", where s is an integer from 1 to 10.
'''

View File

@ -1,37 +1,34 @@
import re
from tot.tasks.base import Task
from tot.prompts.text import *
from tot.prompts.swe import *
from tot.models import gpt
class SWETask(Task):
"""
Input (x) : a task instruction
Output (y) : a text generation
Input (x) : a problem statement
Output (y) : a patch generation
Reward (r) : # TODO
"""
def __init__(self, dataset):
"""
file: a text file, each line is some sentences
"""
super().__init__()
self.data = dataset
self.steps = 2
self.stops = ['\nPassage:\n', None]
self.stops = ['\nPatch:\n', None]
def __len__(self) -> int:
return len(self.data)
def get_input(self, idx: int) -> str:
return self.data[idx]
return self.data[idx]["problem_statement"]
def test_output(self, idx: int, output: str):
output = output.split('Passage:\n')[-1]
output = output.split('Patch:\n')[-1]
prompt = score_prompt + output
score_outputs = gpt(prompt, n=5, model='gpt-4')
scores = []
for score_output in score_outputs:
# print(score_output)
pattern = r".*coherency score is (\d+).*"
print("score_output: ",score_output)
pattern = r".*correctness score is (\d+).*"
match = re.match(pattern, score_output, re.DOTALL)
if match:
score = int(match.groups()[0])
@ -39,7 +36,7 @@ class SWETask(Task):
else:
print(f'------------------score no match: {[score_output]}')
print(scores)
# print('------------')
print('------------')
info = {'rs': scores, 'r': sum(scores) / len(scores) if scores else 0}
return info
@ -77,17 +74,17 @@ class SWETask(Task):
@staticmethod
def compare_prompt_wrap(x: str, ys: list) -> str:
assert len(ys) == 2, 'compare prompt only supports 2 candidates'
ys = [y.split('Passage:\n')[-1] for y in ys]
prompt = compare_prompt + f'Passage 1:\n{ys[0]}\n\nPassage 2:\n{ys[1]}\n'
ys = [y.split('Patch:\n')[-1] for y in ys]
prompt = compare_prompt + f'Patch: 1:\n{ys[0]}\n\nPatch: 2:\n{ys[1]}\n'
return prompt
@staticmethod
def compare_output_unwrap(compare_output: str):
if 'more coherent passage is 1' in compare_output:
if 'more correct patch is 1' in compare_output:
return 0
elif 'more coherent passage is 2' in compare_output:
elif 'more correct patch is 2' in compare_output:
return 1
elif 'two passages are similarly coherent' in compare_output:
elif 'two patches are equally correct' in compare_output:
return 0.5
else:
print(f'-----------------compare no match: {[compare_output]}')