mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-04-27 17:31:07 +00:00
Initial exploration of ToT with SWEBench
This commit is contained in:
parent
ab400345c5
commit
025fe57daf
64
env.yml
Normal file
64
env.yml
Normal file
@ -0,0 +1,64 @@
|
||||
name: tot
|
||||
channels:
|
||||
- defaults
|
||||
- conda-forge
|
||||
- pytorch
|
||||
dependencies:
|
||||
- async-timeout=4.0.2=py311hca03da5_0
|
||||
- attrs=23.1.0=py311hca03da5_0
|
||||
- backoff=2.2.1=py311hca03da5_1
|
||||
- blas=1.0=openblas
|
||||
- bottleneck=1.3.7=py311hb9f6ed7_0
|
||||
- brotli-python=1.0.9=py311h313beb8_7
|
||||
- bzip2=1.0.8=h80987f9_5
|
||||
- ca-certificates=2023.12.12=hca03da5_0
|
||||
- certifi=2023.5.7=py311hca03da5_0
|
||||
- charset-normalizer=2.0.4=pyhd3eb1b0_0
|
||||
- frozenlist=1.3.3=py311h80987f9_0
|
||||
- gmp=6.2.1=hc377ac9_3
|
||||
- gmpy2=2.1.2=py311h40f64dc_0
|
||||
- idna=3.4=py311hca03da5_0
|
||||
- libcxx=14.0.6=h848a8c0_0
|
||||
- libffi=3.4.4=hca03da5_0
|
||||
- libgfortran=5.0.0=11_3_0_hca03da5_28
|
||||
- libgfortran5=11.3.0=h009349e_28
|
||||
- libopenblas=0.3.21=h269037a_0
|
||||
- llvm-openmp=14.0.6=hc6e5704_0
|
||||
- mpc=1.1.0=h8c48613_1
|
||||
- mpfr=4.0.2=h695f6f0_1
|
||||
- mpmath=1.3.0=py311hca03da5_0
|
||||
- multidict=6.0.4=py311h80987f9_0
|
||||
- ncurses=6.4=h313beb8_0
|
||||
- numexpr=2.8.7=py311h6dc990b_0
|
||||
- numpy=1.24.3=py311hb57d4eb_0
|
||||
- numpy-base=1.24.3=py311h1d85a46_0
|
||||
- openssl=3.0.13=h1a28f6b_0
|
||||
- pandas=2.0.3=py311h7aedaa7_0
|
||||
- pip=23.3.1=py311hca03da5_0
|
||||
- pysocks=1.7.1=py311hca03da5_0
|
||||
- python=3.11.8=hb885b13_0
|
||||
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
||||
- python-tzdata=2023.3=pyhd3eb1b0_0
|
||||
- pytz=2023.3.post1=py311hca03da5_0
|
||||
- readline=8.2=h1a28f6b_0
|
||||
- requests=2.31.0=py311hca03da5_1
|
||||
- setuptools=68.2.2=py311hca03da5_0
|
||||
- six=1.16.0=pyhd3eb1b0_1
|
||||
- sqlite=3.41.2=h80987f9_0
|
||||
- sympy=1.12=py311hca03da5_0
|
||||
- tk=8.6.12=hb8d0fd4_0
|
||||
- tqdm=4.65.0=py311hb6e6a13_0
|
||||
- tzdata=2024a=h04d1e81_0
|
||||
- urllib3=2.1.0=py311hca03da5_1
|
||||
- wheel=0.41.2=py311hca03da5_0
|
||||
- xz=5.4.6=h80987f9_0
|
||||
- zlib=1.2.13=h5a0b063_0
|
||||
- pip
|
||||
- pip:
|
||||
- aiohttp==3.8.4
|
||||
- aiosignal==1.3.1
|
||||
- charset-normalizer==3.1.0
|
||||
- openai==0.27.7
|
||||
- urllib3==2.0.2
|
||||
- yarl==1.9.2
|
||||
prefix: /Users/ricardo/anaconda3/envs/tot
|
2
run.py
2
run.py
@ -45,7 +45,7 @@ def parse_args():
|
||||
args.add_argument('--backend', type=str, choices=['gpt-4', 'gpt-3.5-turbo'], default='gpt-4')
|
||||
args.add_argument('--temperature', type=float, default=0.7)
|
||||
|
||||
args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords'])
|
||||
args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords','swe'])
|
||||
args.add_argument('--task_start_index', type=int, default=900)
|
||||
args.add_argument('--task_end_index', type=int, default=1000)
|
||||
|
||||
|
35
src/tot/data/swe/dataset.py
Normal file
35
src/tot/data/swe/dataset.py
Normal file
@ -0,0 +1,35 @@
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
from datasets import load_dataset, Dataset
|
||||
|
||||
class SwebenchInstance(TypedDict):
|
||||
repo: str
|
||||
instance_id: str
|
||||
base_commit: str
|
||||
patch: str
|
||||
test_patch: str
|
||||
problem_statement: str
|
||||
hints_text: str
|
||||
created_at: str
|
||||
version: str
|
||||
FAIL_TO_PASS: str
|
||||
PASS_TO_PASS: str
|
||||
environment_setup_commit: str
|
||||
|
||||
def get_dataset() -> list[SwebenchInstance]:
|
||||
dataset = cast(Dataset, load_dataset("princeton-nlp/SWE-bench", split="dev+test"))
|
||||
return [cast(SwebenchInstance, instance) for instance in dataset]
|
||||
|
||||
def get_categories() -> list[str]:
|
||||
dataset = cast(Dataset, load_dataset("princeton-nlp/SWE-bench", split="dev+test"))
|
||||
return list(set([instance["category"] for instance in dataset]))
|
||||
|
||||
def main():
|
||||
swe_dataset = load_dataset("princeton-nlp/SWE-bench", split="dev+test")
|
||||
# print each key-value pair in the first instance
|
||||
for key, value in swe_dataset[0].items():
|
||||
print(f"{key}: {value}")
|
||||
# print a unique list of values for the key "repo"
|
||||
print(set([instance["repo"] for instance in swe_dataset]))
|
||||
# instances of the dataset which key "repo" is "pytest-dev/pytest"
|
||||
pytest_repo = [instance for instance in swe_dataset if instance["repo"] == "pytest-dev/pytest"]
|
25
src/tot/prompts/swe.py
Normal file
25
src/tot/prompts/swe.py
Normal file
@ -0,0 +1,25 @@
|
||||
standard_prompt = '''
|
||||
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
|
||||
'''
|
||||
|
||||
cot_prompt = '''
|
||||
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
|
||||
|
||||
Make a plan then write. Your output should be of the following format:
|
||||
|
||||
Plan:
|
||||
Your plan here.
|
||||
|
||||
Passage:
|
||||
Your passage here.
|
||||
'''
|
||||
|
||||
|
||||
vote_prompt = '''Given an instruction and several choices, decide which choice is most promising. Analyze each choice in detail, then conclude in the last line "The best choice is {s}", where s the integer id of the choice.
|
||||
'''
|
||||
|
||||
compare_prompt = '''Briefly analyze the coherency of the following two passages. Conclude in the last line "The more coherent passage is 1", "The more coherent passage is 2", or "The two passages are similarly coherent".
|
||||
'''
|
||||
|
||||
score_prompt = '''Analyze the following passage, then at the last line conclude "Thus the coherency score is {s}", where s is an integer from 1 to 10.
|
||||
'''
|
@ -8,5 +8,8 @@ def get_task(name):
|
||||
elif name == 'crosswords':
|
||||
from tot.tasks.crosswords import MiniCrosswordsTask
|
||||
return MiniCrosswordsTask()
|
||||
elif name == 'swe':
|
||||
from tot.tasks.swe import SWETask
|
||||
return SWETask()
|
||||
else:
|
||||
raise NotImplementedError
|
101
src/tot/tasks/swe.py
Normal file
101
src/tot/tasks/swe.py
Normal file
@ -0,0 +1,101 @@
|
||||
import os
|
||||
import re
|
||||
from tot.tasks.base import Task, DATA_PATH
|
||||
from tot.prompts.text import *
|
||||
from tot.models import gpt
|
||||
|
||||
|
||||
|
||||
|
||||
class SWETask(Task):
|
||||
"""
|
||||
Input (x) : a text instruction
|
||||
Output (y) : a text generation
|
||||
Reward (r) : # TODO
|
||||
Input Example:
|
||||
Output Example:
|
||||
"""
|
||||
def __init__(self, file='data_100_random_text.txt'):
|
||||
"""
|
||||
file: a text file, each line is some sentences
|
||||
"""
|
||||
super().__init__()
|
||||
path = os.path.join(DATA_PATH, 'text', file)
|
||||
self.data = open(path).readlines()
|
||||
self.steps = 2
|
||||
self.stops = ['\nPassage:\n', None]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get_input(self, idx: int) -> str:
|
||||
return self.data[idx]
|
||||
|
||||
def test_output(self, idx: int, output: str):
|
||||
output = output.split('Passage:\n')[-1]
|
||||
prompt = score_prompt + output
|
||||
score_outputs = gpt(prompt, n=5, model='gpt-4')
|
||||
scores = []
|
||||
for score_output in score_outputs:
|
||||
# print(score_output)
|
||||
pattern = r".*coherency score is (\d+).*"
|
||||
match = re.match(pattern, score_output, re.DOTALL)
|
||||
if match:
|
||||
score = int(match.groups()[0])
|
||||
scores.append(score)
|
||||
else:
|
||||
print(f'------------------score no match: {[score_output]}')
|
||||
print(scores)
|
||||
# print('------------')
|
||||
info = {'rs': scores, 'r': sum(scores) / len(scores) if scores else 0}
|
||||
return info
|
||||
|
||||
@staticmethod
|
||||
def standard_prompt_wrap(x: str, y:str='') -> str:
|
||||
return standard_prompt.format(input=x) + y
|
||||
|
||||
@staticmethod
|
||||
def cot_prompt_wrap(x: str, y:str='') -> str:
|
||||
return cot_prompt.format(input=x) + y
|
||||
|
||||
@staticmethod
|
||||
def vote_prompt_wrap(x: str, ys: list) -> str:
|
||||
prompt = vote_prompt
|
||||
for i, y in enumerate(ys, 1):
|
||||
# y = y.replace('Plan:\n', '')
|
||||
# TODO: truncate the plan part?
|
||||
prompt += f'Choice {i}:\n{y}\n'
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def vote_outputs_unwrap(vote_outputs: list, n_candidates: int) -> list:
|
||||
vote_results = [0] * n_candidates
|
||||
for vote_output in vote_outputs:
|
||||
pattern = r".*best choice is .*(\d+).*"
|
||||
match = re.match(pattern, vote_output, re.DOTALL)
|
||||
if match:
|
||||
vote = int(match.groups()[0]) - 1
|
||||
if vote in range(n_candidates):
|
||||
vote_results[vote] += 1
|
||||
else:
|
||||
print(f'vote no match: {[vote_output]}')
|
||||
return vote_results
|
||||
|
||||
@staticmethod
|
||||
def compare_prompt_wrap(x: str, ys: list) -> str:
|
||||
assert len(ys) == 2, 'compare prompt only supports 2 candidates'
|
||||
ys = [y.split('Passage:\n')[-1] for y in ys]
|
||||
prompt = compare_prompt + f'Passage 1:\n{ys[0]}\n\nPassage 2:\n{ys[1]}\n'
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def compare_output_unwrap(compare_output: str):
|
||||
if 'more coherent passage is 1' in compare_output:
|
||||
return 0
|
||||
elif 'more coherent passage is 2' in compare_output:
|
||||
return 1
|
||||
elif 'two passages are similarly coherent' in compare_output:
|
||||
return 0.5
|
||||
else:
|
||||
print(f'-----------------compare no match: {[compare_output]}')
|
||||
return -1
|
19
tasks/game24.py
Normal file
19
tasks/game24.py
Normal file
@ -0,0 +1,19 @@
|
||||
import argparse
|
||||
from tot.methods.bfs import solve
|
||||
from tot.tasks.game24 import Game24Task
|
||||
|
||||
args = argparse.Namespace(backend='gpt-4',
|
||||
temperature=0.7,
|
||||
task='game24',
|
||||
naive_run=False,
|
||||
prompt_sample=None,
|
||||
method_generate='propose',
|
||||
method_evaluate='value',
|
||||
method_select='greedy',
|
||||
n_generate_sample=1,
|
||||
n_evaluate_sample=3,
|
||||
n_select_sample=5)
|
||||
|
||||
task = Game24Task()
|
||||
ys, infos = solve(args=args, task=task, idx=900)
|
||||
print(ys[0])
|
19
tasks/swe-bench.py
Normal file
19
tasks/swe-bench.py
Normal file
@ -0,0 +1,19 @@
|
||||
import argparse
|
||||
from tot.methods.bfs import solve
|
||||
from tot.tasks.swe import SWETask
|
||||
|
||||
args = argparse.Namespace(backend='gpt-4',
|
||||
temperature=0.7,
|
||||
task='swe',
|
||||
naive_run=False,
|
||||
prompt_sample=None,
|
||||
method_generate='propose',
|
||||
method_evaluate='vote',
|
||||
method_select='greedy',
|
||||
n_generate_sample=1,
|
||||
n_evaluate_sample=3,
|
||||
n_select_sample=5)
|
||||
|
||||
task = SWETask()
|
||||
ys, infos = solve(args=args, task=task, idx=1)
|
||||
print(ys[0])
|
19
tasks/text.py
Normal file
19
tasks/text.py
Normal file
@ -0,0 +1,19 @@
|
||||
import argparse
|
||||
from tot.methods.bfs import solve
|
||||
from tot.tasks.text import TextTask
|
||||
|
||||
args = argparse.Namespace(backend='gpt-4',
|
||||
temperature=1.0,
|
||||
task='text',
|
||||
naive_run=False,
|
||||
prompt_sample='cot',
|
||||
method_generate='sample',
|
||||
method_evaluate='vote',
|
||||
method_select='greedy',
|
||||
n_generate_sample=1,
|
||||
n_evaluate_sample=3,
|
||||
n_select_sample=5)
|
||||
|
||||
task = TextTask()
|
||||
ys, infos = solve(args,task,0)
|
||||
print(ys[0])
|
Loading…
Reference in New Issue
Block a user