Merge pull request #35 from princeton-nlp/pypi

Pypi
This commit is contained in:
Shunyu Yao 2023-07-04 22:57:39 -04:00 committed by GitHub
commit 3bc03cf11d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 1629 additions and 1527 deletions

4
MANIFEST.in Normal file
View File

@ -0,0 +1,4 @@
include src/tot/data/24/24.csv
include src/tot/data/crosswords/mini0505_0_100_5.json
include src/tot/data/crosswords/mini0505.json
include src/tot/data/text/data_100_random_text.txt

View File

Before

Width:  |  Height:  |  Size: 84 KiB

After

Width:  |  Height:  |  Size: 84 KiB

View File

Before

Width:  |  Height:  |  Size: 99 KiB

After

Width:  |  Height:  |  Size: 99 KiB

35
pyproject.toml Normal file
View File

@ -0,0 +1,35 @@
[build-system]
requires = ["setuptools >= 61.0.0"]
build-backend = "setuptools.build_meta"
[project]
name = "tot"
version = "0.1.0"
description = 'Official Implementation of "Tree of Thoughts: Deliberate Problem Solving with Large Language Models"'
readme = "README.md"
requires-python = ">= 3.7"
authors = [{ name = "Shunyu Yao", email = "shunyuyao.cs@gmail.com" }]
license = { text = "MIT License" }
keywords = ["tree-search", "large-language-models", "llm", "prompting", "tree-of-thoughts"]
classifiers = [
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
'Intended Audience :: Science/Research',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
]
dynamic=["dependencies"]
[tool.setuptools.dynamic]
dependencies = {file = ["requirements.txt"]}
[tool.setuptools.packages.find]
where = ["src"] # list of folders that contain the packages (["."] by default)
[project.urls]
Homepage = "https://github.com/princeton-nlp/tree-of-thought-llm"

View File

