mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2024-11-16 13:42:27 +00:00
92 lines
3.1 KiB
Python
92 lines
3.1 KiB
Python
|
import re
|
||
|
import os
|
||
|
import sympy
|
||
|
import pandas as pd
|
||
|
from tasks.base import Task, DATA_PATH
|
||
|
from prompts.game24 import *
|
||
|
|
||
|
|
||
|
def get_current_numbers(y: str) -> str:
|
||
|
last_line = y.strip().split('\n')[-1]
|
||
|
return last_line.split('left: ')[-1].split(')')[0]
|
||
|
|
||
|
|
||
|
class Game24Task(Task):
|
||
|
"""
|
||
|
Input (x) : a string of 4 numbers
|
||
|
Output (y) : a trajectory of 3 steps to reach 24
|
||
|
Reward (r) : 0 or 1, depending on whether the trajectory is correct
|
||
|
Input Example:
|
||
|
1 2 3 4
|
||
|
Output Example:
|
||
|
1 + 2 = 3 (left: 3 3 4)
|
||
|
3 + 3 = 6 (left: 4 6)
|
||
|
6 * 4 = 24 (left: 24)
|
||
|
(1 + 2 + 3) * 4 = 24
|
||
|
"""
|
||
|
def __init__(self, file='24.csv'):
|
||
|
"""
|
||
|
file: a csv file (fixed)
|
||
|
"""
|
||
|
super().__init__()
|
||
|
path = os.path.join(DATA_PATH, '24', file)
|
||
|
self.data = list(pd.read_csv(path)['Puzzles'])
|
||
|
self.value_cache = {}
|
||
|
self.steps = 4
|
||
|
self.stops = ['\n'] * 4
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
return len(self.data)
|
||
|
|
||
|
def get_input(self, idx: int) -> str:
|
||
|
return self.data[idx]
|
||
|
|
||
|
def test_output(self, idx: int, output: str):
|
||
|
expression = output.strip().split('\n')[-1].lower().replace('answer: ', '').split('=')[0]
|
||
|
numbers = re.findall(r'\d+', expression)
|
||
|
problem_numbers = re.findall(r'\d+', self.data[idx])
|
||
|
if sorted(numbers) != sorted(problem_numbers):
|
||
|
return {'r': 0}
|
||
|
try:
|
||
|
# print(sympy.simplify(expression))
|
||
|
return {'r': int(sympy.simplify(expression) == 24)}
|
||
|
except Exception as e:
|
||
|
# print(e)
|
||
|
return {'r': 0}
|
||
|
|
||
|
@staticmethod
|
||
|
def standard_prompt_wrap(x: str, y:str='') -> str:
|
||
|
return standard_prompt.format(input=x) + y
|
||
|
|
||
|
@staticmethod
|
||
|
def cot_prompt_wrap(x: str, y:str='') -> str:
|
||
|
return cot_prompt.format(input=x) + y
|
||
|
|
||
|
@staticmethod
|
||
|
def propose_prompt_wrap(x: str, y: str='') -> str:
|
||
|
current_numbers = get_current_numbers(y if y else x)
|
||
|
if current_numbers == '24':
|
||
|
prompt = cot_prompt.format(input=x) + 'Steps:' + y
|
||
|
# print([prompt])
|
||
|
else:
|
||
|
prompt = propose_prompt.format(input=current_numbers)
|
||
|
return prompt
|
||
|
|
||
|
@staticmethod
|
||
|
def value_prompt_wrap(x: str, y: str) -> str:
|
||
|
last_line = y.strip().split('\n')[-1]
|
||
|
if 'left: ' not in last_line: # last step
|
||
|
ans = last_line.lower().replace('answer: ', '')
|
||
|
# print([value_last_step_prompt.format(input=x, answer=ans)])
|
||
|
return value_last_step_prompt.format(input=x, answer=ans)
|
||
|
current_numbers = get_current_numbers(y)
|
||
|
return value_prompt.format(input=current_numbers)
|
||
|
|
||
|
@staticmethod
|
||
|
def value_outputs_unwrap(x: str, y: str, value_outputs: list) -> float:
|
||
|
if len(y.strip().split('\n')) == 4 and 'answer' not in y.lower():
|
||
|
return 0
|
||
|
value_names = [_.split('\n')[-1] for _ in value_outputs]
|
||
|
value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20} # TODO: ad hoc
|
||
|
value = sum(value * value_names.count(name) for name, value in value_map.items())
|
||
|
return value
|