mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-01-22 02:25:28 +00:00
tot package
This commit is contained in:
parent
7382f2416e
commit
733b009f62
4
MANIFEST.in
Normal file
4
MANIFEST.in
Normal file
@ -0,0 +1,4 @@
|
||||
include src/tot/data/24/24.csv
|
||||
include src/tot/data/crosswords/mini0505_0_100_5.json
|
||||
include src/tot/data/crosswords/mini0505.json
|
||||
include src/tot/data/text/data_100_random_text.txt
|
Before Width: | Height: | Size: 84 KiB After Width: | Height: | Size: 84 KiB |
Before Width: | Height: | Size: 99 KiB After Width: | Height: | Size: 99 KiB |
35
pyproject.toml
Normal file
35
pyproject.toml
Normal file
@ -0,0 +1,35 @@
|
||||
[build-system]
|
||||
requires = ["setuptools >= 61.0.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "tot"
|
||||
version = "0.1.0"
|
||||
description = 'Official Implementation of "Tree of Thoughts: Deliberate Problem Solving with Large Language Models"'
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.7"
|
||||
authors = [{ name = "Shunyu Yao", email = "shunyuyao.cs@gmail.com" }]
|
||||
license = { text = "MIT License" }
|
||||
keywords = ["tree-search", "large-language-models", "llm", "prompting", "tree-of-thoughts"]
|
||||
classifiers = [
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
'Intended Audience :: Science/Research',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
]
|
||||
dynamic=["dependencies"]
|
||||
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
dependencies = {file = ["requirements.txt"]}
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"] # list of folders that contain the packages (["."] by default)
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/princeton-nlp/tree-of-thought-llm"
|
@ -4,7 +4,7 @@
|
||||
<details>
|
||||
<summary>Note: https://github.com/kyegomez/tree-of-thoughts is NOT the correct implementation to replicate paper results. </summary>
|
||||
|
||||
In fact, people have reported that his code [cannot](https://github.com/kyegomez/tree-of-thoughts/issues/52) [properly](https://github.com/kyegomez/tree-of-thoughts/issues/41) [run](https://github.com/kyegomez/tree-of-thoughts/issues/60), and is [probably automatically generated by ChatGPT](fake.png), and [kyegomez has done so for other popular ML methods](https://twitter.com/qbitium/status/1663954096741814272), while intentionally refusing to link to official implementations for his own interests (see https://github.com/kyegomez/tree-of-thoughts/issues/54, https://github.com/kyegomez/tree-of-thoughts/issues/55, https://github.com/kyegomez/tree-of-thoughts/issues/56).
|
||||
In fact, people have reported that his code [cannot](https://github.com/kyegomez/tree-of-thoughts/issues/52) [properly](https://github.com/kyegomez/tree-of-thoughts/issues/41) [run](https://github.com/kyegomez/tree-of-thoughts/issues/60), and is [probably automatically generated by ChatGPT](pics/fake.png), and [kyegomez has done so for other popular ML methods](https://twitter.com/qbitium/status/1663954096741814272), while intentionally refusing to link to official implementations for his own interests (see https://github.com/kyegomez/tree-of-thoughts/issues/54, https://github.com/kyegomez/tree-of-thoughts/issues/55, https://github.com/kyegomez/tree-of-thoughts/issues/56).
|
||||
Unfortunately, Google/Github searches go to kyegomez's malicious repo by default as it has more stars. **Please DE-STAR his repo and STAR this to help other people avoid being misled, thanks!**
|
||||
</details>
|
||||
|
||||
@ -12,7 +12,7 @@ Unfortunately, Google/Github searches go to kyegomez's malicious repo by default
|
||||
|
||||
|
||||
|
||||
![teaser](teaser.png)
|
||||
![teaser](pics/teaser.png)
|
||||
|
||||
Official implementation for paper [Tree of Thoughts: Deliberate Problem Solving with Large Language Models](https://arxiv.org/abs/2305.10601) with code, prompts, model outputs.
|
||||
Also check [its tweet thread](https://twitter.com/ShunyuYao12/status/1659357547474681857) in 1min.
|
||||
|
@ -16,3 +16,4 @@ sympy==1.12
|
||||
tqdm==4.65.0
|
||||
urllib3==2.0.2
|
||||
yarl==1.9.2
|
||||
pandas==2.0.3
|
105
run.py
105
run.py
@ -1,108 +1,18 @@
|
||||
import os
|
||||
import json
|
||||
import itertools
|
||||
import argparse
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from models import gpt, gpt_usage
|
||||
from tasks import get_task
|
||||
|
||||
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
|
||||
value_prompt = task.value_prompt_wrap(x, y)
|
||||
if cache_value and value_prompt in task.value_cache:
|
||||
return task.value_cache[value_prompt]
|
||||
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
|
||||
value = task.value_outputs_unwrap(x, y, value_outputs)
|
||||
if cache_value:
|
||||
task.value_cache[value_prompt] = value
|
||||
return value
|
||||
|
||||
def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
|
||||
values = []
|
||||
local_value_cache = {}
|
||||
for y in ys: # each partial output
|
||||
if y in local_value_cache: # avoid duplicate candidates
|
||||
value = 0
|
||||
else:
|
||||
value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
|
||||
local_value_cache[y] = value
|
||||
values.append(value)
|
||||
return values
|
||||
|
||||
def get_votes(task, x, ys, n_evaluate_sample):
|
||||
vote_prompt = task.vote_prompt_wrap(x, ys)
|
||||
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
|
||||
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
|
||||
return values
|
||||
|
||||
def get_proposals(task, x, y):
|
||||
propose_prompt = task.propose_prompt_wrap(x, y)
|
||||
proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
|
||||
return [y + _ + '\n' for _ in proposals]
|
||||
|
||||
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
|
||||
if prompt_sample == 'standard':
|
||||
prompt = task.standard_prompt_wrap(x, y)
|
||||
elif prompt_sample == 'cot':
|
||||
prompt = task.cot_prompt_wrap(x, y)
|
||||
else:
|
||||
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
|
||||
samples = gpt(prompt, n=n_generate_sample, stop=stop)
|
||||
return [y + _ for _ in samples]
|
||||
|
||||
def solve(args, task, idx, to_print=True):
|
||||
print(gpt)
|
||||
x = task.get_input(idx) # input
|
||||
ys = [''] # current output candidates
|
||||
infos = []
|
||||
for step in range(task.steps):
|
||||
# generation
|
||||
if args.method_generate == 'sample':
|
||||
new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
|
||||
elif args.method_generate == 'propose':
|
||||
new_ys = [get_proposals(task, x, y) for y in ys]
|
||||
new_ys = list(itertools.chain(*new_ys))
|
||||
ids = list(range(len(new_ys)))
|
||||
# evaluation
|
||||
if args.method_evaluate == 'vote':
|
||||
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
|
||||
elif args.method_evaluate == 'value':
|
||||
values = get_values(task, x, new_ys, args.n_evaluate_sample)
|
||||
|
||||
# selection
|
||||
if args.method_select == 'sample':
|
||||
ps = np.array(values) / sum(values)
|
||||
select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
|
||||
elif args.method_select == 'greedy':
|
||||
select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
|
||||
select_new_ys = [new_ys[select_id] for select_id in select_ids]
|
||||
|
||||
# log
|
||||
if to_print:
|
||||
sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
|
||||
print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
|
||||
|
||||
infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
|
||||
ys = select_new_ys
|
||||
|
||||
if to_print:
|
||||
print(ys)
|
||||
return ys, {'steps': infos}
|
||||
|
||||
def naive_solve(args, task, idx, to_print=True):
|
||||
x = task.get_input(idx) # input
|
||||
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
|
||||
return ys, {}
|
||||
from tot.tasks import get_task
|
||||
from tot.methods.bfs import solve, naive_solve
|
||||
from tot.models import gpt_usage
|
||||
|
||||
def run(args):
|
||||
task = get_task(args.task, args.task_file_path)
|
||||
task = get_task(args.task)
|
||||
logs, cnt_avg, cnt_any = [], 0, 0
|
||||
global gpt
|
||||
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
|
||||
if args.naive_run:
|
||||
file = f'logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
|
||||
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
|
||||
else:
|
||||
file = f'logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
|
||||
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
|
||||
os.makedirs(os.path.dirname(file), exist_ok=True)
|
||||
|
||||
for i in range(args.task_start_index, args.task_end_index):
|
||||
@ -136,7 +46,6 @@ def parse_args():
|
||||
args.add_argument('--temperature', type=float, default=0.7)
|
||||
|
||||
args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords'])
|
||||
args.add_argument('--task_file_path', type=str, required=True)
|
||||
args.add_argument('--task_start_index', type=int, default=900)
|
||||
args.add_argument('--task_end_index', type=int, default=1000)
|
||||
|
||||
@ -145,7 +54,7 @@ def parse_args():
|
||||
|
||||
args.add_argument('--method_generate', type=str, choices=['sample', 'propose'])
|
||||
args.add_argument('--method_evaluate', type=str, choices=['value', 'vote'])
|
||||
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'])
|
||||
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'], default='greedy')
|
||||
args.add_argument('--n_generate_sample', type=int, default=1) # only thing needed if naive_run
|
||||
args.add_argument('--n_evaluate_sample', type=int, default=1)
|
||||
args.add_argument('--n_select_sample', type=int, default=1)
|
||||
|
@ -1,6 +1,5 @@
|
||||
python run.py \
|
||||
--task crosswords \
|
||||
--task_file_path mini0505_0_100_5.json \
|
||||
--task_start_index 0 \
|
||||
--task_end_index 20 \
|
||||
--naive_run \
|
||||
|
@ -14,7 +14,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cd ../.."
|
||||
"cd .."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -24,9 +24,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"from prompts.crosswords import propose_prompt, value_prompt\n",
|
||||
"from models import gpt\n",
|
||||
"from tasks.crosswords import MiniCrosswordsEnv\n",
|
||||
"from tot.prompts.crosswords import propose_prompt, value_prompt\n",
|
||||
"from tot.models import gpt\n",
|
||||
"from tot.tasks.crosswords import MiniCrosswordsEnv\n",
|
||||
"\n",
|
||||
"env = MiniCrosswordsEnv()"
|
||||
]
|
||||
@ -61,7 +61,7 @@
|
||||
"source": [
|
||||
"import re\n",
|
||||
"import copy\n",
|
||||
"from models import gpt\n",
|
||||
"from tot.models import gpt\n",
|
||||
"\n",
|
||||
"def parse_line(input_str):\n",
|
||||
" # regular expression pattern to match the input string format\n",
|
||||
|
@ -1,6 +1,5 @@
|
||||
python run.py \
|
||||
--task crosswords \
|
||||
--task_file_path mini0505_0_100_5.json \
|
||||
--task_start_index 0 \
|
||||
--task_end_index 20 \
|
||||
--naive_run \
|
||||
|
@ -1,6 +1,5 @@
|
||||
python run.py \
|
||||
--task game24 \
|
||||
--task_file_path 24.csv \
|
||||
--task_start_index 900 \
|
||||
--task_end_index 1000 \
|
||||
--method_generate propose \
|
||||
|
@ -1,6 +1,5 @@
|
||||
python run.py \
|
||||
--task game24 \
|
||||
--task_file_path 24.csv \
|
||||
--task_start_index 900 \
|
||||
--task_end_index 1000 \
|
||||
--naive_run \
|
||||
|
@ -1,6 +1,5 @@
|
||||
python run.py \
|
||||
--task game24 \
|
||||
--task_file_path 24.csv \
|
||||
--task_start_index 900 \
|
||||
--task_end_index 1000 \
|
||||
--naive_run \
|
||||
|
@ -1,8 +1,7 @@
|
||||
python run.py \
|
||||
--task text \
|
||||
--task_file_path data_100_random_text.txt \
|
||||
--task_start_index 0 \
|
||||
--task_end_index 1 \
|
||||
--task_end_index 100 \
|
||||
--method_generate sample \
|
||||
--method_evaluate vote \
|
||||
--method_select greedy \
|
||||
|
@ -1,8 +1,7 @@
|
||||
python run.py \
|
||||
--task text \
|
||||
--task_file_path data_100_random_text.txt \
|
||||
--task_start_index 0 \
|
||||
--task_end_index 1 \
|
||||
--task_end_index 100 \
|
||||
--naive_run \
|
||||
--prompt_sample cot \
|
||||
--n_generate_sample 10 \
|
||||
|
@ -1,8 +1,7 @@
|
||||
python run.py \
|
||||
--task text \
|
||||
--task_file_path data_100_random_text.txt \
|
||||
--task_start_index 0 \
|
||||
--task_end_index 1 \
|
||||
--task_end_index 100 \
|
||||
--naive_run \
|
||||
--prompt_sample standard \
|
||||
--n_generate_sample 10 \
|
||||
|
37
setup.py
Normal file
37
setup.py
Normal file
@ -0,0 +1,37 @@
|
||||
import setuptools
|
||||
|
||||
with open('README.md', 'r', encoding='utf-8') as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
|
||||
setuptools.setup(
|
||||
name='tot',
|
||||
author='Shunyu Yao',
|
||||
author_email='shunyuyao.cs@gmail.com',
|
||||
description='Official Implementation of "Tree of Thoughts: Deliberate Problem Solving with Large Language Models"',
|
||||
keywords='tree-search, large-language-models, llm, prompting, tree-of-thoughts',
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
url='https://github.com/princeton-nlp/tree-of-thought-llm',
|
||||
project_urls={
|
||||
'Homepage': 'https://github.com/princeton-nlp/tree-of-thought-llm',
|
||||
},
|
||||
package_dir={'': 'src'},
|
||||
packages=setuptools.find_packages(where='src'),
|
||||
classifiers=[
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
'Intended Audience :: Science/Research',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
],
|
||||
python_requires='>=3.7',
|
||||
install_requires=[
|
||||
'setuptools',
|
||||
],
|
||||
include_package_data=True,
|
||||
)
|
1
src/tot/__init__.py
Normal file
1
src/tot/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
__version__ = "0.1.0"
|
File diff suppressed because it is too large
Load Diff
96
src/tot/methods/bfs.py
Normal file
96
src/tot/methods/bfs.py
Normal file
@ -0,0 +1,96 @@
|
||||
import itertools
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from tot.models import gpt
|
||||
|
||||
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
|
||||
value_prompt = task.value_prompt_wrap(x, y)
|
||||
if cache_value and value_prompt in task.value_cache:
|
||||
return task.value_cache[value_prompt]
|
||||
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
|
||||
value = task.value_outputs_unwrap(x, y, value_outputs)
|
||||
if cache_value:
|
||||
task.value_cache[value_prompt] = value
|
||||
return value
|
||||
|
||||
def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
|
||||
values = []
|
||||
local_value_cache = {}
|
||||
for y in ys: # each partial output
|
||||
if y in local_value_cache: # avoid duplicate candidates
|
||||
value = 0
|
||||
else:
|
||||
value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
|
||||
local_value_cache[y] = value
|
||||
values.append(value)
|
||||
return values
|
||||
|
||||
def get_votes(task, x, ys, n_evaluate_sample):
|
||||
vote_prompt = task.vote_prompt_wrap(x, ys)
|
||||
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
|
||||
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
|
||||
return values
|
||||
|
||||
def get_proposals(task, x, y):
|
||||
propose_prompt = task.propose_prompt_wrap(x, y)
|
||||
proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
|
||||
return [y + _ + '\n' for _ in proposals]
|
||||
|
||||
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
|
||||
if prompt_sample == 'standard':
|
||||
prompt = task.standard_prompt_wrap(x, y)
|
||||
elif prompt_sample == 'cot':
|
||||
prompt = task.cot_prompt_wrap(x, y)
|
||||
else:
|
||||
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
|
||||
samples = gpt(prompt, n=n_generate_sample, stop=stop)
|
||||
return [y + _ for _ in samples]
|
||||
|
||||
def solve(args, task, idx, to_print=True):
|
||||
global gpt
|
||||
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
|
||||
print(gpt)
|
||||
x = task.get_input(idx) # input
|
||||
ys = [''] # current output candidates
|
||||
infos = []
|
||||
for step in range(task.steps):
|
||||
# generation
|
||||
if args.method_generate == 'sample':
|
||||
new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
|
||||
elif args.method_generate == 'propose':
|
||||
new_ys = [get_proposals(task, x, y) for y in ys]
|
||||
new_ys = list(itertools.chain(*new_ys))
|
||||
ids = list(range(len(new_ys)))
|
||||
# evaluation
|
||||
if args.method_evaluate == 'vote':
|
||||
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
|
||||
elif args.method_evaluate == 'value':
|
||||
values = get_values(task, x, new_ys, args.n_evaluate_sample)
|
||||
|
||||
# selection
|
||||
if args.method_select == 'sample':
|
||||
ps = np.array(values) / sum(values)
|
||||
select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
|
||||
elif args.method_select == 'greedy':
|
||||
select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
|
||||
select_new_ys = [new_ys[select_id] for select_id in select_ids]
|
||||
|
||||
# log
|
||||
if to_print:
|
||||
sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
|
||||
print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
|
||||
|
||||
infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
|
||||
ys = select_new_ys
|
||||
|
||||
if to_print:
|
||||
print(ys)
|
||||
return ys, {'steps': infos}
|
||||
|
||||
def naive_solve(args, task, idx, to_print=True):
|
||||
global gpt
|
||||
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
|
||||
print(gpt)
|
||||
x = task.get_input(idx) # input
|
||||
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
|
||||
return ys, {}
|
@ -41,5 +41,5 @@ def gpt_usage(backend="gpt-4"):
|
||||
if backend == "gpt-4":
|
||||
cost = completion_tokens / 1000 * 0.06 + prompt_tokens / 1000 * 0.03
|
||||
elif backend == "gpt-3.5-turbo":
|
||||
cost = completion_tokens / 1000 * 0.002 + prompt_tokens / 1000 * 0.0015
|
||||
cost = (completion_tokens + prompt_tokens) / 1000 * 0.0002
|
||||
return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens, "cost": cost}
|
12
src/tot/tasks/__init__.py
Normal file
12
src/tot/tasks/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
def get_task(name):
|
||||
if name == 'game24':
|
||||
from tot.tasks.game24 import Game24Task
|
||||
return Game24Task()
|
||||
elif name == 'text':
|
||||
from tot.tasks.text import TextTask
|
||||
return TextTask()
|
||||
elif name == 'crosswords':
|
||||
from tot.tasks.crosswords import MiniCrosswordsTask
|
||||
return MiniCrosswordsTask()
|
||||
else:
|
||||
raise NotImplementedError
|
@ -1,4 +1,5 @@
|
||||
DATA_PATH = './data'
|
||||
import os
|
||||
DATA_PATH = os.path.join(os.path.dirname(__file__), '..', 'data')
|
||||
|
||||
class Task:
|
||||
def __init__(self):
|
@ -1,13 +1,14 @@
|
||||
import re
|
||||
import json
|
||||
import os
|
||||
from tasks.base import Task, DATA_PATH
|
||||
from prompts.crosswords import *
|
||||
from models import gpt
|
||||
import json
|
||||
from tot.tasks.base import Task, DATA_PATH
|
||||
from tot.prompts.crosswords import *
|
||||
from tot.models import gpt
|
||||
|
||||
class MiniCrosswordsEnv:
|
||||
def __init__(self, file='mini0505.json'):
|
||||
self.file = f'data/crosswords/{file}'
|
||||
self.file = os.path.join(DATA_PATH, 'crosswords', file)
|
||||
|
||||
self.file = json.load(open(self.file))
|
||||
self.n = len(self.file)
|
||||
self.cache = {}
|
@ -2,8 +2,8 @@ import re
|
||||
import os
|
||||
import sympy
|
||||
import pandas as pd
|
||||
from tasks.base import Task, DATA_PATH
|
||||
from prompts.game24 import *
|
||||
from tot.tasks.base import Task, DATA_PATH
|
||||
from tot.prompts.game24 import *
|
||||
|
||||
|
||||
def get_current_numbers(y: str) -> str:
|
@ -1,8 +1,8 @@
|
||||
import os
|
||||
import re
|
||||
from tasks.base import Task, DATA_PATH
|
||||
from prompts.text import *
|
||||
from models import gpt
|
||||
from tot.tasks.base import Task, DATA_PATH
|
||||
from tot.prompts.text import *
|
||||
from tot.models import gpt
|
||||
|
||||
|
||||
class TextTask(Task):
|
@ -1,12 +0,0 @@
|
||||
def get_task(name, file=None):
|
||||
if name == 'game24':
|
||||
from .game24 import Game24Task
|
||||
return Game24Task(file)
|
||||
elif name == 'text':
|
||||
from .text import TextTask
|
||||
return TextTask(file)
|
||||
elif name == 'crosswords':
|
||||
from .crosswords import MiniCrosswordsTask
|
||||
return MiniCrosswordsTask(file)
|
||||
else:
|
||||
raise NotImplementedError
|
Loading…
Reference in New Issue
Block a user