import re import json import os from tasks.base import Task, DATA_PATH from prompts.crosswords import * from models import gpt class MiniCrosswordsEnv: def __init__(self, file='mini0505.json'): self.file = f'data/crosswords/{file}' self.file = json.load(open(self.file)) self.n = len(self.file) self.cache = {} self.idx = None self.times = 0 self.prompt_status_cache = {} def __len__(self): return self.n def reset(self, idx, board=None, status=None, steps=None): self.idx = idx self.data, self.board_gt = self.file[idx] self.board = ['_'] * 25 self.ans = ['_____'] * 10 self.ans_gt = self.get_ans(self.board_gt) self.steps = 0 self.status = [0] * 10 # 0: unfilled; 1: filled; 2: filled then changed if board is not None: self.board = board self.ans = self.get_ans(self.board) if status is not None: self.status = status if steps is not None: self.steps = steps return self.render() def prompt_status(self): count = {'sure': 0, 'maybe': 0, 'impossible': 0} for ans, data, status in zip(self.ans, self.data, self.status): # if status != 0: continue if ans.count('_') >= 4: continue ans = ' '.join(ans.lower()) line = f'{data}: {ans}' prompt = value_prompt.format(input=line) if prompt in self.prompt_status_cache: res = self.prompt_status_cache[prompt] else: res = gpt(prompt)[0] self.prompt_status_cache[prompt] = res # print(line) # print(res) # print() res = res.split('\n')[-1].strip() if res in count: count[res] += 1 # print(count) return count def render_gt_board(self): s = "GT Board:\n" for i in range(5): s += ' '.join(self.board_gt[i*5:(i+1)*5]) + '\n' return s def render_board(self): s = "Current Board:\n" for i in range(5): s += ''.join(self.board[i*5:(i+1)*5]) + '\n' return s def render_clues(self, status=None): s = "" # s += "Horizontal:\n" for i in range(5): if status is None or self.status[i] == status: s += 'h' + str(i+1) + '. ' + self.data[i] + '\n' # s += "Vertical:\n" for i in range(5, 10): if status is None or self.status[i] == status: s += 'v' + str(i-5+1) + '. ' + self.data[i] + '\n' return s def render_ans(self, status=None): s = "" # s += "Horizontal:\n" for i in range(5): if status is None or self.status[i] == status: s += 'h' + str(i+1) + '. ' + self.data[i] + ': ' + self.ans[i] + '\n' # s += "Vertical:\n" for i in range(5, 10): if status is None or self.status[i] == status: s += 'v' + str(i-5+1) + '. ' + self.data[i] + ': ' + self.ans[i] + '\n' return s def render_gt_ans(self, status=None): s = "" # s += "Horizontal:\n" for i in range(5): if status is None or self.status[i] == status: s += 'h' + str(i+1) + '. ' + self.data[i] + ': ' + self.ans_gt[i] + '\n' # s += "Vertical:\n" for i in range(5, 10): if status is None or self.status[i] == status: s += 'v' + str(i-5+1) + '. ' + self.data[i] + ': ' + self.ans_gt[i] + '\n' return s def render(self, status=True): if status: return self.render_board() + '\nUnfilled:\n' + self.render_ans(status=0) + '\nFilled:\n' + self.render_ans(status=1) + '\nChanged:\n' + self.render_ans(status=2) else: return self.render_board() + '\n' + self.render_ans() def get_ans(self, board): ans = [''] * 10 for i in range(5): ans[i] = ''.join(board[i*5:(i+1)*5]) for i in range(5): ans[i+5] = ''.join(board[i::5]) return ans def step(self, action): self.steps += 1 action = action.split('\n')[-1] action = action.split('. ') if len(action) != 2: return 'Invalid! Format should be like "h1. apple"', 0, False, {} pos, word = action if len(word) != 5: return 'Invalid! Word should have 5 letters.', 0, False, {} if pos.startswith('h'): idx = int(pos[1:]) - 1 self.board[idx*5:(idx+1)*5] = list(word.upper()) elif pos.startswith('v'): idx = int(pos[1:]) - 1 self.board[idx::5] = list(word.upper()) idx += 5 # for later status update else: return 'Invalid! Position should be h1-h5 or v1-v5', 0, False, {} self.new_ans = self.get_ans(self.board) # self.status = [2 if (status == 1 and ans != new_ans) else status for status, ans, new_ans in zip(self.status, self.ans, self.new_ans)] self.status = [2 if any(letter != new_letter and letter != '_' for letter, new_letter in zip(ans, new_ans)) else status for status, ans, new_ans in zip(self.status, self.ans, self.new_ans)] self.status[idx] = 1 self.ans = self.new_ans r_all = (self.board == self.board_gt) r_letter = sum(a == b for a, b in zip(self.board, self.board_gt)) / 25 r_word = sum(a == b for a, b in zip(self.ans, self.ans_gt)) / 10 return self.render(), r_all, (r_all or self.steps >= 20), {'r_letter': r_letter, 'r_word': r_word, 'r_game': r_all} class MiniCrosswordsTask(Task): """ Input (x) : Decription of a 5x5 mini crossword Output (y) : List of 10 words to fill in the crossword Reward (r) : word level and game level Input Example: Output Example: """ def __init__(self, file): """ file: a csv file (fixed) """ super().__init__() self.env = MiniCrosswordsEnv(file) # use it as a stateless tool self.xs = [] for idx in range(len(self.env)): self.env.reset(idx) self.xs.append(self.env.render_clues()) self.steps = 10 # TODO: variable steps?? self.cache_proposals = {} def __len__(self) -> int: return len(self.env) def get_input(self, idx: int) -> str: self.env.reset(idx) return self.env.render_clues() # def test_output(self, idx: int, output: str): # TODO: r_word for now # self.env.reset(idx) # info = {'r_word': 0} # for line in output.split('\n'): # if line.startswith('h') or line.startswith('v'): # _, _, _, info = self.env.step(line) # return info['r_word'] def test_output(self, idx: int, output: str): self.env.reset(idx) output = output.split('Output:\n')[-1] info = {'r_word': 0, 'r_letter': 0, 'r_game': 0} for i, line in enumerate(output.strip().split('\n')[-5:], 1): letters = line.split(' ')[:5] word = ''.join(letters) word = word + '_' * (5 - len(word)) action = f'h{i}. {word}' # print(action) _, _, _, info = self.env.step(action) info['r'] = info['r_word'] return info def set_status(self, x: str, y: str): idx = self.xs.index(x) self.test_output(idx, y) # update self.env @staticmethod def standard_prompt_wrap(x: str, y:str='') -> str: return standard_prompt.format(input=x) + y @staticmethod def cot_prompt_wrap(x: str, y:str='') -> str: return cot_prompt.format(input=x) + y def propose_prompt_wrap(self, x: str, y: str='') -> str: self.set_status(x, y) return propose_prompt.format(input=self.env.render()) def propose_outputs_unwrap(self, x: str, y: str, outputs: list, n_max_propose: int) -> list: confidence_to_value = {'certain': 1, 'high': 0.5, 'medium': 0.2, 'low': 0.1} # TODO: ad hoc proposals_to_scores = {} for output in outputs: lines = output.split('\n') pattern = r'^([hv][1-5])\. ([a-zA-Z]{5,5}) \((certain|high|medium|low)\).*$' for line in lines: match = re.match(pattern, line) if match: parts = [match.group(1), match.group(2), match.group(3)] proposal = parts[0].lower() + '. ' + parts[1].lower() score = confidence_to_value.get(parts[2], 0) proposals_to_scores[proposal] = proposals_to_scores.get(proposal, 0) + score proposals = sorted(proposals_to_scores.items(), key=lambda x: x[1], reverse=True) if n_max_propose != -1: proposals = proposals[:n_max_propose] proposals = [y + proposal[0] + '\n' for proposal in proposals] self.cache_proposals[(x, y, n_max_propose)] = proposals return proposals def evaluate(self, x: str, y: str, n_evaluate_sample: int) -> int: self.set_status(x, y) assert n_evaluate_sample == 1 # TODO: ad hoc count = {'sure': 0, 'maybe': 0, 'impossible': 0} for ans, data, status in zip(self.env.ans, self.env.data, self.env.status): if ans.count('_') >= 4: continue ans = ' '.join(ans.lower()) line = f'{data}: {ans}' prompt = value_prompt.format(input=line) res = gpt(prompt)[0] print(line) print(res) print() res = res.split('\n')[-1].strip() if res in count: count[res] += 1 print(count) return count