mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2024-11-21 15:47:51 +00:00
tot package
This commit is contained in:
parent
7382f2416e
commit
733b009f62
4
MANIFEST.in
Normal file
4
MANIFEST.in
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
include src/tot/data/24/24.csv
|
||||||
|
include src/tot/data/crosswords/mini0505_0_100_5.json
|
||||||
|
include src/tot/data/crosswords/mini0505.json
|
||||||
|
include src/tot/data/text/data_100_random_text.txt
|
Before Width: | Height: | Size: 84 KiB After Width: | Height: | Size: 84 KiB |
Before Width: | Height: | Size: 99 KiB After Width: | Height: | Size: 99 KiB |
35
pyproject.toml
Normal file
35
pyproject.toml
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools >= 61.0.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "tot"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = 'Official Implementation of "Tree of Thoughts: Deliberate Problem Solving with Large Language Models"'
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">= 3.7"
|
||||||
|
authors = [{ name = "Shunyu Yao", email = "shunyuyao.cs@gmail.com" }]
|
||||||
|
license = { text = "MIT License" }
|
||||||
|
keywords = ["tree-search", "large-language-models", "llm", "prompting", "tree-of-thoughts"]
|
||||||
|
classifiers = [
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.7",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
'Intended Audience :: Science/Research',
|
||||||
|
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||||
|
]
|
||||||
|
dynamic=["dependencies"]
|
||||||
|
|
||||||
|
|
||||||
|
[tool.setuptools.dynamic]
|
||||||
|
dependencies = {file = ["requirements.txt"]}
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["src"] # list of folders that contain the packages (["."] by default)
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://github.com/princeton-nlp/tree-of-thought-llm"
|
@ -4,7 +4,7 @@
|
|||||||
<details>
|
<details>
|
||||||
<summary>Note: https://github.com/kyegomez/tree-of-thoughts is NOT the correct implementation to replicate paper results. </summary>
|
<summary>Note: https://github.com/kyegomez/tree-of-thoughts is NOT the correct implementation to replicate paper results. </summary>
|
||||||
|
|
||||||
In fact, people have reported that his code [cannot](https://github.com/kyegomez/tree-of-thoughts/issues/52) [properly](https://github.com/kyegomez/tree-of-thoughts/issues/41) [run](https://github.com/kyegomez/tree-of-thoughts/issues/60), and is [probably automatically generated by ChatGPT](fake.png), and [kyegomez has done so for other popular ML methods](https://twitter.com/qbitium/status/1663954096741814272), while intentionally refusing to link to official implementations for his own interests (see https://github.com/kyegomez/tree-of-thoughts/issues/54, https://github.com/kyegomez/tree-of-thoughts/issues/55, https://github.com/kyegomez/tree-of-thoughts/issues/56).
|
In fact, people have reported that his code [cannot](https://github.com/kyegomez/tree-of-thoughts/issues/52) [properly](https://github.com/kyegomez/tree-of-thoughts/issues/41) [run](https://github.com/kyegomez/tree-of-thoughts/issues/60), and is [probably automatically generated by ChatGPT](pics/fake.png), and [kyegomez has done so for other popular ML methods](https://twitter.com/qbitium/status/1663954096741814272), while intentionally refusing to link to official implementations for his own interests (see https://github.com/kyegomez/tree-of-thoughts/issues/54, https://github.com/kyegomez/tree-of-thoughts/issues/55, https://github.com/kyegomez/tree-of-thoughts/issues/56).
|
||||||
Unfortunately, Google/Github searches go to kyegomez's malicious repo by default as it has more stars. **Please DE-STAR his repo and STAR this to help other people avoid being misled, thanks!**
|
Unfortunately, Google/Github searches go to kyegomez's malicious repo by default as it has more stars. **Please DE-STAR his repo and STAR this to help other people avoid being misled, thanks!**
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@ -12,7 +12,7 @@ Unfortunately, Google/Github searches go to kyegomez's malicious repo by default
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
![teaser](teaser.png)
|
![teaser](pics/teaser.png)
|
||||||
|
|
||||||
Official implementation for paper [Tree of Thoughts: Deliberate Problem Solving with Large Language Models](https://arxiv.org/abs/2305.10601) with code, prompts, model outputs.
|
Official implementation for paper [Tree of Thoughts: Deliberate Problem Solving with Large Language Models](https://arxiv.org/abs/2305.10601) with code, prompts, model outputs.
|
||||||
Also check [its tweet thread](https://twitter.com/ShunyuYao12/status/1659357547474681857) in 1min.
|
Also check [its tweet thread](https://twitter.com/ShunyuYao12/status/1659357547474681857) in 1min.
|
||||||
|
@ -16,3 +16,4 @@ sympy==1.12
|
|||||||
tqdm==4.65.0
|
tqdm==4.65.0
|
||||||
urllib3==2.0.2
|
urllib3==2.0.2
|
||||||
yarl==1.9.2
|
yarl==1.9.2
|
||||||
|
pandas==2.0.3
|
105
run.py
105
run.py
@ -1,108 +1,18 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import itertools
|
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
|
||||||
from functools import partial
|
|
||||||
from models import gpt, gpt_usage
|
|
||||||
from tasks import get_task
|
|
||||||
|
|
||||||
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
|
from tot.tasks import get_task
|
||||||
value_prompt = task.value_prompt_wrap(x, y)
|
from tot.methods.bfs import solve, naive_solve
|
||||||
if cache_value and value_prompt in task.value_cache:
|
from tot.models import gpt_usage
|
||||||
return task.value_cache[value_prompt]
|
|
||||||
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
|
|
||||||
value = task.value_outputs_unwrap(x, y, value_outputs)
|
|
||||||
if cache_value:
|
|
||||||
task.value_cache[value_prompt] = value
|
|
||||||
return value
|
|
||||||
|
|
||||||
def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
|
|
||||||
values = []
|
|
||||||
local_value_cache = {}
|
|
||||||
for y in ys: # each partial output
|
|
||||||
if y in local_value_cache: # avoid duplicate candidates
|
|
||||||
value = 0
|
|
||||||
else:
|
|
||||||
value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
|
|
||||||
local_value_cache[y] = value
|
|
||||||
values.append(value)
|
|
||||||
return values
|
|
||||||
|
|
||||||
def get_votes(task, x, ys, n_evaluate_sample):
|
|
||||||
vote_prompt = task.vote_prompt_wrap(x, ys)
|
|
||||||
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
|
|
||||||
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
|
|
||||||
return values
|
|
||||||
|
|
||||||
def get_proposals(task, x, y):
|
|
||||||
propose_prompt = task.propose_prompt_wrap(x, y)
|
|
||||||
proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
|
|
||||||
return [y + _ + '\n' for _ in proposals]
|
|
||||||
|
|
||||||
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
|
|
||||||
if prompt_sample == 'standard':
|
|
||||||
prompt = task.standard_prompt_wrap(x, y)
|
|
||||||
elif prompt_sample == 'cot':
|
|
||||||
prompt = task.cot_prompt_wrap(x, y)
|
|
||||||
else:
|
|
||||||
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
|
|
||||||
samples = gpt(prompt, n=n_generate_sample, stop=stop)
|
|
||||||
return [y + _ for _ in samples]
|
|
||||||
|
|
||||||
def solve(args, task, idx, to_print=True):
|
|
||||||
print(gpt)
|
|
||||||
x = task.get_input(idx) # input
|
|
||||||
ys = [''] # current output candidates
|
|
||||||
infos = []
|
|
||||||
for step in range(task.steps):
|
|
||||||
# generation
|
|
||||||
if args.method_generate == 'sample':
|
|
||||||
new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
|
|
||||||
elif args.method_generate == 'propose':
|
|
||||||
new_ys = [get_proposals(task, x, y) for y in ys]
|
|
||||||
new_ys = list(itertools.chain(*new_ys))
|
|
||||||
ids = list(range(len(new_ys)))
|
|
||||||
# evaluation
|
|
||||||
if args.method_evaluate == 'vote':
|
|
||||||
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
|
|
||||||
elif args.method_evaluate == 'value':
|
|
||||||
values = get_values(task, x, new_ys, args.n_evaluate_sample)
|
|
||||||
|
|
||||||
# selection
|
|
||||||
if args.method_select == 'sample':
|
|
||||||
ps = np.array(values) / sum(values)
|
|
||||||
select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
|
|
||||||
elif args.method_select == 'greedy':
|
|
||||||
select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
|
|
||||||
select_new_ys = [new_ys[select_id] for select_id in select_ids]
|
|
||||||
|
|
||||||
# log
|
|
||||||
if to_print:
|
|
||||||
sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
|
|
||||||
print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
|
|
||||||
|
|
||||||
infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
|
|
||||||
ys = select_new_ys
|
|
||||||
|
|
||||||
if to_print:
|
|
||||||
print(ys)
|
|
||||||
return ys, {'steps': infos}
|
|
||||||
|
|
||||||
def naive_solve(args, task, idx, to_print=True):
|
|
||||||
x = task.get_input(idx) # input
|
|
||||||
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
|
|
||||||
return ys, {}
|
|
||||||
|
|
||||||
def run(args):
|
def run(args):
|
||||||
task = get_task(args.task, args.task_file_path)
|
task = get_task(args.task)
|
||||||
logs, cnt_avg, cnt_any = [], 0, 0
|
logs, cnt_avg, cnt_any = [], 0, 0
|
||||||
global gpt
|
|
||||||
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
|
|
||||||
if args.naive_run:
|
if args.naive_run:
|
||||||
file = f'logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
|
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
|
||||||
else:
|
else:
|
||||||
file = f'logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
|
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
|
||||||
os.makedirs(os.path.dirname(file), exist_ok=True)
|
os.makedirs(os.path.dirname(file), exist_ok=True)
|
||||||
|
|
||||||
for i in range(args.task_start_index, args.task_end_index):
|
for i in range(args.task_start_index, args.task_end_index):
|
||||||
@ -136,7 +46,6 @@ def parse_args():
|
|||||||
args.add_argument('--temperature', type=float, default=0.7)
|
args.add_argument('--temperature', type=float, default=0.7)
|
||||||
|
|
||||||
args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords'])
|
args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords'])
|
||||||
args.add_argument('--task_file_path', type=str, required=True)
|
|
||||||
args.add_argument('--task_start_index', type=int, default=900)
|
args.add_argument('--task_start_index', type=int, default=900)
|
||||||
args.add_argument('--task_end_index', type=int, default=1000)
|
args.add_argument('--task_end_index', type=int, default=1000)
|
||||||
|
|
||||||
@ -145,7 +54,7 @@ def parse_args():
|
|||||||
|
|
||||||
args.add_argument('--method_generate', type=str, choices=['sample', 'propose'])
|
args.add_argument('--method_generate', type=str, choices=['sample', 'propose'])
|
||||||
args.add_argument('--method_evaluate', type=str, choices=['value', 'vote'])
|
args.add_argument('--method_evaluate', type=str, choices=['value', 'vote'])
|
||||||
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'])
|
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'], default='greedy')
|
||||||
args.add_argument('--n_generate_sample', type=int, default=1) # only thing needed if naive_run
|
args.add_argument('--n_generate_sample', type=int, default=1) # only thing needed if naive_run
|
||||||
args.add_argument('--n_evaluate_sample', type=int, default=1)
|
args.add_argument('--n_evaluate_sample', type=int, default=1)
|
||||||
args.add_argument('--n_select_sample', type=int, default=1)
|
args.add_argument('--n_select_sample', type=int, default=1)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
python run.py \
|
python run.py \
|
||||||
--task crosswords \
|
--task crosswords \
|
||||||
--task_file_path mini0505_0_100_5.json \
|
|
||||||
--task_start_index 0 \
|
--task_start_index 0 \
|
||||||
--task_end_index 20 \
|
--task_end_index 20 \
|
||||||
--naive_run \
|
--naive_run \
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"cd ../.."
|
"cd .."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -24,9 +24,9 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import json\n",
|
"import json\n",
|
||||||
"from prompts.crosswords import propose_prompt, value_prompt\n",
|
"from tot.prompts.crosswords import propose_prompt, value_prompt\n",
|
||||||
"from models import gpt\n",
|
"from tot.models import gpt\n",
|
||||||
"from tasks.crosswords import MiniCrosswordsEnv\n",
|
"from tot.tasks.crosswords import MiniCrosswordsEnv\n",
|
||||||
"\n",
|
"\n",
|
||||||
"env = MiniCrosswordsEnv()"
|
"env = MiniCrosswordsEnv()"
|
||||||
]
|
]
|
||||||
@ -61,7 +61,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"import re\n",
|
"import re\n",
|
||||||
"import copy\n",
|
"import copy\n",
|
||||||
"from models import gpt\n",
|
"from tot.models import gpt\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def parse_line(input_str):\n",
|
"def parse_line(input_str):\n",
|
||||||
" # regular expression pattern to match the input string format\n",
|
" # regular expression pattern to match the input string format\n",
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
python run.py \
|
python run.py \
|
||||||
--task crosswords \
|
--task crosswords \
|
||||||
--task_file_path mini0505_0_100_5.json \
|
|
||||||
--task_start_index 0 \
|
--task_start_index 0 \
|
||||||
--task_end_index 20 \
|
--task_end_index 20 \
|
||||||
--naive_run \
|
--naive_run \
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
python run.py \
|
python run.py \
|
||||||
--task game24 \
|
--task game24 \
|
||||||
--task_file_path 24.csv \
|
|
||||||
--task_start_index 900 \
|
--task_start_index 900 \
|
||||||
--task_end_index 1000 \
|
--task_end_index 1000 \
|
||||||
--method_generate propose \
|
--method_generate propose \
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
python run.py \
|
python run.py \
|
||||||
--task game24 \
|
--task game24 \
|
||||||
--task_file_path 24.csv \
|
|
||||||
--task_start_index 900 \
|
--task_start_index 900 \
|
||||||
--task_end_index 1000 \
|
--task_end_index 1000 \
|
||||||
--naive_run \
|
--naive_run \
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
python run.py \
|
python run.py \
|
||||||
--task game24 \
|
--task game24 \
|
||||||
--task_file_path 24.csv \
|
|
||||||
--task_start_index 900 \
|
--task_start_index 900 \
|
||||||
--task_end_index 1000 \
|
--task_end_index 1000 \
|
||||||
--naive_run \
|
--naive_run \
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
python run.py \
|
python run.py \
|
||||||
--task text \
|
--task text \
|
||||||
--task_file_path data_100_random_text.txt \
|
|
||||||
--task_start_index 0 \
|
--task_start_index 0 \
|
||||||
--task_end_index 1 \
|
--task_end_index 100 \
|
||||||
--method_generate sample \
|
--method_generate sample \
|
||||||
--method_evaluate vote \
|
--method_evaluate vote \
|
||||||
--method_select greedy \
|
--method_select greedy \
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
python run.py \
|
python run.py \
|
||||||
--task text \
|
--task text \
|
||||||
--task_file_path data_100_random_text.txt \
|
|
||||||
--task_start_index 0 \
|
--task_start_index 0 \
|
||||||
--task_end_index 1 \
|
--task_end_index 100 \
|
||||||
--naive_run \
|
--naive_run \
|
||||||
--prompt_sample cot \
|
--prompt_sample cot \
|
||||||
--n_generate_sample 10 \
|
--n_generate_sample 10 \
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
python run.py \
|
python run.py \
|
||||||
--task text \
|
--task text \
|
||||||
--task_file_path data_100_random_text.txt \
|
|
||||||
--task_start_index 0 \
|
--task_start_index 0 \
|
||||||
--task_end_index 1 \
|
--task_end_index 100 \
|
||||||
--naive_run \
|
--naive_run \
|
||||||
--prompt_sample standard \
|
--prompt_sample standard \
|
||||||
--n_generate_sample 10 \
|
--n_generate_sample 10 \
|
||||||
|
37
setup.py
Normal file
37
setup.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import setuptools
|
||||||
|
|
||||||
|
with open('README.md', 'r', encoding='utf-8') as fh:
|
||||||
|
long_description = fh.read()
|
||||||
|
|
||||||
|
|
||||||
|
setuptools.setup(
|
||||||
|
name='tot',
|
||||||
|
author='Shunyu Yao',
|
||||||
|
author_email='shunyuyao.cs@gmail.com',
|
||||||
|
description='Official Implementation of "Tree of Thoughts: Deliberate Problem Solving with Large Language Models"',
|
||||||
|
keywords='tree-search, large-language-models, llm, prompting, tree-of-thoughts',
|
||||||
|
long_description=long_description,
|
||||||
|
long_description_content_type='text/markdown',
|
||||||
|
url='https://github.com/princeton-nlp/tree-of-thought-llm',
|
||||||
|
project_urls={
|
||||||
|
'Homepage': 'https://github.com/princeton-nlp/tree-of-thought-llm',
|
||||||
|
},
|
||||||
|
package_dir={'': 'src'},
|
||||||
|
packages=setuptools.find_packages(where='src'),
|
||||||
|
classifiers=[
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.7",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
'Intended Audience :: Science/Research',
|
||||||
|
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||||
|
],
|
||||||
|
python_requires='>=3.7',
|
||||||
|
install_requires=[
|
||||||
|
'setuptools',
|
||||||
|
],
|
||||||
|
include_package_data=True,
|
||||||
|
)
|
1
src/tot/__init__.py
Normal file
1
src/tot/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
__version__ = "0.1.0"
|
File diff suppressed because it is too large
Load Diff
96
src/tot/methods/bfs.py
Normal file
96
src/tot/methods/bfs.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
import itertools
|
||||||
|
import numpy as np
|
||||||
|
from functools import partial
|
||||||
|
from tot.models import gpt
|
||||||
|
|
||||||
|
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
|
||||||
|
value_prompt = task.value_prompt_wrap(x, y)
|
||||||
|
if cache_value and value_prompt in task.value_cache:
|
||||||
|
return task.value_cache[value_prompt]
|
||||||
|
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
|
||||||
|
value = task.value_outputs_unwrap(x, y, value_outputs)
|
||||||
|
if cache_value:
|
||||||
|
task.value_cache[value_prompt] = value
|
||||||
|
return value
|
||||||
|
|
||||||
|
def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
|
||||||
|
values = []
|
||||||
|
local_value_cache = {}
|
||||||
|
for y in ys: # each partial output
|
||||||
|
if y in local_value_cache: # avoid duplicate candidates
|
||||||
|
value = 0
|
||||||
|
else:
|
||||||
|
value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
|
||||||
|
local_value_cache[y] = value
|
||||||
|
values.append(value)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def get_votes(task, x, ys, n_evaluate_sample):
|
||||||
|
vote_prompt = task.vote_prompt_wrap(x, ys)
|
||||||
|
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
|
||||||
|
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
|
||||||
|
return values
|
||||||
|
|
||||||
|
def get_proposals(task, x, y):
|
||||||
|
propose_prompt = task.propose_prompt_wrap(x, y)
|
||||||
|
proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
|
||||||
|
return [y + _ + '\n' for _ in proposals]
|
||||||
|
|
||||||
|
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
|
||||||
|
if prompt_sample == 'standard':
|
||||||
|
prompt = task.standard_prompt_wrap(x, y)
|
||||||
|
elif prompt_sample == 'cot':
|
||||||
|
prompt = task.cot_prompt_wrap(x, y)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
|
||||||
|
samples = gpt(prompt, n=n_generate_sample, stop=stop)
|
||||||
|
return [y + _ for _ in samples]
|
||||||
|
|
||||||
|
def solve(args, task, idx, to_print=True):
|
||||||
|
global gpt
|
||||||
|
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
|
||||||
|
print(gpt)
|
||||||
|
x = task.get_input(idx) # input
|
||||||
|
ys = [''] # current output candidates
|
||||||
|
infos = []
|
||||||
|
for step in range(task.steps):
|
||||||
|
# generation
|
||||||
|
if args.method_generate == 'sample':
|
||||||
|
new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
|
||||||
|
elif args.method_generate == 'propose':
|
||||||
|
new_ys = [get_proposals(task, x, y) for y in ys]
|
||||||
|
new_ys = list(itertools.chain(*new_ys))
|
||||||
|
ids = list(range(len(new_ys)))
|
||||||
|
# evaluation
|
||||||
|
if args.method_evaluate == 'vote':
|
||||||
|
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
|
||||||
|
elif args.method_evaluate == 'value':
|
||||||
|
values = get_values(task, x, new_ys, args.n_evaluate_sample)
|
||||||
|
|
||||||
|
# selection
|
||||||
|
if args.method_select == 'sample':
|
||||||
|
ps = np.array(values) / sum(values)
|
||||||
|
select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
|
||||||
|
elif args.method_select == 'greedy':
|
||||||
|
select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
|
||||||
|
select_new_ys = [new_ys[select_id] for select_id in select_ids]
|
||||||
|
|
||||||
|
# log
|
||||||
|
if to_print:
|
||||||
|
sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
|
||||||
|
print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
|
||||||
|
|
||||||
|
infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
|
||||||
|
ys = select_new_ys
|
||||||
|
|
||||||
|
if to_print:
|
||||||
|
print(ys)
|
||||||
|
return ys, {'steps': infos}
|
||||||
|
|
||||||
|
def naive_solve(args, task, idx, to_print=True):
|
||||||
|
global gpt
|
||||||
|
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
|
||||||
|
print(gpt)
|
||||||
|
x = task.get_input(idx) # input
|
||||||
|
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
|
||||||
|
return ys, {}
|
@ -41,5 +41,5 @@ def gpt_usage(backend="gpt-4"):
|
|||||||
if backend == "gpt-4":
|
if backend == "gpt-4":
|
||||||
cost = completion_tokens / 1000 * 0.06 + prompt_tokens / 1000 * 0.03
|
cost = completion_tokens / 1000 * 0.06 + prompt_tokens / 1000 * 0.03
|
||||||
elif backend == "gpt-3.5-turbo":
|
elif backend == "gpt-3.5-turbo":
|
||||||
cost = completion_tokens / 1000 * 0.002 + prompt_tokens / 1000 * 0.0015
|
cost = (completion_tokens + prompt_tokens) / 1000 * 0.0002
|
||||||
return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens, "cost": cost}
|
return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens, "cost": cost}
|
12
src/tot/tasks/__init__.py
Normal file
12
src/tot/tasks/__init__.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
def get_task(name):
|
||||||
|
if name == 'game24':
|
||||||
|
from tot.tasks.game24 import Game24Task
|
||||||
|
return Game24Task()
|
||||||
|
elif name == 'text':
|
||||||
|
from tot.tasks.text import TextTask
|
||||||
|
return TextTask()
|
||||||
|
elif name == 'crosswords':
|
||||||
|
from tot.tasks.crosswords import MiniCrosswordsTask
|
||||||
|
return MiniCrosswordsTask()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
@ -1,4 +1,5 @@
|
|||||||
DATA_PATH = './data'
|
import os
|
||||||
|
DATA_PATH = os.path.join(os.path.dirname(__file__), '..', 'data')
|
||||||
|
|
||||||
class Task:
|
class Task:
|
||||||
def __init__(self):
|
def __init__(self):
|
@ -1,13 +1,14 @@
|
|||||||
import re
|
import re
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from tasks.base import Task, DATA_PATH
|
import json
|
||||||
from prompts.crosswords import *
|
from tot.tasks.base import Task, DATA_PATH
|
||||||
from models import gpt
|
from tot.prompts.crosswords import *
|
||||||
|
from tot.models import gpt
|
||||||
|
|
||||||
class MiniCrosswordsEnv:
|
class MiniCrosswordsEnv:
|
||||||
def __init__(self, file='mini0505.json'):
|
def __init__(self, file='mini0505.json'):
|
||||||
self.file = f'data/crosswords/{file}'
|
self.file = os.path.join(DATA_PATH, 'crosswords', file)
|
||||||
|
|
||||||
self.file = json.load(open(self.file))
|
self.file = json.load(open(self.file))
|
||||||
self.n = len(self.file)
|
self.n = len(self.file)
|
||||||
self.cache = {}
|
self.cache = {}
|
@ -2,8 +2,8 @@ import re
|
|||||||
import os
|
import os
|
||||||
import sympy
|
import sympy
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tasks.base import Task, DATA_PATH
|
from tot.tasks.base import Task, DATA_PATH
|
||||||
from prompts.game24 import *
|
from tot.prompts.game24 import *
|
||||||
|
|
||||||
|
|
||||||
def get_current_numbers(y: str) -> str:
|
def get_current_numbers(y: str) -> str:
|
@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from tasks.base import Task, DATA_PATH
|
from tot.tasks.base import Task, DATA_PATH
|
||||||
from prompts.text import *
|
from tot.prompts.text import *
|
||||||
from models import gpt
|
from tot.models import gpt
|
||||||
|
|
||||||
|
|
||||||
class TextTask(Task):
|
class TextTask(Task):
|
@ -1,12 +0,0 @@
|
|||||||
def get_task(name, file=None):
|
|
||||||
if name == 'game24':
|
|
||||||
from .game24 import Game24Task
|
|
||||||
return Game24Task(file)
|
|
||||||
elif name == 'text':
|
|
||||||
from .text import TextTask
|
|
||||||
return TextTask(file)
|
|
||||||
elif name == 'crosswords':
|
|
||||||
from .crosswords import MiniCrosswordsTask
|
|
||||||
return MiniCrosswordsTask(file)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
Loading…
Reference in New Issue
Block a user