Initial exploration of ToT with SWEBench

This commit is contained in:
ricardo-larosa 2024-03-23 15:15:16 +01:00
parent ab400345c5
commit 025fe57daf
9 changed files with 286 additions and 1 deletions

64
env.yml Normal file
View File

@ -0,0 +1,64 @@
name: tot
channels:
- defaults
- conda-forge
- pytorch
dependencies:
- async-timeout=4.0.2=py311hca03da5_0
- attrs=23.1.0=py311hca03da5_0
- backoff=2.2.1=py311hca03da5_1
- blas=1.0=openblas
- bottleneck=1.3.7=py311hb9f6ed7_0
- brotli-python=1.0.9=py311h313beb8_7
- bzip2=1.0.8=h80987f9_5
- ca-certificates=2023.12.12=hca03da5_0
- certifi=2023.5.7=py311hca03da5_0
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- frozenlist=1.3.3=py311h80987f9_0
- gmp=6.2.1=hc377ac9_3
- gmpy2=2.1.2=py311h40f64dc_0
- idna=3.4=py311hca03da5_0
- libcxx=14.0.6=h848a8c0_0
- libffi=3.4.4=hca03da5_0
- libgfortran=5.0.0=11_3_0_hca03da5_28
- libgfortran5=11.3.0=h009349e_28
- libopenblas=0.3.21=h269037a_0
- llvm-openmp=14.0.6=hc6e5704_0
- mpc=1.1.0=h8c48613_1
- mpfr=4.0.2=h695f6f0_1
- mpmath=1.3.0=py311hca03da5_0
- multidict=6.0.4=py311h80987f9_0
- ncurses=6.4=h313beb8_0
- numexpr=2.8.7=py311h6dc990b_0
- numpy=1.24.3=py311hb57d4eb_0
- numpy-base=1.24.3=py311h1d85a46_0
- openssl=3.0.13=h1a28f6b_0
- pandas=2.0.3=py311h7aedaa7_0
- pip=23.3.1=py311hca03da5_0
- pysocks=1.7.1=py311hca03da5_0
- python=3.11.8=hb885b13_0
- python-dateutil=2.8.2=pyhd3eb1b0_0
- python-tzdata=2023.3=pyhd3eb1b0_0
- pytz=2023.3.post1=py311hca03da5_0
- readline=8.2=h1a28f6b_0
- requests=2.31.0=py311hca03da5_1
- setuptools=68.2.2=py311hca03da5_0
- six=1.16.0=pyhd3eb1b0_1
- sqlite=3.41.2=h80987f9_0
- sympy=1.12=py311hca03da5_0
- tk=8.6.12=hb8d0fd4_0
- tqdm=4.65.0=py311hb6e6a13_0
- tzdata=2024a=h04d1e81_0
- urllib3=2.1.0=py311hca03da5_1
- wheel=0.41.2=py311hca03da5_0
- xz=5.4.6=h80987f9_0
- zlib=1.2.13=h5a0b063_0
- pip
- pip:
- aiohttp==3.8.4
- aiosignal==1.3.1
- charset-normalizer==3.1.0
- openai==0.27.7
- urllib3==2.0.2
- yarl==1.9.2
prefix: /Users/ricardo/anaconda3/envs/tot

2
run.py
View File

@ -45,7 +45,7 @@ def parse_args():
args.add_argument('--backend', type=str, choices=['gpt-4', 'gpt-3.5-turbo'], default='gpt-4')
args.add_argument('--temperature', type=float, default=0.7)
args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords'])
args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords','swe'])
args.add_argument('--task_start_index', type=int, default=900)
args.add_argument('--task_end_index', type=int, default=1000)

View File

@ -0,0 +1,35 @@
from typing import TypedDict
from typing import cast
from datasets import load_dataset, Dataset
class SwebenchInstance(TypedDict):
repo: str
instance_id: str
base_commit: str
patch: str
test_patch: str
problem_statement: str
hints_text: str
created_at: str
version: str
FAIL_TO_PASS: str
PASS_TO_PASS: str
environment_setup_commit: str
def get_dataset() -> list[SwebenchInstance]:
dataset = cast(Dataset, load_dataset("princeton-nlp/SWE-bench", split="dev+test"))
return [cast(SwebenchInstance, instance) for instance in dataset]
def get_categories() -> list[str]:
dataset = cast(Dataset, load_dataset("princeton-nlp/SWE-bench", split="dev+test"))
return list(set([instance["category"] for instance in dataset]))
def main():
swe_dataset = load_dataset("princeton-nlp/SWE-bench", split="dev+test")
# print each key-value pair in the first instance
for key, value in swe_dataset[0].items():
print(f"{key}: {value}")
# print a unique list of values for the key "repo"
print(set([instance["repo"] for instance in swe_dataset]))
# instances of the dataset which key "repo" is "pytest-dev/pytest"
pytest_repo = [instance for instance in swe_dataset if instance["repo"] == "pytest-dev/pytest"]

25
src/tot/prompts/swe.py Normal file
View File

@ -0,0 +1,25 @@
standard_prompt = '''
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
'''
cot_prompt = '''
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
Make a plan then write. Your output should be of the following format:
Plan:
Your plan here.
Passage:
Your passage here.
'''
vote_prompt = '''Given an instruction and several choices, decide which choice is most promising. Analyze each choice in detail, then conclude in the last line "The best choice is {s}", where s the integer id of the choice.
'''
compare_prompt = '''Briefly analyze the coherency of the following two passages. Conclude in the last line "The more coherent passage is 1", "The more coherent passage is 2", or "The two passages are similarly coherent".
'''
score_prompt = '''Analyze the following passage, then at the last line conclude "Thus the coherency score is {s}", where s is an integer from 1 to 10.
'''

