tot package

This commit is contained in:
ysymyth 2023-07-03 22:16:03 -04:00
parent 7382f2416e
commit 733b009f62
33 changed files with 1579 additions and 1502 deletions

4
MANIFEST.in Normal file
View File

@ -0,0 +1,4 @@
include src/tot/data/24/24.csv
include src/tot/data/crosswords/mini0505_0_100_5.json
include src/tot/data/crosswords/mini0505.json
include src/tot/data/text/data_100_random_text.txt

View File

Before

Width:  |  Height:  |  Size: 84 KiB

After

Width:  |  Height:  |  Size: 84 KiB

View File

Before

Width:  |  Height:  |  Size: 99 KiB

After

Width:  |  Height:  |  Size: 99 KiB

35
pyproject.toml Normal file
View File

@ -0,0 +1,35 @@
[build-system]
requires = ["setuptools >= 61.0.0"]
build-backend = "setuptools.build_meta"
[project]
name = "tot"
version = "0.1.0"
description = 'Official Implementation of "Tree of Thoughts: Deliberate Problem Solving with Large Language Models"'
readme = "README.md"
requires-python = ">= 3.7"
authors = [{ name = "Shunyu Yao", email = "shunyuyao.cs@gmail.com" }]
license = { text = "MIT License" }
keywords = ["tree-search", "large-language-models", "llm", "prompting", "tree-of-thoughts"]
classifiers = [
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
'Intended Audience :: Science/Research',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
]
dynamic=["dependencies"]
[tool.setuptools.dynamic]
dependencies = {file = ["requirements.txt"]}
[tool.setuptools.packages.find]
where = ["src"] # list of folders that contain the packages (["."] by default)
[project.urls]
Homepage = "https://github.com/princeton-nlp/tree-of-thought-llm"

View File