@ -2,44 +2,58 @@
[![DOI](https://zenodo.org/badge/642099326.svg)](https://zenodo.org/badge/latestdoi/642099326)
<details>
<summary>Note: https://github.com/kyegomez/tree-of-thoughts is NOT the correct implementation to replicate paper results. </summary>
In fact, people have reported that his code [cannot](https://github.com/kyegomez/tree-of-thoughts/issues/52) [properly](https://github.com/kyegomez/tree-of-thoughts/issues/41) [run](https://github.com/kyegomez/tree-of-thoughts/issues/60), and is [probably automatically generated by ChatGPT](fake.png), and [kyegomez has done so for other popular ML methods](https://twitter.com/qbitium/status/1663954096741814272), while intentionally refusing to link to official implementations for his own interests (see https://github.com/kyegomez/tree-of-thoughts/issues/54, https://github.com/kyegomez/tree-of-thoughts/issues/55, https://github.com/kyegomez/tree-of-thoughts/issues/56).
<summary>Note: https://github.com/kyegomez/tree-of-thoughts CANNOT replicate paper results. </summary>
In fact, people have reported that his code [cannot](https://github.com/kyegomez/tree-of-thoughts/issues/52) [properly](https://github.com/kyegomez/tree-of-thoughts/issues/41) [run](https://github.com/kyegomez/tree-of-thoughts/issues/60), and is [probably automatically generated by ChatGPT](pics/fake.png), and [kyegomez has done so for other popular ML methods](https://twitter.com/qbitium/status/1663954096741814272), while intentionally refusing to link to official implementations for his own interests (see https://github.com/kyegomez/tree-of-thoughts/issues/54, https://github.com/kyegomez/tree-of-thoughts/issues/55, https://github.com/kyegomez/tree-of-thoughts/issues/56).
Unfortunately, Google/Github searches go to kyegomez's malicious repo by default as it has more stars. **Please DE-STAR his repo and STAR this to help other people avoid being misled, thanks!**
</details>
![teaser](teaser.png)
![teaser](pics/teaser.png)
Official implementation for paper [Tree of Thoughts: Deliberate Problem Solving with Large Language Models](https://arxiv.org/abs/2305.10601) with code, prompts, model outputs.
Also check [its tweet thread](https://twitter.com/ShunyuYao12/status/1659357547474681857) in 1min.
Please cite the paper and star this repo if you use ToT and find it interesting/useful. Thanks!
```bibtex
@misc{yao2023tree,
title={{Tree of Thoughts}: Deliberate Problem Solving with Large Language Models},
author={Shunyu Yao and Dian Yu and Jeffrey Zhao and Izhak Shafran and Thomas L. Griffiths and Yuan Cao and Karthik Narasimhan},
year={2023},
eprint={2305.10601},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
## Setup
You need to first have an OpenAI API key and store it in the environment variable ``OPENAI_API_KEY`` (see [here](https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety)). If you use custom base url, set it by environment variable ``OPENAI_API_BASE`` (e.g. https://api.openai.com/v1).
- Set up OpenAI API key and store in environment variable ``OPENAI_API_KEY`` (see [here](https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety)).
Package requirement: ``pip install openai backoff sympy numpy``
- Install dependencies and `tot` package (PyPI package coming soon):
```bash
git clone https://github.com/princeton-nlp/tree-of-thought-llm
cd tree-of-thought-llm
pip install -r requirements.txt
pip install -e . # install `tot` package
```
## Experiments
## Quick Start
The following minimal script will attempt to solve the game of 24 with `4 5 6 10` (might be a bit slow as it's using GPT-4):
```python
import argparse
from tot.methods.bfs import solve
from tot.tasks.game24 import Game24Task
args = argparse.Namespace(backend='gpt-4', temperature=0.7, task='game24', naive_run=False, prompt_sample=None, method_generate='propose', method_evaluate='value', method_select='greedy', n_generate_sample=1, n_evaluate_sample=3, n_select_sample=5)
task = Game24Task()
ys, infos = solve(args, task, 900)
print(ys[0])
```
And the output would be something like (note it's not deterministic, and sometimes the output can be wrong):
```
10 - 4 = 6 (left: 5 6 6)
5 * 6 = 30 (left: 6 30)
30 - 6 = 24 (left: 24)
Answer: (5 * (10 - 4)) - 6 = 24
```
## Paper Experiments
Run experiments via ``sh scripts/{game24, text, crosswords}/{standard_sampling, cot_sampling, bfs}.sh``, except in crosswords we use a DFS algorithm for ToT, which can be run via ``scripts/crosswords/search_crosswords-dfs.ipynb``.
@ -55,13 +69,24 @@ The very simple ``run.py`` implements the ToT + BFS algorithm, as well as the na
## Trajectories
## Paper Trajectories
``logs/`` contains all the trajectories from the paper's experiments, except for ``logs/game24/gpt-4_0.7_propose1_value3_greedy5_start900_end1000.json`` which was reproduced after the paper (as the original experiment was done in a notebook) and achieved a 69\% score instead of the original 74\% score due to randomness in GPT decoding. We hope to aggregate multiple runs in the future to account for sampling randomness and update the paper, but this shouldn't affect the main conclusions of the paper.
## How to Add A New Task
Setting up a new task is easy, and mainly involves two steps.
* Set up a new task class in ``tot/tasks/`` and task files in ``tot/data/``. See ``tot/tasks/game24.py`` for an example. Add the task to ``tot/tasks/__init__.py``.
* Set up task-specific prompts in ``tot/prompts/``. See ``tot/prompts/game24.py`` for an example. Depending on the nature of the task, choose ``--method_generate`` (choices=[``sample``, ``propose``]) and ``--method_evaluate`` (choices=[``value``, ``vote``]) and their corresponding prompts.
## Citations
Please cite the paper and star this repo if you use ToT and find it interesting/useful, thanks! Feel free to contact shunyuyao.cs@gmail.com or open an issue if you have any questions.
## Questions
Feel free to contact shunyuyao.cs@gmail.com or open an issue if you have any questions.
```bibtex
@misc{yao2023tree,
title={{Tree of Thoughts}: Deliberate Problem Solving with Large Language Models},
author={Shunyu Yao and Dian Yu and Jeffrey Zhao and Izhak Shafran and Thomas L. Griffiths and Yuan Cao and Karthik Narasimhan},
year={2023},
eprint={2305.10601},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```

View File

@ -16,3 +16,4 @@ sympy==1.12
tqdm==4.65.0
urllib3==2.0.2
yarl==1.9.2
pandas==2.0.3

105
run.py
View File

@ -1,108 +1,18 @@
import os
import json
import itertools
import argparse
import numpy as np
from functools import partial
from models import gpt, gpt_usage
from tasks import get_task
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
value_prompt = task.value_prompt_wrap(x, y)
if cache_value and value_prompt in task.value_cache:
return task.value_cache[value_prompt]
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
value = task.value_outputs_unwrap(x, y, value_outputs)
if cache_value:
task.value_cache[value_prompt] = value
return value
def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
values = []
local_value_cache = {}
for y in ys: # each partial output
if y in local_value_cache: # avoid duplicate candidates
value = 0
else:
value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
local_value_cache[y] = value
values.append(value)
return values
def get_votes(task, x, ys, n_evaluate_sample):
vote_prompt = task.vote_prompt_wrap(x, ys)
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
return values
def get_proposals(task, x, y):
propose_prompt = task.propose_prompt_wrap(x, y)
proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
return [y + _ + '\n' for _ in proposals]
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
if prompt_sample == 'standard':
prompt = task.standard_prompt_wrap(x, y)
elif prompt_sample == 'cot':
prompt = task.cot_prompt_wrap(x, y)
else:
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
samples = gpt(prompt, n=n_generate_sample, stop=stop)
return [y + _ for _ in samples]
def solve(args, task, idx, to_print=True):
print(gpt)
x = task.get_input(idx) # input
ys = [''] # current output candidates
infos = []
for step in range(task.steps):
# generation
if args.method_generate == 'sample':
new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
elif args.method_generate == 'propose':
new_ys = [get_proposals(task, x, y) for y in ys]
new_ys = list(itertools.chain(*new_ys))
ids = list(range(len(new_ys)))
# evaluation
if args.method_evaluate == 'vote':
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
elif args.method_evaluate == 'value':
values = get_values(task, x, new_ys, args.n_evaluate_sample)
# selection
if args.method_select == 'sample':
ps = np.array(values) / sum(values)
select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
elif args.method_select == 'greedy':
select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
select_new_ys = [new_ys[select_id] for select_id in select_ids]
# log
if to_print:
sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
ys = select_new_ys
if to_print:
print(ys)
return ys, {'steps': infos}
def naive_solve(args, task, idx, to_print=True):
x = task.get_input(idx) # input
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
return ys, {}
from tot.tasks import get_task
from tot.methods.bfs import solve, naive_solve
from tot.models import gpt_usage
def run(args):
task = get_task(args.task, args.task_file_path)
task = get_task(args.task)
logs, cnt_avg, cnt_any = [], 0, 0
global gpt
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
if args.naive_run:
file = f'logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
else:
file = f'logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
os.makedirs(os.path.dirname(file), exist_ok=True)
for i in range(args.task_start_index, args.task_end_index):
@ -136,7 +46,6 @@ def parse_args():
args.add_argument('--temperature', type=float, default=0.7)
args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords'])
args.add_argument('--task_file_path', type=str, required=True)
args.add_argument('--task_start_index', type=int, default=900)
args.add_argument('--task_end_index', type=int, default=1000)
@ -145,7 +54,7 @@ def parse_args():
args.add_argument('--method_generate', type=str, choices=['sample', 'propose'])
args.add_argument('--method_evaluate', type=str, choices=['value', 'vote'])
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'])
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'], default='greedy')
args.add_argument('--n_generate_sample', type=int, default=1) # only thing needed if naive_run
args.add_argument('--n_evaluate_sample', type=int, default=1)
args.add_argument('--n_select_sample', type=int, default=1)

View File

@ -1,6 +1,5 @@
python run.py \
--task crosswords \
--task_file_path mini0505_0_100_5.json \
--task_start_index 0 \
--task_end_index 20 \
--naive_run \

View File

@ -14,7 +14,7 @@
"metadata": {},
"outputs": [],
"source": [
"cd ../.."
"cd .."
]
},
{
@ -24,9 +24,9 @@
"outputs": [],
"source": [
"import json\n",
"from prompts.crosswords import propose_prompt, value_prompt\n",
"from models import gpt\n",
"from tasks.crosswords import MiniCrosswordsEnv\n",
"from tot.prompts.crosswords import propose_prompt, value_prompt\n",
"from tot.models import gpt\n",
"from tot.tasks.crosswords import MiniCrosswordsEnv\n",
"\n",
"env = MiniCrosswordsEnv()"
]
@ -61,7 +61,7 @@
"source": [
"import re\n",
"import copy\n",
"from models import gpt\n",
"from tot.models import gpt\n",
"\n",
"def parse_line(input_str):\n",
" # regular expression pattern to match the input string format\n",

View File

@ -1,6 +1,5 @@
python run.py \
--task crosswords \
--task_file_path mini0505_0_100_5.json \
--task_start_index 0 \
--task_end_index 20 \
--naive_run \

View File

@ -1,6 +1,5 @@
python run.py \
--task game24 \
--task_file_path 24.csv \
--task_start_index 900 \
--task_end_index 1000 \
--method_generate propose \

View File

@ -1,6 +1,5 @@
python run.py \
--task game24 \
--task_file_path 24.csv \
--task_start_index 900 \
--task_end_index 1000 \
--naive_run \

View File

@ -1,6 +1,5 @@
python run.py \
--task game24 \
--task_file_path 24.csv \
--task_start_index 900 \
--task_end_index 1000 \
--naive_run \

View File

@ -1,8 +1,7 @@
python run.py \
--task text \
--task_file_path data_100_random_text.txt \
--task_start_index 0 \
--task_end_index 1 \
--task_end_index 100 \
--method_generate sample \
--method_evaluate vote \
--method_select greedy \

View File

@ -1,8 +1,7 @@
python run.py \
--task text \
--task_file_path data_100_random_text.txt \
--task_start_index 0 \
--task_end_index 1 \
--task_end_index 100 \
--naive_run \
--prompt_sample cot \
--n_generate_sample 10 \

View File

@ -1,8 +1,7 @@
python run.py \
--task text \
--task_file_path data_100_random_text.txt \
--task_start_index 0 \
--task_end_index 1 \
--task_end_index 100 \
--naive_run \
--prompt_sample standard \
--n_generate_sample 10 \

37
setup.py Normal file
View File

@ -0,0 +1,37 @@
import setuptools
with open('README.md', 'r', encoding='utf-8') as fh:
long_description = fh.read()
setuptools.setup(
name='tot',
author='Shunyu Yao',
author_email='shunyuyao.cs@gmail.com',
description='Official Implementation of "Tree of Thoughts: Deliberate Problem Solving with Large Language Models"',
keywords='tree-search, large-language-models, llm, prompting, tree-of-thoughts',
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/princeton-nlp/tree-of-thought-llm',
project_urls={
'Homepage': 'https://github.com/princeton-nlp/tree-of-thought-llm',
},
package_dir={'': 'src'},
packages=setuptools.find_packages(where='src'),
classifiers=[
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
'Intended Audience :: Science/Research',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
],
python_requires='>=3.7',
install_requires=[
'setuptools',
],
include_package_data=True,
)

1
src/tot/__init__.py Normal file
View File

@ -0,0 +1 @@
__version__ = "0.1.0"

File diff suppressed because it is too large Load Diff

96
src/tot/methods/bfs.py Normal file
View File

@ -0,0 +1,96 @@
import itertools
import numpy as np
from functools import partial
from tot.models import gpt
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
value_prompt = task.value_prompt_wrap(x, y)
if cache_value and value_prompt in task.value_cache:
return task.value_cache[value_prompt]
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
value = task.value_outputs_unwrap(x, y, value_outputs)
if cache_value:
task.value_cache[value_prompt] = value
return value
def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
values = []
local_value_cache = {}
for y in ys: # each partial output
if y in local_value_cache: # avoid duplicate candidates
value = 0
else:
value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
local_value_cache[y] = value
values.append(value)
return values
def get_votes(task, x, ys, n_evaluate_sample):
vote_prompt = task.vote_prompt_wrap(x, ys)
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
return values
def get_proposals(task, x, y):
propose_prompt = task.propose_prompt_wrap(x, y)
proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
return [y + _ + '\n' for _ in proposals]
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
if prompt_sample == 'standard':
prompt = task.standard_prompt_wrap(x, y)
elif prompt_sample == 'cot':
prompt = task.cot_prompt_wrap(x, y)
else:
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
samples = gpt(prompt, n=n_generate_sample, stop=stop)
return [y + _ for _ in samples]
def solve(args, task, idx, to_print=True):
global gpt
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
print(gpt)
x = task.get_input(idx) # input
ys = [''] # current output candidates
infos = []
for step in range(task.steps):
# generation
if args.method_generate == 'sample':
new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
elif args.method_generate == 'propose':
new_ys = [get_proposals(task, x, y) for y in ys]
new_ys = list(itertools.chain(*new_ys))
ids = list(range(len(new_ys)))
# evaluation
if args.method_evaluate == 'vote':
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
elif args.method_evaluate == 'value':
values = get_values(task, x, new_ys, args.n_evaluate_sample)
# selection
if args.method_select == 'sample':
ps = np.array(values) / sum(values)
select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
elif args.method_select == 'greedy':
select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
select_new_ys = [new_ys[select_id] for select_id in select_ids]
# log
if to_print:
sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
ys = select_new_ys
if to_print:
print(ys)
return ys, {'steps': infos}
def naive_solve(args, task, idx, to_print=True):
global gpt
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
print(gpt)
x = task.get_input(idx) # input
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
return ys, {}

View File

@ -41,5 +41,5 @@ def gpt_usage(backend="gpt-4"):
if backend == "gpt-4":
cost = completion_tokens / 1000 * 0.06 + prompt_tokens / 1000 * 0.03
elif backend == "gpt-3.5-turbo":
cost = completion_tokens / 1000 * 0.002 + prompt_tokens / 1000 * 0.0015
cost = (completion_tokens + prompt_tokens) / 1000 * 0.0002
return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens, "cost": cost}

12
src/tot/tasks/__init__.py Normal file
View File

@ -0,0 +1,12 @@
def get_task(name):
if name == 'game24':
from tot.tasks.game24 import Game24Task
return Game24Task()
elif name == 'text':
from tot.tasks.text import TextTask
return TextTask()
elif name == 'crosswords':
from tot.tasks.crosswords import MiniCrosswordsTask
return MiniCrosswordsTask()
else:
raise NotImplementedError

View File

@ -1,4 +1,5 @@
DATA_PATH = './data'
import os
DATA_PATH = os.path.join(os.path.dirname(__file__), '..', 'data')
class Task:
def __init__(self):

View File

@ -1,13 +1,14 @@
import re
import json
import os
from tasks.base import Task, DATA_PATH
from prompts.crosswords import *
from models import gpt
import json
from tot.tasks.base import Task, DATA_PATH
from tot.prompts.crosswords import *
from tot.models import gpt
class MiniCrosswordsEnv:
def __init__(self, file='mini0505.json'):
self.file = f'data/crosswords/{file}'
self.file = os.path.join(DATA_PATH, 'crosswords', file)
self.file = json.load(open(self.file))
self.n = len(self.file)
self.cache = {}

View File

@ -2,8 +2,8 @@ import re
import os
import sympy
import pandas as pd
from tasks.base import Task, DATA_PATH
from prompts.game24 import *
from tot.tasks.base import Task, DATA_PATH
from tot.prompts.game24 import *
def get_current_numbers(y: str) -> str:

View File

@ -1,8 +1,8 @@
import os
import re
from tasks.base import Task, DATA_PATH
from prompts.text import *
from models import gpt
from tot.tasks.base import Task, DATA_PATH
from tot.prompts.text import *
from tot.models import gpt
class TextTask(Task):

View File

@ -1,12 +0,0 @@
def get_task(name, file=None):
if name == 'game24':
from .game24 import Game24Task
return Game24Task(file)
elif name == 'text':
from .text import TextTask
return TextTask(file)
elif name == 'crosswords':
from .crosswords import MiniCrosswordsTask
return MiniCrosswordsTask(file)
else:
raise NotImplementedError