tree-of-thought-llm/scripts/crosswords/search_crosswords-dfs.ipynb

257 lines
8.3 KiB
Plaintext
Raw Normal View History

2023-05-23 22:34:41 +00:00
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Env"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
2023-07-04 02:16:03 +00:00
"cd .."
2023-05-23 22:34:41 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
2023-07-04 02:16:03 +00:00
"from tot.prompts.crosswords import propose_prompt, value_prompt\n",
"from tot.models import gpt\n",
"from tot.tasks.crosswords import MiniCrosswordsEnv\n",
2023-05-23 22:34:41 +00:00
"\n",
"env = MiniCrosswordsEnv()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Prompt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def prompt_wrap(obs):\n",
" return propose_prompt.format(input=obs)\n",
"\n",
"print(prompt_wrap(env.reset(0)))\n",
"# print('---------')\n",
"# print(prompt_wrap(env.step('h2. value')[0]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import re\n",
"import copy\n",
2023-07-04 02:16:03 +00:00
"from tot.models import gpt\n",
2023-05-23 22:34:41 +00:00
"\n",
"def parse_line(input_str):\n",
" # regular expression pattern to match the input string format\n",
" pattern = r'^([hv][1-5])\\. ([a-zA-Z]{5,5}) \\((certain|high|medium|low)\\).*$'\n",
"\n",
" # use regex to extract the parts of the input string\n",
" match = re.match(pattern, input_str)\n",
"\n",
" if match:\n",
" # extract the matched groups\n",
" parts = [match.group(1), match.group(2), match.group(3)]\n",
" return parts\n",
" else:\n",
" return None\n",
"\n",
"confidence_to_value = {'certain': 1, 'high': 0.5, 'medium': 0.2, 'low': 0.1} # TODO: ad hoc\n",
"\n",
"def parse_response(response):\n",
" # split the response into lines\n",
" lines = response.split('\\n')\n",
"\n",
" # parse each line\n",
" parsed_lines = [parse_line(line) for line in lines]\n",
"\n",
" # filter out the lines that didn't match the format\n",
" parsed_lines = [(line[0].lower() + '. ' + line[1].lower(), confidence_to_value.get(line[2], 0)) for line in parsed_lines if line is not None]\n",
"\n",
" return parsed_lines if len(parsed_lines) >= 1 else None\n",
"\n",
"\n",
"def get_candidates_to_scores(env):\n",
" obs = env.render()\n",
" if obs in env.cache: \n",
" print('cache hit')\n",
" return env.cache[obs]\n",
" print('call gpt')\n",
" responses = gpt(prompt_wrap(obs), model='gpt-4', n=8)\n",
" candidates_to_scores = {}\n",
" for response in responses:\n",
" parsed_response = parse_response(response)\n",
" if parsed_response:\n",
" for candidate, score in parsed_response:\n",
" candidates_to_scores[candidate] = candidates_to_scores.get(candidate, 0) + score\n",
" # choose candiate with highest score\n",
" # print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))\n",
" env.cache[obs] = candidates_to_scores\n",
" return candidates_to_scores\n",
"\n",
"def propose_score(env, idx):\n",
" obs = env.reset(idx)\n",
" done = False\n",
" infos = []\n",
" while not done:\n",
" responses = gpt(prompt_wrap(obs), model='gpt-4', n=5)\n",
" candidates_to_scores = {}\n",
" for response in responses:\n",
" parsed_response = parse_response(response)\n",
" if parsed_response:\n",
" for candidate, score in parsed_response:\n",
" candidates_to_scores[candidate] = candidates_to_scores.get(candidate, 0) + score\n",
" # choose candiate with highest score\n",
" print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))\n",
" if len(candidates_to_scores) == 0:\n",
" break\n",
" candidates = sorted(candidates_to_scores, key=candidates_to_scores.get, reverse=True)\n",
" for candidate in candidates:\n",
" env_ = copy.deepcopy(env)\n",
" env_.step(candidate)\n",
" if not any(_ == 2 for _ in env_.status):\n",
" break\n",
" print(candidate)\n",
" # candidate = input()\n",
" obs, r, done, info = env.step(candidate)\n",
" print(obs)\n",
" print(env.steps, info)\n",
" print('-------------------\\n\\n\\n')\n",
" infos.append(info)\n",
" return infos"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# DFS"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def dfs(env, actions, infos, time_limit, prune, max_per_state):\n",
" # get candidate thoughts\n",
" candidates_to_scores = get_candidates_to_scores(env)\n",
" if len(candidates_to_scores) == 0: return 0, [], []\n",
" print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))\n",
"\n",
" # back up current state\n",
" board, status, steps = env.board.copy(), env.status.copy(), env.steps\n",
"\n",
" # try each candidate\n",
" cnt_per_state = 0\n",
" for action in sorted(candidates_to_scores, key=candidates_to_scores.get, reverse=True):\n",
" obs, r, done, info = env.step(action)\n",
" r = info['r_word']\n",
" if len(infos) < time_limit and env.steps < 10 and not any(_ == 2 for _ in env.status): # not violating any existing constraints\n",
" cnt_per_state += 1\n",
" if cnt_per_state > max_per_state: break\n",
" count = env.prompt_status() \n",
" actions.append(action) \n",
"\n",
" print(len(infos))\n",
" print(actions)\n",
" print(env.render_board())\n",
" print(info)\n",
" print(count)\n",
" if infos:\n",
" best = max(infos, key=lambda x: x['info']['r_word'])\n",
" print('best', best)\n",
" print('--------------')\n",
" print()\n",
"\n",
" info = {'total_step': len(infos), 'env_step': env.steps, 'actions': actions.copy(), 'info': info, 'count': count}\n",
" infos.append(info)\n",
" if not prune or count['impossible'] < 1: # only continue if the current status is possible\n",
" dfs(env, actions, infos, time_limit, prune, max_per_state)\n",
" actions.pop()\n",
" env.reset(env.idx, board=board.copy(), status=status.copy(), steps=steps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# dfs with pruning\n",
"infoss = []\n",
"for i in range(0, 100, 5):\n",
" env.reset(i)\n",
" infos = []\n",
" actions = []\n",
" dfs(env, actions, infos, 100, prune=True, max_per_state=3)\n",
" infoss.append(infos)\n",
" with open('logs/crosswords/infoss_dfs_prune.json', 'w') as fout:\n",
" json.dump(infoss, fout)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# dfs without pruning\n",
"infoss = []\n",
"for i in range(0, 100, 5):\n",
" env.reset(i)\n",
" infos = []\n",
" actions = []\n",
" dfs(env, actions, infos, 100, prune=False, max_per_state=3)\n",
" infoss.append(infos)\n",
" with open('logs/crosswords/infoss_dfs_no_prune.json', 'w') as fout:\n",
" json.dump(infoss, fout)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}