mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-01-22 18:45:32 +00:00
257 lines
8.3 KiB
Plaintext
257 lines
8.3 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"attachments": {},
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Env"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"cd ../.."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import json\n",
|
||
|
"from prompts.crosswords import propose_prompt, value_prompt\n",
|
||
|
"from models import gpt\n",
|
||
|
"from tasks.crosswords import MiniCrosswordsEnv\n",
|
||
|
"\n",
|
||
|
"env = MiniCrosswordsEnv()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"attachments": {},
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Prompt"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def prompt_wrap(obs):\n",
|
||
|
" return propose_prompt.format(input=obs)\n",
|
||
|
"\n",
|
||
|
"print(prompt_wrap(env.reset(0)))\n",
|
||
|
"# print('---------')\n",
|
||
|
"# print(prompt_wrap(env.step('h2. value')[0]))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import re\n",
|
||
|
"import copy\n",
|
||
|
"from models import gpt\n",
|
||
|
"\n",
|
||
|
"def parse_line(input_str):\n",
|
||
|
" # regular expression pattern to match the input string format\n",
|
||
|
" pattern = r'^([hv][1-5])\\. ([a-zA-Z]{5,5}) \\((certain|high|medium|low)\\).*$'\n",
|
||
|
"\n",
|
||
|
" # use regex to extract the parts of the input string\n",
|
||
|
" match = re.match(pattern, input_str)\n",
|
||
|
"\n",
|
||
|
" if match:\n",
|
||
|
" # extract the matched groups\n",
|
||
|
" parts = [match.group(1), match.group(2), match.group(3)]\n",
|
||
|
" return parts\n",
|
||
|
" else:\n",
|
||
|
" return None\n",
|
||
|
"\n",
|
||
|
"confidence_to_value = {'certain': 1, 'high': 0.5, 'medium': 0.2, 'low': 0.1} # TODO: ad hoc\n",
|
||
|
"\n",
|
||
|
"def parse_response(response):\n",
|
||
|
" # split the response into lines\n",
|
||
|
" lines = response.split('\\n')\n",
|
||
|
"\n",
|
||
|
" # parse each line\n",
|
||
|
" parsed_lines = [parse_line(line) for line in lines]\n",
|
||
|
"\n",
|
||
|
" # filter out the lines that didn't match the format\n",
|
||
|
" parsed_lines = [(line[0].lower() + '. ' + line[1].lower(), confidence_to_value.get(line[2], 0)) for line in parsed_lines if line is not None]\n",
|
||
|
"\n",
|
||
|
" return parsed_lines if len(parsed_lines) >= 1 else None\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def get_candidates_to_scores(env):\n",
|
||
|
" obs = env.render()\n",
|
||
|
" if obs in env.cache: \n",
|
||
|
" print('cache hit')\n",
|
||
|
" return env.cache[obs]\n",
|
||
|
" print('call gpt')\n",
|
||
|
" responses = gpt(prompt_wrap(obs), model='gpt-4', n=8)\n",
|
||
|
" candidates_to_scores = {}\n",
|
||
|
" for response in responses:\n",
|
||
|
" parsed_response = parse_response(response)\n",
|
||
|
" if parsed_response:\n",
|
||
|
" for candidate, score in parsed_response:\n",
|
||
|
" candidates_to_scores[candidate] = candidates_to_scores.get(candidate, 0) + score\n",
|
||
|
" # choose candiate with highest score\n",
|
||
|
" # print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))\n",
|
||
|
" env.cache[obs] = candidates_to_scores\n",
|
||
|
" return candidates_to_scores\n",
|
||
|
"\n",
|
||
|
"def propose_score(env, idx):\n",
|
||
|
" obs = env.reset(idx)\n",
|
||
|
" done = False\n",
|
||
|
" infos = []\n",
|
||
|
" while not done:\n",
|
||
|
" responses = gpt(prompt_wrap(obs), model='gpt-4', n=5)\n",
|
||
|
" candidates_to_scores = {}\n",
|
||
|
" for response in responses:\n",
|
||
|
" parsed_response = parse_response(response)\n",
|
||
|
" if parsed_response:\n",
|
||
|
" for candidate, score in parsed_response:\n",
|
||
|
" candidates_to_scores[candidate] = candidates_to_scores.get(candidate, 0) + score\n",
|
||
|
" # choose candiate with highest score\n",
|
||
|
" print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))\n",
|
||
|
" if len(candidates_to_scores) == 0:\n",
|
||
|
" break\n",
|
||
|
" candidates = sorted(candidates_to_scores, key=candidates_to_scores.get, reverse=True)\n",
|
||
|
" for candidate in candidates:\n",
|
||
|
" env_ = copy.deepcopy(env)\n",
|
||
|
" env_.step(candidate)\n",
|
||
|
" if not any(_ == 2 for _ in env_.status):\n",
|
||
|
" break\n",
|
||
|
" print(candidate)\n",
|
||
|
" # candidate = input()\n",
|
||
|
" obs, r, done, info = env.step(candidate)\n",
|
||
|
" print(obs)\n",
|
||
|
" print(env.steps, info)\n",
|
||
|
" print('-------------------\\n\\n\\n')\n",
|
||
|
" infos.append(info)\n",
|
||
|
" return infos"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"attachments": {},
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# DFS"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def dfs(env, actions, infos, time_limit, prune, max_per_state):\n",
|
||
|
" # get candidate thoughts\n",
|
||
|
" candidates_to_scores = get_candidates_to_scores(env)\n",
|
||
|
" if len(candidates_to_scores) == 0: return 0, [], []\n",
|
||
|
" print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))\n",
|
||
|
"\n",
|
||
|
" # back up current state\n",
|
||
|
" board, status, steps = env.board.copy(), env.status.copy(), env.steps\n",
|
||
|
"\n",
|
||
|
" # try each candidate\n",
|
||
|
" cnt_per_state = 0\n",
|
||
|
" for action in sorted(candidates_to_scores, key=candidates_to_scores.get, reverse=True):\n",
|
||
|
" obs, r, done, info = env.step(action)\n",
|
||
|
" r = info['r_word']\n",
|
||
|
" if len(infos) < time_limit and env.steps < 10 and not any(_ == 2 for _ in env.status): # not violating any existing constraints\n",
|
||
|
" cnt_per_state += 1\n",
|
||
|
" if cnt_per_state > max_per_state: break\n",
|
||
|
" count = env.prompt_status() \n",
|
||
|
" actions.append(action) \n",
|
||
|
"\n",
|
||
|
" print(len(infos))\n",
|
||
|
" print(actions)\n",
|
||
|
" print(env.render_board())\n",
|
||
|
" print(info)\n",
|
||
|
" print(count)\n",
|
||
|
" if infos:\n",
|
||
|
" best = max(infos, key=lambda x: x['info']['r_word'])\n",
|
||
|
" print('best', best)\n",
|
||
|
" print('--------------')\n",
|
||
|
" print()\n",
|
||
|
"\n",
|
||
|
" info = {'total_step': len(infos), 'env_step': env.steps, 'actions': actions.copy(), 'info': info, 'count': count}\n",
|
||
|
" infos.append(info)\n",
|
||
|
" if not prune or count['impossible'] < 1: # only continue if the current status is possible\n",
|
||
|
" dfs(env, actions, infos, time_limit, prune, max_per_state)\n",
|
||
|
" actions.pop()\n",
|
||
|
" env.reset(env.idx, board=board.copy(), status=status.copy(), steps=steps)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# dfs with pruning\n",
|
||
|
"infoss = []\n",
|
||
|
"for i in range(0, 100, 5):\n",
|
||
|
" env.reset(i)\n",
|
||
|
" infos = []\n",
|
||
|
" actions = []\n",
|
||
|
" dfs(env, actions, infos, 100, prune=True, max_per_state=3)\n",
|
||
|
" infoss.append(infos)\n",
|
||
|
" with open('logs/crosswords/infoss_dfs_prune.json', 'w') as fout:\n",
|
||
|
" json.dump(infoss, fout)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# dfs without pruning\n",
|
||
|
"infoss = []\n",
|
||
|
"for i in range(0, 100, 5):\n",
|
||
|
" env.reset(i)\n",
|
||
|
" infos = []\n",
|
||
|
" actions = []\n",
|
||
|
" dfs(env, actions, infos, 100, prune=False, max_per_state=3)\n",
|
||
|
" infoss.append(infos)\n",
|
||
|
" with open('logs/crosswords/infoss_dfs_no_prune.json', 'w') as fout:\n",
|
||
|
" json.dump(infoss, fout)"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.7.4"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|