@ -4,7 +4,7 @@
<details> <details>
<summary>Note: https://github.com/kyegomez/tree-of-thoughts is NOT the correct implementation to replicate paper results. </summary> <summary>Note: https://github.com/kyegomez/tree-of-thoughts is NOT the correct implementation to replicate paper results. </summary>
In fact, people have reported that his code [cannot](https://github.com/kyegomez/tree-of-thoughts/issues/52) [properly](https://github.com/kyegomez/tree-of-thoughts/issues/41) [run](https://github.com/kyegomez/tree-of-thoughts/issues/60), and is [probably automatically generated by ChatGPT](fake.png), and [kyegomez has done so for other popular ML methods](https://twitter.com/qbitium/status/1663954096741814272), while intentionally refusing to link to official implementations for his own interests (see https://github.com/kyegomez/tree-of-thoughts/issues/54, https://github.com/kyegomez/tree-of-thoughts/issues/55, https://github.com/kyegomez/tree-of-thoughts/issues/56). In fact, people have reported that his code [cannot](https://github.com/kyegomez/tree-of-thoughts/issues/52) [properly](https://github.com/kyegomez/tree-of-thoughts/issues/41) [run](https://github.com/kyegomez/tree-of-thoughts/issues/60), and is [probably automatically generated by ChatGPT](pics/fake.png), and [kyegomez has done so for other popular ML methods](https://twitter.com/qbitium/status/1663954096741814272), while intentionally refusing to link to official implementations for his own interests (see https://github.com/kyegomez/tree-of-thoughts/issues/54, https://github.com/kyegomez/tree-of-thoughts/issues/55, https://github.com/kyegomez/tree-of-thoughts/issues/56).
Unfortunately, Google/Github searches go to kyegomez's malicious repo by default as it has more stars. **Please DE-STAR his repo and STAR this to help other people avoid being misled, thanks!** Unfortunately, Google/Github searches go to kyegomez's malicious repo by default as it has more stars. **Please DE-STAR his repo and STAR this to help other people avoid being misled, thanks!**
</details> </details>
@ -12,7 +12,7 @@ Unfortunately, Google/Github searches go to kyegomez's malicious repo by default
![teaser](teaser.png) ![teaser](pics/teaser.png)
Official implementation for paper [Tree of Thoughts: Deliberate Problem Solving with Large Language Models](https://arxiv.org/abs/2305.10601) with code, prompts, model outputs. Official implementation for paper [Tree of Thoughts: Deliberate Problem Solving with Large Language Models](https://arxiv.org/abs/2305.10601) with code, prompts, model outputs.
Also check [its tweet thread](https://twitter.com/ShunyuYao12/status/1659357547474681857) in 1min. Also check [its tweet thread](https://twitter.com/ShunyuYao12/status/1659357547474681857) in 1min.

View File

@ -16,3 +16,4 @@ sympy==1.12
tqdm==4.65.0 tqdm==4.65.0
urllib3==2.0.2 urllib3==2.0.2
yarl==1.9.2 yarl==1.9.2
pandas==2.0.3

105
run.py
View File

@ -1,108 +1,18 @@
import os import os
import json import json
import itertools
import argparse import argparse
import numpy as np
from functools import partial
from models import gpt, gpt_usage
from tasks import get_task
def get_value(task, x, y, n_evaluate_sample, cache_value=True): from tot.tasks import get_task
value_prompt = task.value_prompt_wrap(x, y) from tot.methods.bfs import solve, naive_solve
if cache_value and value_prompt in task.value_cache: from tot.models import gpt_usage
return task.value_cache[value_prompt]
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
value = task.value_outputs_unwrap(x, y, value_outputs)
if cache_value:
task.value_cache[value_prompt] = value
return value
def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
values = []
local_value_cache = {}
for y in ys: # each partial output
if y in local_value_cache: # avoid duplicate candidates
value = 0
else:
value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
local_value_cache[y] = value
values.append(value)
return values
def get_votes(task, x, ys, n_evaluate_sample):
vote_prompt = task.vote_prompt_wrap(x, ys)
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
return values
def get_proposals(task, x, y):
propose_prompt = task.propose_prompt_wrap(x, y)
proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
return [y + _ + '\n' for _ in proposals]
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
if prompt_sample == 'standard':
prompt = task.standard_prompt_wrap(x, y)
elif prompt_sample == 'cot':
prompt = task.cot_prompt_wrap(x, y)
else:
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
samples = gpt(prompt, n=n_generate_sample, stop=stop)
return [y + _ for _ in samples]
def solve(args, task, idx, to_print=True):
print(gpt)
x = task.get_input(idx) # input
ys = [''] # current output candidates
infos = []
for step in range(task.steps):
# generation
if args.method_generate == 'sample':
new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
elif args.method_generate == 'propose':
new_ys = [get_proposals(task, x, y) for y in ys]
new_ys = list(itertools.chain(*new_ys))
ids = list(range(len(new_ys)))
# evaluation
if args.method_evaluate == 'vote':
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
elif args.method_evaluate == 'value':
values = get_values(task, x, new_ys, args.n_evaluate_sample)
# selection
if args.method_select == 'sample':
ps = np.array(values) / sum(values)
select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
elif args.method_select == 'greedy':
select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
select_new_ys = [new_ys[select_id] for select_id in select_ids]
# log
if to_print:
sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
ys = select_new_ys
if to_print:
print(ys)
return ys, {'steps': infos}
def naive_solve(args, task, idx, to_print=True):
x = task.get_input(idx) # input
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
return ys, {}
def run(args): def run(args):
task = get_task(args.task, args.task_file_path) task = get_task(args.task)
logs, cnt_avg, cnt_any = [], 0, 0 logs, cnt_avg, cnt_any = [], 0, 0
global gpt
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
if args.naive_run: if args.naive_run:
file = f'logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json' file = f'./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
else: else:
file = f'logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json' file = f'./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
os.makedirs(os.path.dirname(file), exist_ok=True) os.makedirs(os.path.dirname(file), exist_ok=True)
for i in range(args.task_start_index, args.task_end_index): for i in range(args.task_start_index, args.task_end_index):
@ -136,7 +46,6 @@ def parse_args():
args.add_argument('--temperature', type=float, default=0.7) args.add_argument('--temperature', type=float, default=0.7)
args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords']) args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords'])
args.add_argument('--task_file_path', type=str, required=True)
args.add_argument('--task_start_index', type=int, default=900) args.add_argument('--task_start_index', type=int, default=900)
args.add_argument('--task_end_index', type=int, default=1000) args.add_argument('--task_end_index', type=int, default=1000)
@ -145,7 +54,7 @@ def parse_args():
args.add_argument('--method_generate', type=str, choices=['sample', 'propose']) args.add_argument('--method_generate', type=str, choices=['sample', 'propose'])
args.add_argument('--method_evaluate', type=str, choices=['value', 'vote']) args.add_argument('--method_evaluate', type=str, choices=['value', 'vote'])
args.add_argument('--method_select', type=str, choices=['sample', 'greedy']) args.add_argument('--method_select', type=str, choices=['sample', 'greedy'], default='greedy')
args.add_argument('--n_generate_sample', type=int, default=1) # only thing needed if naive_run args.add_argument('--n_generate_sample', type=int, default=1) # only thing needed if naive_run
args.add_argument('--n_evaluate_sample', type=int, default=1) args.add_argument('--n_evaluate_sample', type=int, default=1)
args.add_argument('--n_select_sample', type=int, default=1) args.add_argument('--n_select_sample', type=int, default=1)

View File

@ -1,6 +1,5 @@
python run.py \ python run.py \
--task crosswords \ --task crosswords \
--task_file_path mini0505_0_100_5.json \
--task_start_index 0 \ --task_start_index 0 \
--task_end_index 20 \ --task_end_index 20 \
--naive_run \ --naive_run \

View File

@ -14,7 +14,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"cd ../.." "cd .."
] ]
}, },
{ {
@ -24,9 +24,9 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"import json\n", "import json\n",
"from prompts.crosswords import propose_prompt, value_prompt\n", "from tot.prompts.crosswords import propose_prompt, value_prompt\n",
"from models import gpt\n", "from tot.models import gpt\n",
"from tasks.crosswords import MiniCrosswordsEnv\n", "from tot.tasks.crosswords import MiniCrosswordsEnv\n",
"\n", "\n",
"env = MiniCrosswordsEnv()" "env = MiniCrosswordsEnv()"
] ]
@ -61,7 +61,7 @@
"source": [ "source": [
"import re\n", "import re\n",
"import copy\n", "import copy\n",
"from models import gpt\n", "from tot.models import gpt\n",
"\n", "\n",
"def parse_line(input_str):\n", "def parse_line(input_str):\n",
" # regular expression pattern to match the input string format\n", " # regular expression pattern to match the input string format\n",

View File

@ -1,6 +1,5 @@
python run.py \ python run.py \
--task crosswords \ --task crosswords \
--task_file_path mini0505_0_100_5.json \
--task_start_index 0 \ --task_start_index 0 \
--task_end_index 20 \ --task_end_index 20 \
--naive_run \ --naive_run \

View File

@ -1,6 +1,5 @@
python run.py \ python run.py \
--task game24 \ --task game24 \
--task_file_path 24.csv \
--task_start_index 900 \ --task_start_index 900 \
--task_end_index 1000 \ --task_end_index 1000 \
--method_generate propose \ --method_generate propose \

View File

@ -1,6 +1,5 @@
python run.py \ python run.py \
--task game24 \ --task game24 \
--task_file_path 24.csv \
--task_start_index 900 \ --task_start_index 900 \
--task_end_index 1000 \ --task_end_index 1000 \
--naive_run \ --naive_run \

View File

@ -1,6 +1,5 @@
python run.py \ python run.py \
--task game24 \ --task game24 \
--task_file_path 24.csv \
--task_start_index 900 \ --task_start_index 900 \
--task_end_index 1000 \ --task_end_index 1000 \
--naive_run \ --naive_run \

View File

@ -1,8 +1,7 @@
python run.py \ python run.py \
--task text \ --task text \
--task_file_path data_100_random_text.txt \
--task_start_index 0 \ --task_start_index 0 \
--task_end_index 1 \ --task_end_index 100 \
--method_generate sample \ --method_generate sample \
--method_evaluate vote \ --method_evaluate vote \
--method_select greedy \ --method_select greedy \

View File

@ -1,8 +1,7 @@
python run.py \ python run.py \
--task text \ --task text \
--task_file_path data_100_random_text.txt \
--task_start_index 0 \ --task_start_index 0 \
--task_end_index 1 \ --task_end_index 100 \
--naive_run \ --naive_run \
--prompt_sample cot \ --prompt_sample cot \
--n_generate_sample 10 \ --n_generate_sample 10 \

View File

@ -1,8 +1,7 @@
python run.py \ python run.py \
--task text \ --task text \
--task_file_path data_100_random_text.txt \
--task_start_index 0 \ --task_start_index 0 \
--task_end_index 1 \ --task_end_index 100 \
--naive_run \ --naive_run \
--prompt_sample standard \ --prompt_sample standard \
--n_generate_sample 10 \ --n_generate_sample 10 \

37
setup.py Normal file
View File

@ -0,0 +1,37 @@
import setuptools
with open('README.md', 'r', encoding='utf-8') as fh:
long_description = fh.read()
setuptools.setup(
name='tot',
author='Shunyu Yao',
author_email='shunyuyao.cs@gmail.com',
description='Official Implementation of "Tree of Thoughts: Deliberate Problem Solving with Large Language Models"',
keywords='tree-search, large-language-models, llm, prompting, tree-of-thoughts',
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/princeton-nlp/tree-of-thought-llm',
project_urls={
'Homepage': 'https://github.com/princeton-nlp/tree-of-thought-llm',
},
package_dir={'': 'src'},
packages=setuptools.find_packages(where='src'),
classifiers=[
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
'Intended Audience :: Science/Research',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
],
python_requires='>=3.7',
install_requires=[
'setuptools',
],
include_package_data=True,
)

1
src/tot/__init__.py Normal file
View File

@ -0,0 +1 @@
__version__ = "0.1.0"

96
src/tot/methods/bfs.py Normal file
View File

@ -0,0 +1,96 @@
import itertools
import numpy as np
from functools import partial
from tot.models import gpt
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
value_prompt = task.value_prompt_wrap(x, y)
if cache_value and value_prompt in task.value_cache:
return task.value_cache[value_prompt]
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
value = task.value_outputs_unwrap(x, y, value_outputs)
if cache_value:
task.value_cache[value_prompt] = value
return value
def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
values = []
local_value_cache = {}
for y in ys: # each partial output
if y in local_value_cache: # avoid duplicate candidates
value = 0
else:
value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
local_value_cache[y] = value
values.append(value)
return values
def get_votes(task, x, ys, n_evaluate_sample):
vote_prompt = task.vote_prompt_wrap(x, ys)
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
return values
def get_proposals(task, x, y):
propose_prompt = task.propose_prompt_wrap(x, y)
proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
return [y + _ + '\n' for _ in proposals]
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
if prompt_sample == 'standard':
prompt = task.standard_prompt_wrap(x, y)
elif prompt_sample == 'cot':
prompt = task.cot_prompt_wrap(x, y)
else:
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
samples = gpt(prompt, n=n_generate_sample, stop=stop)
return [y + _ for _ in samples]
def solve(args, task, idx, to_print=True):
global gpt
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
print(gpt)
x = task.get_input(idx) # input
ys = [''] # current output candidates
infos = []
for step in range(task.steps):
# generation
if args.method_generate == 'sample':
new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
elif args.method_generate == 'propose':
new_ys = [get_proposals(task, x, y) for y in ys]
new_ys = list(itertools.chain(*new_ys))
ids = list(range(len(new_ys)))
# evaluation
if args.method_evaluate == 'vote':
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
elif args.method_evaluate == 'value':
values = get_values(task, x, new_ys, args.n_evaluate_sample)
# selection
if args.method_select == 'sample':
ps = np.array(values) / sum(values)
select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
elif args.method_select == 'greedy':
select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
select_new_ys = [new_ys[select_id] for select_id in select_ids]
# log
if to_print:
sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
ys = select_new_ys
if to_print:
print(ys)
return ys, {'steps': infos}
def naive_solve(args, task, idx, to_print=True):
global gpt
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
print(gpt)
x = task.get_input(idx) # input
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
return ys, {}

View File

@ -41,5 +41,5 @@ def gpt_usage(backend="gpt-4"):
if backend == "gpt-4": if backend == "gpt-4":
cost = completion_tokens / 1000 * 0.06 + prompt_tokens / 1000 * 0.03 cost = completion_tokens / 1000 * 0.06 + prompt_tokens / 1000 * 0.03
elif backend == "gpt-3.5-turbo": elif backend == "gpt-3.5-turbo":
cost = completion_tokens / 1000 * 0.002 + prompt_tokens / 1000 * 0.0015 cost = (completion_tokens + prompt_tokens) / 1000 * 0.0002
return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens, "cost": cost} return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens, "cost": cost}

