{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Env" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cd .." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "from tot.prompts.crosswords import propose_prompt, value_prompt\n", "from tot.models import gpt\n", "from tot.tasks.crosswords import MiniCrosswordsEnv\n", "\n", "env = MiniCrosswordsEnv()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Prompt" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def prompt_wrap(obs):\n", " return propose_prompt.format(input=obs)\n", "\n", "print(prompt_wrap(env.reset(0)))\n", "# print('---------')\n", "# print(prompt_wrap(env.step('h2. value')[0]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import re\n", "import copy\n", "from tot.models import gpt\n", "\n", "def parse_line(input_str):\n", " # regular expression pattern to match the input string format\n", " pattern = r'^([hv][1-5])\\. ([a-zA-Z]{5,5}) \\((certain|high|medium|low)\\).*$'\n", "\n", " # use regex to extract the parts of the input string\n", " match = re.match(pattern, input_str)\n", "\n", " if match:\n", " # extract the matched groups\n", " parts = [match.group(1), match.group(2), match.group(3)]\n", " return parts\n", " else:\n", " return None\n", "\n", "confidence_to_value = {'certain': 1, 'high': 0.5, 'medium': 0.2, 'low': 0.1} # TODO: ad hoc\n", "\n", "def parse_response(response):\n", " # split the response into lines\n", " lines = response.split('\\n')\n", "\n", " # parse each line\n", " parsed_lines = [parse_line(line) for line in lines]\n", "\n", " # filter out the lines that didn't match the format\n", " parsed_lines = [(line[0].lower() + '. ' + line[1].lower(), confidence_to_value.get(line[2], 0)) for line in parsed_lines if line is not None]\n", "\n", " return parsed_lines if len(parsed_lines) >= 1 else None\n", "\n", "\n", "def get_candidates_to_scores(env):\n", " obs = env.render()\n", " if obs in env.cache: \n", " print('cache hit')\n", " return env.cache[obs]\n", " print('call gpt')\n", " responses = gpt(prompt_wrap(obs), model='gpt-4', n=8)\n", " candidates_to_scores = {}\n", " for response in responses:\n", " parsed_response = parse_response(response)\n", " if parsed_response:\n", " for candidate, score in parsed_response:\n", " candidates_to_scores[candidate] = candidates_to_scores.get(candidate, 0) + score\n", " # choose candiate with highest score\n", " # print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))\n", " env.cache[obs] = candidates_to_scores\n", " return candidates_to_scores\n", "\n", "def propose_score(env, idx):\n", " obs = env.reset(idx)\n", " done = False\n", " infos = []\n", " while not done:\n", " responses = gpt(prompt_wrap(obs), model='gpt-4', n=5)\n", " candidates_to_scores = {}\n", " for response in responses:\n", " parsed_response = parse_response(response)\n", " if parsed_response:\n", " for candidate, score in parsed_response:\n", " candidates_to_scores[candidate] = candidates_to_scores.get(candidate, 0) + score\n", " # choose candiate with highest score\n", " print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))\n", " if len(candidates_to_scores) == 0:\n", " break\n", " candidates = sorted(candidates_to_scores, key=candidates_to_scores.get, reverse=True)\n", " for candidate in candidates:\n", " env_ = copy.deepcopy(env)\n", " env_.step(candidate)\n", " if not any(_ == 2 for _ in env_.status):\n", " break\n", " print(candidate)\n", " # candidate = input()\n", " obs, r, done, info = env.step(candidate)\n", " print(obs)\n", " print(env.steps, info)\n", " print('-------------------\\n\\n\\n')\n", " infos.append(info)\n", " return infos" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# DFS" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def dfs(env, actions, infos, time_limit, prune, max_per_state):\n", " # get candidate thoughts\n", " candidates_to_scores = get_candidates_to_scores(env)\n", " if len(candidates_to_scores) == 0: return 0, [], []\n", " print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))\n", "\n", " # back up current state\n", " board, status, steps = env.board.copy(), env.status.copy(), env.steps\n", "\n", " # try each candidate\n", " cnt_per_state = 0\n", " for action in sorted(candidates_to_scores, key=candidates_to_scores.get, reverse=True):\n", " obs, r, done, info = env.step(action)\n", " r = info['r_word']\n", " if len(infos) < time_limit and env.steps < 10 and not any(_ == 2 for _ in env.status): # not violating any existing constraints\n", " cnt_per_state += 1\n", " if cnt_per_state > max_per_state: break\n", " count = env.prompt_status() \n", " actions.append(action) \n", "\n", " print(len(infos))\n", " print(actions)\n", " print(env.render_board())\n", " print(info)\n", " print(count)\n", " if infos:\n", " best = max(infos, key=lambda x: x['info']['r_word'])\n", " print('best', best)\n", " print('--------------')\n", " print()\n", "\n", " info = {'total_step': len(infos), 'env_step': env.steps, 'actions': actions.copy(), 'info': info, 'count': count}\n", " infos.append(info)\n", " if not prune or count['impossible'] < 1: # only continue if the current status is possible\n", " dfs(env, actions, infos, time_limit, prune, max_per_state)\n", " actions.pop()\n", " env.reset(env.idx, board=board.copy(), status=status.copy(), steps=steps)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# dfs with pruning\n", "infoss = []\n", "for i in range(0, 100, 5):\n", " env.reset(i)\n", " infos = []\n", " actions = []\n", " dfs(env, actions, infos, 100, prune=True, max_per_state=3)\n", " infoss.append(infos)\n", " with open('logs/crosswords/infoss_dfs_prune.json', 'w') as fout:\n", " json.dump(infoss, fout)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# dfs without pruning\n", "infoss = []\n", "for i in range(0, 100, 5):\n", " env.reset(i)\n", " infos = []\n", " actions = []\n", " dfs(env, actions, infos, 100, prune=False, max_per_state=3)\n", " infoss.append(infos)\n", " with open('logs/crosswords/infoss_dfs_no_prune.json', 'w') as fout:\n", " json.dump(infoss, fout)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 2 }