mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-01-22 10:35:31 +00:00
commit
3bc03cf11d
4
MANIFEST.in
Normal file
4
MANIFEST.in
Normal file
@ -0,0 +1,4 @@
|
||||
include src/tot/data/24/24.csv
|
||||
include src/tot/data/crosswords/mini0505_0_100_5.json
|
||||
include src/tot/data/crosswords/mini0505.json
|
||||
include src/tot/data/text/data_100_random_text.txt
|
Before Width: | Height: | Size: 84 KiB After Width: | Height: | Size: 84 KiB |
Before Width: | Height: | Size: 99 KiB After Width: | Height: | Size: 99 KiB |
35
pyproject.toml
Normal file
35
pyproject.toml
Normal file
@ -0,0 +1,35 @@
|
||||
[build-system]
|
||||
requires = ["setuptools >= 61.0.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "tot"
|
||||
version = "0.1.0"
|
||||
description = 'Official Implementation of "Tree of Thoughts: Deliberate Problem Solving with Large Language Models"'
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.7"
|
||||
authors = [{ name = "Shunyu Yao", email = "shunyuyao.cs@gmail.com" }]
|
||||
license = { text = "MIT License" }
|
||||
keywords = ["tree-search", "large-language-models", "llm", "prompting", "tree-of-thoughts"]
|
||||
classifiers = [
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
'Intended Audience :: Science/Research',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
]
|
||||
dynamic=["dependencies"]
|
||||
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
dependencies = {file = ["requirements.txt"]}
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"] # list of folders that contain the packages (["."] by default)
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/princeton-nlp/tree-of-thought-llm"
|
79
readme.md
79
readme.md
@ -2,44 +2,58 @@
|
||||
[![DOI](https://zenodo.org/badge/642099326.svg)](https://zenodo.org/badge/latestdoi/642099326)
|
||||
|
||||
<details>
|
||||
<summary>Note: https://github.com/kyegomez/tree-of-thoughts is NOT the correct implementation to replicate paper results. </summary>
|
||||
|
||||
In fact, people have reported that his code [cannot](https://github.com/kyegomez/tree-of-thoughts/issues/52) [properly](https://github.com/kyegomez/tree-of-thoughts/issues/41) [run](https://github.com/kyegomez/tree-of-thoughts/issues/60), and is [probably automatically generated by ChatGPT](fake.png), and [kyegomez has done so for other popular ML methods](https://twitter.com/qbitium/status/1663954096741814272), while intentionally refusing to link to official implementations for his own interests (see https://github.com/kyegomez/tree-of-thoughts/issues/54, https://github.com/kyegomez/tree-of-thoughts/issues/55, https://github.com/kyegomez/tree-of-thoughts/issues/56).
|
||||
<summary>Note: https://github.com/kyegomez/tree-of-thoughts CANNOT replicate paper results. </summary>
|
||||
|
||||
In fact, people have reported that his code [cannot](https://github.com/kyegomez/tree-of-thoughts/issues/52) [properly](https://github.com/kyegomez/tree-of-thoughts/issues/41) [run](https://github.com/kyegomez/tree-of-thoughts/issues/60), and is [probably automatically generated by ChatGPT](pics/fake.png), and [kyegomez has done so for other popular ML methods](https://twitter.com/qbitium/status/1663954096741814272), while intentionally refusing to link to official implementations for his own interests (see https://github.com/kyegomez/tree-of-thoughts/issues/54, https://github.com/kyegomez/tree-of-thoughts/issues/55, https://github.com/kyegomez/tree-of-thoughts/issues/56).
|
||||
Unfortunately, Google/Github searches go to kyegomez's malicious repo by default as it has more stars. **Please DE-STAR his repo and STAR this to help other people avoid being misled, thanks!**
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
![teaser](teaser.png)
|
||||
![teaser](pics/teaser.png)
|
||||
|
||||
Official implementation for paper [Tree of Thoughts: Deliberate Problem Solving with Large Language Models](https://arxiv.org/abs/2305.10601) with code, prompts, model outputs.
|
||||
Also check [its tweet thread](https://twitter.com/ShunyuYao12/status/1659357547474681857) in 1min.
|
||||
|
||||
|
||||
Please cite the paper and star this repo if you use ToT and find it interesting/useful. Thanks!
|
||||
|
||||
```bibtex
|
||||
@misc{yao2023tree,
|
||||
title={{Tree of Thoughts}: Deliberate Problem Solving with Large Language Models},
|
||||
author={Shunyu Yao and Dian Yu and Jeffrey Zhao and Izhak Shafran and Thomas L. Griffiths and Yuan Cao and Karthik Narasimhan},
|
||||
year={2023},
|
||||
eprint={2305.10601},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Setup
|
||||
You need to first have an OpenAI API key and store it in the environment variable ``OPENAI_API_KEY`` (see [here](https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety)). If you use custom base url, set it by environment variable ``OPENAI_API_BASE`` (e.g. https://api.openai.com/v1).
|
||||
- Set up OpenAI API key and store in environment variable ``OPENAI_API_KEY`` (see [here](https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety)).
|
||||
|
||||
Package requirement: ``pip install openai backoff sympy numpy``
|
||||
- Install dependencies and `tot` package (PyPI package coming soon):
|
||||
```bash
|
||||
git clone https://github.com/princeton-nlp/tree-of-thought-llm
|
||||
cd tree-of-thought-llm
|
||||
pip install -r requirements.txt
|
||||
pip install -e . # install `tot` package
|
||||
```
|
||||
|
||||
|
||||
## Experiments
|
||||
## Quick Start
|
||||
The following minimal script will attempt to solve the game of 24 with `4 5 6 10` (might be a bit slow as it's using GPT-4):
|
||||
```python
|
||||
import argparse
|
||||
from tot.methods.bfs import solve
|
||||
from tot.tasks.game24 import Game24Task
|
||||
|
||||
args = argparse.Namespace(backend='gpt-4', temperature=0.7, task='game24', naive_run=False, prompt_sample=None, method_generate='propose', method_evaluate='value', method_select='greedy', n_generate_sample=1, n_evaluate_sample=3, n_select_sample=5)
|
||||
|
||||
task = Game24Task()
|
||||
ys, infos = solve(args, task, 900)
|
||||
print(ys[0])
|
||||
```
|
||||
|
||||
And the output would be something like (note it's not deterministic, and sometimes the output can be wrong):
|
||||
```
|
||||
10 - 4 = 6 (left: 5 6 6)
|
||||
5 * 6 = 30 (left: 6 30)
|
||||
30 - 6 = 24 (left: 24)
|
||||
Answer: (5 * (10 - 4)) - 6 = 24
|
||||
```
|
||||
|
||||
## Paper Experiments
|
||||
|
||||
Run experiments via ``sh scripts/{game24, text, crosswords}/{standard_sampling, cot_sampling, bfs}.sh``, except in crosswords we use a DFS algorithm for ToT, which can be run via ``scripts/crosswords/search_crosswords-dfs.ipynb``.
|
||||
|
||||
@ -55,13 +69,24 @@ The very simple ``run.py`` implements the ToT + BFS algorithm, as well as the na
|
||||
|
||||
|
||||
|
||||
## Trajectories
|
||||
## Paper Trajectories
|
||||
``logs/`` contains all the trajectories from the paper's experiments, except for ``logs/game24/gpt-4_0.7_propose1_value3_greedy5_start900_end1000.json`` which was reproduced after the paper (as the original experiment was done in a notebook) and achieved a 69\% score instead of the original 74\% score due to randomness in GPT decoding. We hope to aggregate multiple runs in the future to account for sampling randomness and update the paper, but this shouldn't affect the main conclusions of the paper.
|
||||
|
||||
## How to Add A New Task
|
||||
Setting up a new task is easy, and mainly involves two steps.
|
||||
* Set up a new task class in ``tot/tasks/`` and task files in ``tot/data/``. See ``tot/tasks/game24.py`` for an example. Add the task to ``tot/tasks/__init__.py``.
|
||||
* Set up task-specific prompts in ``tot/prompts/``. See ``tot/prompts/game24.py`` for an example. Depending on the nature of the task, choose ``--method_generate`` (choices=[``sample``, ``propose``]) and ``--method_evaluate`` (choices=[``value``, ``vote``]) and their corresponding prompts.
|
||||
|
||||
## Citations
|
||||
Please cite the paper and star this repo if you use ToT and find it interesting/useful, thanks! Feel free to contact shunyuyao.cs@gmail.com or open an issue if you have any questions.
|
||||
|
||||
## Questions
|
||||
Feel free to contact shunyuyao.cs@gmail.com or open an issue if you have any questions.
|
||||
|
||||
|
||||
|
||||
```bibtex
|
||||
@misc{yao2023tree,
|
||||
title={{Tree of Thoughts}: Deliberate Problem Solving with Large Language Models},
|
||||
author={Shunyu Yao and Dian Yu and Jeffrey Zhao and Izhak Shafran and Thomas L. Griffiths and Yuan Cao and Karthik Narasimhan},
|
||||
year={2023},
|
||||
eprint={2305.10601},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
```
|
@ -16,3 +16,4 @@ sympy==1.12
|
||||
tqdm==4.65.0
|
||||
urllib3==2.0.2
|
||||
yarl==1.9.2
|
||||
pandas==2.0.3
|
105
run.py
105
run.py
@ -1,108 +1,18 @@
|
||||
import os
|
||||
import json
|
||||
import itertools
|
||||
import argparse
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from models import gpt, gpt_usage
|
||||
from tasks import get_task
|
||||
|
||||
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
|
||||
value_prompt = task.value_prompt_wrap(x, y)
|
||||
if cache_value and value_prompt in task.value_cache:
|
||||
return task.value_cache[value_prompt]
|
||||
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
|
||||
value = task.value_outputs_unwrap(x, y, value_outputs)
|
||||
if cache_value:
|
||||
task.value_cache[value_prompt] = value
|
||||
return value
|
||||
|
||||
def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
|
||||
values = []
|
||||
local_value_cache = {}
|
||||
for y in ys: # each partial output
|
||||
if y in local_value_cache: # avoid duplicate candidates
|
||||
value = 0
|
||||
else:
|
||||
value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
|
||||
local_value_cache[y] = value
|
||||
values.append(value)
|
||||
return values
|
||||
|
||||
def get_votes(task, x, ys, n_evaluate_sample):
|
||||
vote_prompt = task.vote_prompt_wrap(x, ys)
|
||||
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
|
||||
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
|
||||
return values
|
||||
|
||||
def get_proposals(task, x, y):
|
||||
propose_prompt = task.propose_prompt_wrap(x, y)
|
||||
proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
|
||||
return [y + _ + '\n' for _ in proposals]
|
||||
|
||||
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
|
||||
if prompt_sample == 'standard':
|
||||
prompt = task.standard_prompt_wrap(x, y)
|
||||
elif prompt_sample == 'cot':
|
||||
prompt = task.cot_prompt_wrap(x, y)
|
||||
else:
|
||||
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
|
||||
samples = gpt(prompt, n=n_generate_sample, stop=stop)
|
||||
return [y + _ for _ in samples]
|
||||
|
||||
def solve(args, task, idx, to_print=True):
|
||||
print(gpt)
|
||||
x = task.get_input(idx) # input
|
||||
ys = [''] # current output candidates
|
||||
infos = []
|
||||
for step in range(task.steps):
|
||||
# generation
|
||||
if args.method_generate == 'sample':
|
||||
new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
|
||||
elif args.method_generate == 'propose':
|
||||
new_ys = [get_proposals(task, x, y) for y in ys]
|
||||
new_ys = list(itertools.chain(*new_ys))
|
||||
ids = list(range(len(new_ys)))
|
||||
# evaluation
|
||||
if args.method_evaluate == 'vote':
|
||||
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
|
||||
elif args.method_evaluate == 'value':
|
||||
values = get_values(task, x, new_ys, args.n_evaluate_sample)
|
||||
|
||||
# selection
|
||||
if args.method_select == 'sample':
|
||||
ps = np.array(values) / sum(values)
|
||||
select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
|
||||
elif args.method_select == 'greedy':
|
||||
select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
|
||||
select_new_ys = [new_ys[select_id] for select_id in select_ids]
|
||||
|
||||
# log
|
||||
if to_print:
|
||||
sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
|
||||
print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
|
||||
|
||||
infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
|
||||
ys = select_new_ys
|
||||
|
||||
if to_print:
|
||||
print(ys)
|
||||
return ys, {'steps': infos}
|
||||
|
||||
def naive_solve(args, task, idx, to_print=True):
|
||||
x = task.get_input(idx) # input
|
||||
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
|
||||
return ys, {}
|
||||
from tot.tasks import get_task
|
||||
from tot.methods.bfs import solve, naive_solve
|
||||
from tot.models import gpt_usage
|
||||
|
||||
def run(args):
|
||||
task = get_task(args.task, args.task_file_path)
|
||||
task = get_task(args.task)
|
||||
logs, cnt_avg, cnt_any = [], 0, 0
|
||||
global gpt
|
||||
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
|
||||
if args.naive_run:
|
||||
file = f'logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
|
||||
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
|
||||
else:
|
||||
file = f'logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
|
||||
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
|
||||
os.makedirs(os.path.dirname(file), exist_ok=True)
|
||||
|
||||
for i in range(args.task_start_index, args.task_end_index):
|
||||
@ -136,7 +46,6 @@ def parse_args():
|
||||
args.add_argument('--temperature', type=float, default=0.7)
|
||||
|
||||
args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords'])
|
||||
args.add_argument('--task_file_path', type=str, required=True)
|
||||
args.add_argument('--task_start_index', type=int, default=900)
|
||||
args.add_argument('--task_end_index', type=int, default=1000)
|
||||
|
||||
@ -145,7 +54,7 @@ def parse_args():
|
||||
|
||||
args.add_argument('--method_generate', type=str, choices=['sample', 'propose'])
|
||||
args.add_argument('--method_evaluate', type=str, choices=['value', 'vote'])
|
||||
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'])
|
||||
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'], default='greedy')
|
||||
args.add_argument('--n_generate_sample', type=int, default=1) # only thing needed if naive_run
|
||||
args.add_argument('--n_evaluate_sample', type=int, default=1)
|
||||
args.add_argument('--n_select_sample', type=int, default=1)
|
||||
|
@ -1,6 +1,5 @@
|
||||
python run.py \
|
||||
--task crosswords \
|
||||
--task_file_path mini0505_0_100_5.json \
|
||||
--task_start_index 0 \
|
||||
--task_end_index 20 \
|
||||
--naive_run \
|
||||
|
@ -14,7 +14,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cd ../.."
|
||||
"cd .."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -24,9 +24,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"from prompts.crosswords import propose_prompt, value_prompt\n",
|
||||
"from models import gpt\n",
|
||||
"from tasks.crosswords import MiniCrosswordsEnv\n",
|
||||
"from tot.prompts.crosswords import propose_prompt, value_prompt\n",
|
||||
"from tot.models import gpt\n",
|
||||
"from tot.tasks.crosswords import MiniCrosswordsEnv\n",
|
||||
"\n",
|
||||
"env = MiniCrosswordsEnv()"
|
||||
]
|
||||
@ -61,7 +61,7 @@
|
||||
"source": [
|
||||
"import re\n",
|
||||
"import copy\n",
|
||||
"from models import gpt\n",
|
||||
"from tot.models import gpt\n",
|
||||
"\n",
|
||||
"def parse_line(input_str):\n",
|
||||
" # regular expression pattern to match the input string format\n",
|
||||
|
@ -1,6 +1,5 @@
|
||||
python run.py \
|
||||
--task crosswords \
|
||||
--task_file_path mini0505_0_100_5.json \
|
||||
--task_start_index 0 \
|
||||
--task_end_index 20 \
|
||||
--naive_run \
|
||||
|
@ -1,6 +1,5 @@
|
||||
python run.py \
|
||||
--task game24 \
|
||||
--task_file_path 24.csv \
|
||||
--task_start_index 900 \
|
||||
--task_end_index 1000 \
|
||||
--method_generate propose \
|
||||
|
@ -1,6 +1,5 @@
|
||||
python run.py \
|
||||
--task game24 \
|
||||
--task_file_path 24.csv \
|
||||
--task_start_index 900 \
|
||||
--task_end_index 1000 \
|
||||
--naive_run \
|
||||
|
@ -1,6 +1,5 @@
|
||||
python run.py \
|
||||
--task game24 \
|
||||
--task_file_path 24.csv \
|
||||
--task_start_index 900 \
|
||||
--task_end_index 1000 \
|
||||
--naive_run \
|
||||
|
@ -1,8 +1,7 @@
|
||||
python run.py \
|
||||
--task text \
|
||||
--task_file_path data_100_random_text.txt \
|
||||
--task_start_index 0 \
|
||||
--task_end_index 1 \
|
||||
--task_end_index 100 \
|
||||
--method_generate sample \
|
||||
--method_evaluate vote \
|
||||
--method_select greedy \
|
||||
|
@ -1,8 +1,7 @@
|
||||
python run.py \
|
||||
--task text \
|
||||
--task_file_path data_100_random_text.txt \
|
||||
--task_start_index 0 \
|
||||
--task_end_index 1 \
|
||||
--task_end_index 100 \
|
||||
--naive_run \
|
||||
--prompt_sample cot \
|
||||
--n_generate_sample 10 \
|
||||
|
@ -1,8 +1,7 @@
|
||||
python run.py \
|
||||
--task text \
|
||||
--task_file_path data_100_random_text.txt \
|
||||
--task_start_index 0 \
|
||||
--task_end_index 1 \
|
||||
--task_end_index 100 \
|
||||
--naive_run \
|
||||
--prompt_sample standard \
|
||||
--n_generate_sample 10 \
|
||||
|
37
setup.py
Normal file
37
setup.py
Normal file
@ -0,0 +1,37 @@
|
||||
import setuptools
|
||||
|
||||
with open('README.md', 'r', encoding='utf-8') as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
|
||||
setuptools.setup(
|
||||
name='tot',
|
||||
author='Shunyu Yao',
|
||||
author_email='shunyuyao.cs@gmail.com',
|
||||
description='Official Implementation of "Tree of Thoughts: Deliberate Problem Solving with Large Language Models"',
|
||||
keywords='tree-search, large-language-models, llm, prompting, tree-of-thoughts',
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
url='https://github.com/princeton-nlp/tree-of-thought-llm',
|
||||
project_urls={
|
||||
'Homepage': 'https://github.com/princeton-nlp/tree-of-thought-llm',
|
||||
},
|
||||
package_dir={'': 'src'},
|
||||
packages=setuptools.find_packages(where='src'),
|
||||
classifiers=[
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
'Intended Audience :: Science/Research',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
],
|
||||
python_requires='>=3.7',
|
||||
install_requires=[
|
||||
'setuptools',
|
||||
],
|
||||
include_package_data=True,
|
||||
)
|
1
src/tot/__init__.py
Normal file
1
src/tot/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
__version__ = "0.1.0"
|
File diff suppressed because it is too large
Load Diff
96
src/tot/methods/bfs.py
Normal file
96
src/tot/methods/bfs.py
Normal file
@ -0,0 +1,96 @@
|
||||
import itertools
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from tot.models import gpt
|
||||
|
||||
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
|
||||
value_prompt = task.value_prompt_wrap(x, y)
|
||||
if cache_value and value_prompt in task.value_cache:
|
||||
return task.value_cache[value_prompt]
|
||||
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
|
||||
value = task.value_outputs_unwrap(x, y, value_outputs)
|
||||
if cache_value:
|
||||
task.value_cache[value_prompt] = value
|
||||
return value
|
||||
|
||||
def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
|
||||
values = []
|
||||
local_value_cache = {}
|
||||
for y in ys: # each partial output
|
||||
if y in local_value_cache: # avoid duplicate candidates
|
||||
value = 0
|
||||
else:
|
||||
value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
|
||||
local_value_cache[y] = value
|
||||
values.append(value)
|
||||
return values
|
||||
|
||||
def get_votes(task, x, ys, n_evaluate_sample):
|
||||
vote_prompt = task.vote_prompt_wrap(x, ys)
|
||||
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
|
||||
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
|
||||
return values
|
||||
|
||||
def get_proposals(task, x, y):
|
||||
propose_prompt = task.propose_prompt_wrap(x, y)
|
||||
proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
|
||||
return [y + _ + '\n' for _ in proposals]
|
||||
|
||||
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
|
||||
if prompt_sample == 'standard':
|
||||
prompt = task.standard_prompt_wrap(x, y)
|
||||
elif prompt_sample == 'cot':
|
||||
prompt = task.cot_prompt_wrap(x, y)
|
||||
else:
|
||||
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
|
||||
samples = gpt(prompt, n=n_generate_sample, stop=stop)
|
||||
return [y + _ for _ in samples]
|
||||
|
||||
def solve(args, task, idx, to_print=True):
|
||||
global gpt
|
||||
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
|
||||
print(gpt)
|
||||
x = task.get_input(idx) # input
|
||||
ys = [''] # current output candidates
|
||||
infos = []
|
||||
for step in range(task.steps):
|
||||
# generation
|
||||
if args.method_generate == 'sample':
|
||||
new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
|
||||
elif args.method_generate == 'propose':
|
||||
new_ys = [get_proposals(task, x, y) for y in ys]
|
||||
new_ys = list(itertools.chain(*new_ys))
|
||||
ids = list(range(len(new_ys)))
|
||||
# evaluation
|
||||
if args.method_evaluate == 'vote':
|
||||
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
|
||||
elif args.method_evaluate == 'value':
|
||||
values = get_values(task, x, new_ys, args.n_evaluate_sample)
|
||||
|
||||
# selection
|
||||
if args.method_select == 'sample':
|
||||
ps = np.array(values) / sum(values)
|
||||
select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
|
||||
elif args.method_select == 'greedy':
|
||||
select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
|
||||
select_new_ys = [new_ys[select_id] for select_id in select_ids]
|
||||
|
||||
# log
|
||||
if to_print:
|
||||
sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
|
||||
print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
|
||||
|
||||
infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
|
||||
ys = select_new_ys
|
||||
|
||||
if to_print:
|
||||
print(ys)
|
||||
return ys, {'steps': infos}
|
||||
|
||||
def naive_solve(args, task, idx, to_print=True):
|
||||
global gpt
|
||||
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
|
||||
print(gpt)
|
||||
x = task.get_input(idx) # input
|
||||
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
|
||||
return ys, {}
|
@ -41,5 +41,5 @@ def gpt_usage(backend="gpt-4"):
|
||||
if backend == "gpt-4":
|
||||
cost = completion_tokens / 1000 * 0.06 + prompt_tokens / 1000 * 0.03
|
||||
elif backend == "gpt-3.5-turbo":
|
||||
cost = completion_tokens / 1000 * 0.002 + prompt_tokens / 1000 * 0.0015
|
||||
cost = (completion_tokens + prompt_tokens) / 1000 * 0.0002
|
||||
return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens, "cost": cost}
|
12
src/tot/tasks/__init__.py
Normal file
12
src/tot/tasks/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
def get_task(name):
|
||||
if name == 'game24':
|
||||
from tot.tasks.game24 import Game24Task
|
||||
return Game24Task()
|
||||
elif name == 'text':
|
||||
from tot.tasks.text import TextTask
|
||||
return TextTask()
|
||||
elif name == 'crosswords':
|
||||
from tot.tasks.crosswords import MiniCrosswordsTask
|
||||
return MiniCrosswordsTask()
|
||||
else:
|
||||
raise NotImplementedError
|
@ -1,4 +1,5 @@
|
||||
DATA_PATH = './data'
|
||||
import os
|
||||
DATA_PATH = os.path.join(os.path.dirname(__file__), '..', 'data')
|
||||
|
||||
class Task:
|
||||
def __init__(self):
|
@ -1,13 +1,14 @@
|
||||
import re
|
||||
import json
|
||||
import os
|
||||
from tasks.base import Task, DATA_PATH
|
||||
from prompts.crosswords import *
|
||||
from models import gpt
|
||||
import json
|
||||
from tot.tasks.base import Task, DATA_PATH
|
||||
from tot.prompts.crosswords import *
|
||||
from tot.models import gpt
|
||||
|
||||
class MiniCrosswordsEnv:
|
||||
def __init__(self, file='mini0505.json'):
|
||||
self.file = f'data/crosswords/{file}'
|
||||
self.file = os.path.join(DATA_PATH, 'crosswords', file)
|
||||
|
||||
self.file = json.load(open(self.file))
|
||||
self.n = len(self.file)
|
||||
self.cache = {}
|
@ -2,8 +2,8 @@ import re
|
||||
import os
|
||||
import sympy
|
||||
import pandas as pd
|
||||
from tasks.base import Task, DATA_PATH
|
||||
from prompts.game24 import *
|
||||
from tot.tasks.base import Task, DATA_PATH
|
||||
from tot.prompts.game24 import *
|
||||
|
||||
|
||||
def get_current_numbers(y: str) -> str:
|
@ -1,8 +1,8 @@
|
||||
import os
|
||||
import re
|
||||
from tasks.base import Task, DATA_PATH
|
||||
from prompts.text import *
|
||||
from models import gpt
|
||||
from tot.tasks.base import Task, DATA_PATH
|
||||
from tot.prompts.text import *
|
||||
from tot.models import gpt
|
||||
|
||||
|
||||
class TextTask(Task):
|
@ -1,12 +0,0 @@
|
||||
def get_task(name, file=None):
|
||||
if name == 'game24':
|
||||
from .game24 import Game24Task
|
||||
return Game24Task(file)
|
||||
elif name == 'text':
|
||||
from .text import TextTask
|
||||
return TextTask(file)
|
||||
elif name == 'crosswords':
|
||||
from .crosswords import MiniCrosswordsTask
|
||||
return MiniCrosswordsTask(file)
|
||||
else:
|
||||
raise NotImplementedError
|
Loading…
Reference in New Issue
Block a user