12
src/tot/tasks/__init__.py Normal file
View File

@ -0,0 +1,12 @@
def get_task(name):
if name == 'game24':
from tot.tasks.game24 import Game24Task
return Game24Task()
elif name == 'text':
from tot.tasks.text import TextTask
return TextTask()
elif name == 'crosswords':
from tot.tasks.crosswords import MiniCrosswordsTask
return MiniCrosswordsTask()
else:
raise NotImplementedError

View File

@ -1,4 +1,5 @@
DATA_PATH = './data' import os
DATA_PATH = os.path.join(os.path.dirname(__file__), '..', 'data')
class Task: class Task:
def __init__(self): def __init__(self):

View File

@ -1,13 +1,14 @@
import re import re
import json
import os import os
from tasks.base import Task, DATA_PATH import json
from prompts.crosswords import * from tot.tasks.base import Task, DATA_PATH
from models import gpt from tot.prompts.crosswords import *
from tot.models import gpt
class MiniCrosswordsEnv: class MiniCrosswordsEnv:
def __init__(self, file='mini0505.json'): def __init__(self, file='mini0505.json'):
self.file = f'data/crosswords/{file}' self.file = os.path.join(DATA_PATH, 'crosswords', file)
self.file = json.load(open(self.file)) self.file = json.load(open(self.file))
self.n = len(self.file) self.n = len(self.file)
self.cache = {} self.cache = {}

View File

@ -2,8 +2,8 @@ import re
import os import os
import sympy import sympy
import pandas as pd import pandas as pd
from tasks.base import Task, DATA_PATH from tot.tasks.base import Task, DATA_PATH
from prompts.game24 import * from tot.prompts.game24 import *
def get_current_numbers(y: str) -> str: def get_current_numbers(y: str) -> str:

View File

@ -1,8 +1,8 @@
import os import os
import re import re
from tasks.base import Task, DATA_PATH from tot.tasks.base import Task, DATA_PATH
from prompts.text import * from tot.prompts.text import *
from models import gpt from tot.models import gpt
class TextTask(Task): class TextTask(Task):

View File

@ -1,12 +0,0 @@
def get_task(name, file=None):
if name == 'game24':
from .game24 import Game24Task
return Game24Task(file)
elif name == 'text':
from .text import TextTask
return TextTask(file)
elif name == 'crosswords':
from .crosswords import MiniCrosswordsTask
return MiniCrosswordsTask(file)
else:
raise NotImplementedError