View File

@ -8,5 +8,8 @@ def get_task(name):
elif name == 'crosswords':
from tot.tasks.crosswords import MiniCrosswordsTask
return MiniCrosswordsTask()
elif name == 'swe':
from tot.tasks.swe import SWETask
return SWETask()
else:
raise NotImplementedError

101
src/tot/tasks/swe.py Normal file
View File

@ -0,0 +1,101 @@
import os
import re
from tot.tasks.base import Task, DATA_PATH
from tot.prompts.text import *
from tot.models import gpt
class SWETask(Task):
"""
Input (x) : a text instruction
Output (y) : a text generation
Reward (r) : # TODO
Input Example:
Output Example:
"""
def __init__(self, file='data_100_random_text.txt'):
"""
file: a text file, each line is some sentences
"""
super().__init__()
path = os.path.join(DATA_PATH, 'text', file)
self.data = open(path).readlines()
self.steps = 2
self.stops = ['\nPassage:\n', None]
def __len__(self) -> int:
return len(self.data)
def get_input(self, idx: int) -> str:
return self.data[idx]
def test_output(self, idx: int, output: str):
output = output.split('Passage:\n')[-1]
prompt = score_prompt + output
score_outputs = gpt(prompt, n=5, model='gpt-4')
scores = []
for score_output in score_outputs:
# print(score_output)
pattern = r".*coherency score is (\d+).*"
match = re.match(pattern, score_output, re.DOTALL)
if match:
score = int(match.groups()[0])
scores.append(score)
else:
print(f'------------------score no match: {[score_output]}')
print(scores)
# print('------------')
info = {'rs': scores, 'r': sum(scores) / len(scores) if scores else 0}
return info
@staticmethod
def standard_prompt_wrap(x: str, y:str='') -> str:
return standard_prompt.format(input=x) + y
@staticmethod
def cot_prompt_wrap(x: str, y:str='') -> str:
return cot_prompt.format(input=x) + y
@staticmethod
def vote_prompt_wrap(x: str, ys: list) -> str:
prompt = vote_prompt
for i, y in enumerate(ys, 1):
# y = y.replace('Plan:\n', '')
# TODO: truncate the plan part?
prompt += f'Choice {i}:\n{y}\n'
return prompt
@staticmethod
def vote_outputs_unwrap(vote_outputs: list, n_candidates: int) -> list:
vote_results = [0] * n_candidates
for vote_output in vote_outputs:
pattern = r".*best choice is .*(\d+).*"
match = re.match(pattern, vote_output, re.DOTALL)
if match:
vote = int(match.groups()[0]) - 1
if vote in range(n_candidates):
vote_results[vote] += 1
else:
print(f'vote no match: {[vote_output]}')
return vote_results
@staticmethod
def compare_prompt_wrap(x: str, ys: list) -> str:
assert len(ys) == 2, 'compare prompt only supports 2 candidates'
ys = [y.split('Passage:\n')[-1] for y in ys]
prompt = compare_prompt + f'Passage 1:\n{ys[0]}\n\nPassage 2:\n{ys[1]}\n'
return prompt
@staticmethod
def compare_output_unwrap(compare_output: str):
if 'more coherent passage is 1' in compare_output:
return 0
elif 'more coherent passage is 2' in compare_output:
return 1
elif 'two passages are similarly coherent' in compare_output:
return 0.5
else:
print(f'-----------------compare no match: {[compare_output]}')
return -1

19
tasks/game24.py Normal file
View File

@ -0,0 +1,19 @@
import argparse
from tot.methods.bfs import solve
from tot.tasks.game24 import Game24Task
args = argparse.Namespace(backend='gpt-4',
temperature=0.7,
task='game24',
naive_run=False,
prompt_sample=None,
method_generate='propose',
method_evaluate='value',
method_select='greedy',
n_generate_sample=1,
n_evaluate_sample=3,
n_select_sample=5)
task = Game24Task()
ys, infos = solve(args=args, task=task, idx=900)
print(ys[0])

19
tasks/swe-bench.py Normal file
View File

@ -0,0 +1,19 @@
import argparse
from tot.methods.bfs import solve
from tot.tasks.swe import SWETask
args = argparse.Namespace(backend='gpt-4',
temperature=0.7,
task='swe',
naive_run=False,
prompt_sample=None,
method_generate='propose',
method_evaluate='vote',
method_select='greedy',
n_generate_sample=1,
n_evaluate_sample=3,
n_select_sample=5)
task = SWETask()
ys, infos = solve(args=args, task=task, idx=1)
print(ys[0])

19
tasks/text.py Normal file
View File

@ -0,0 +1,19 @@
import argparse
from tot.methods.bfs import solve
from tot.tasks.text import TextTask
args = argparse.Namespace(backend='gpt-4',
temperature=1.0,
task='text',
naive_run=False,
prompt_sample='cot',
method_generate='sample',
method_evaluate='vote',
method_select='greedy',
n_generate_sample=1,
n_evaluate_sample=3,
n_select_sample=5)
task = TextTask()
ys, infos = solve(args,task,0)
print(ys[0])