init project

This commit is contained in:
Dejian
2023-11-02 22:07:09 +08:00
commit a4ba628dfd
111 changed files with 26064 additions and 0 deletions

View File

@@ -0,0 +1,151 @@
"""
This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
- https://github.com/microsoft/ProphetNet/tree/master/CRITIC
- https://github.com/openai/prm800k
"""
import multiprocessing
from math import isclose
from typing import Union
from sympy import simplify, N
from sympy.parsing.sympy_parser import parse_expr
from sympy.parsing.latex import parse_latex
def is_digit(s):
try:
float(str(s).replace(",", ""))
return True
except ValueError:
return False
def math_equal(prediction: Union[bool, float, str],
reference: Union[float, str],
include_percentage: bool = True,
is_close: bool = True,
timeout: bool = False,
) -> bool:
"""
Exact match of math if and only if:
1. numerical equal: both can convert to float and are equal
2. symbolic equal: both can convert to sympy expression and are equal
"""
try: # 1. numerical equal
if is_digit(prediction) and is_digit(reference):
prediction = float(str(prediction).replace(",", ""))
reference = float(str(reference).replace(",", ""))
# number questions
if include_percentage:
gt_result = [reference / 100, reference, reference * 100]
else:
gt_result = [reference]
for item in gt_result:
try:
if is_close:
if isclose(item, prediction, rel_tol=1e-4):
return True
else:
if item == prediction:
return True
except Exception:
continue
return False
except:
pass
if not prediction and prediction not in [0, False]:
return False
# 2. symbolic equal
reference = str(reference).strip()
prediction = str(prediction).strip()
## deal with [], (), {}
pred_str, ref_str = prediction, reference
if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or \
(prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")):
pred_str = pred_str.strip("[]()")
ref_str = ref_str.strip("[]()")
for s in ['{', "}", "(", ")"]:
ref_str = ref_str.replace(s, "")
pred_str = pred_str.replace(s, "")
if pred_str == ref_str:
return True
## [a, b] vs. [c, d], return a==c and b==d
if (prediction.startswith("[") and prediction.endswith("]")) and (reference.startswith("[") and reference.endswith("]")) or \
(prediction.startswith("(") and prediction.endswith(")")) and (reference.startswith("(") and reference.endswith(")")):
pred_parts = prediction[1:-1].split(",")
ref_parts = reference[1:-1].split(",")
if len(pred_parts) == len(ref_parts):
if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
return True
# symbolic equal with sympy
if timeout:
if call_with_timeout(symbolic_equal_process, prediction, reference):
return True
else:
if symbolic_equal(prediction, reference):
return True
return False
def math_equal_process(param):
print(param[-2], param[-1],math_equal(param[-2], param[-1]))
return math_equal(param[-2], param[-1])
def symbolic_equal(a, b):
def _parse(s):
for f in [parse_latex, parse_expr]:
try:
return f(s)
except:
pass
return s
a = _parse(a)
b = _parse(b)
try:
if simplify(a-b) == 0:
return True
except:
pass
try:
if isclose(N(a), N(b), rel_tol=1e-3):
return True
except:
pass
return False
def symbolic_equal_process(a, b, output_queue):
result = symbolic_equal(a, b)
output_queue.put(result)
def call_with_timeout(func, *args, timeout=1, **kwargs):
output_queue = multiprocessing.Queue()
process_args = args + (output_queue,)
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
process.start()
process.join(timeout)
if process.is_alive():
process.terminate()
process.join()
return False
return output_queue.get()
def _test_math_equal():
# print(math_equal("0.0833333333333333", "\\frac{1}{12}"))
# print(math_equal("(1,4.5)", "(1,\\frac{9}{2})"))
print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True))
if __name__ == "__main__":
_test_math_equal()

View File

@@ -0,0 +1,320 @@
"""
This file come from: https://github.com/microsoft/ToRA/blob/main/src/utils/parser.py
"""
import re
from typing import Any, Dict
def _fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if len(substr) > 0 and substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def _fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
if "sqrt" not in a:
a = int(a)
if "sqrt" not in b:
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except:
return string
def _fix_sqrt(string):
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
return _string
def strip_string(string):
string = str(string).strip()
# linebreaks
string = string.replace("\n", "")
# right "."
string = string.rstrip(".")
# remove inverse spaces
string = string.replace("\\!", "")
string = string.replace("\\ ", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove unit: miles, dollars if after is not none
_string = re.sub(r"\\text{.*?}$", "", string).strip()
if _string != "" and _string != string:
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
string = _string
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
string = string.replace("$", "")
string = string.replace("\\text", "")
string = string.replace("x\\in", "")
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")
string = string.replace("%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# cdot
string = string.replace("\\cdot", "")
# inf
string = string.replace("infinity", "\\infty")
if "\\infty" not in string:
string = string.replace("inf", "\\infty")
string = string.replace("+\\inity", "\\infty")
# and
string = string.replace("and", "")
string = string.replace("\\mathbf", "")
# use regex to remove \mbox{...}
string = re.sub(r"\\mbox{.*?}", "", string)
# quote
string.replace("'", "")
string.replace("\"", "")
# i, j
if "j" in string and "i" not in string:
string = string.replace("j", "i")
# replace a.000b where b is not number or b is end, with ab, use regex
string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string)
string = re.sub(r"(\d+)\.0+$", r"\1", string)
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
string = _fix_sqrt(string)
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string
def extract_answer(pred_str):
if 'boxed' in pred_str:
ans = pred_str.split('boxed')[-1]
if len(ans) == 0:
return ""
elif (ans[0] == '{'):
stack = 1
a = ''
for c in ans[1:]:
if (c == '{'):
stack += 1
a += c
elif (c == '}'):
stack -= 1
if (stack == 0): break
a += c
else:
a += c
else:
a = ans.split('$')[0].strip()
pred=a
elif ('he answer is' in pred_str):
pred = pred_str.split('he answer is')[-1].strip()
elif extract_program_output(pred_str) != "":
# fall back to program
pred = extract_program_output(pred_str)
else: # use the last number
pattern = '-?\d*\.?\d+'
pred = re.findall(pattern, pred_str.replace(",", ""))
if(len(pred) >= 1):
pred = pred[-1]
else: pred = ''
# multiple line
pred = pred.split("\n")[0]
if pred != "" and pred[0] == ":":
pred = pred[1:]
if pred != "" and pred[-1] == ".":
pred = pred[:-1]
if pred != "" and pred[-1] == "/":
pred = pred[:-1]
pred = strip_string(pred)
return pred
def extract_program(result: str, last_only=True):
"""
extract the program after "```python", and before "```"
"""
program = ""
start = False
for line in result.split("\n"):
if line.startswith("```python"):
if last_only:
program = "" # only extract the last program
else:
program += "\n# ========\n"
start = True
elif line.startswith("```"):
start = False
elif start:
program += line + "\n"
return program
def extract_program_output(pred_str):
"""
extract output between the last ```output\n...\n```
"""
if "```output" not in pred_str:
return ""
if '```output' in pred_str:
pred_str = pred_str.split('```output')[-1]
if '```' in pred_str:
pred_str = pred_str.split('```')[0]
output = pred_str.strip()
return output
def parse_ground_truth(example: Dict[str, Any], data_name):
if 'gt_cot' in example:
return example['gt_cot'], strip_string(example['gt'])
# parse ground truth
if data_name in ["math", 'ocw']:
gt_cot = example['solution']
gt_ans = extract_answer(gt_cot)
elif data_name == "gsm8k":
gt_cot, gt_ans = example['answer'].split("####")
elif data_name == "gsm-hard":
gt_cot, gt_ans = example['code'], example['target']
elif data_name == "svamp":
gt_cot, gt_ans = example['Equation'], example['Answer']
elif data_name == "asdiv":
gt_cot = example['formula']
gt_ans = re.sub(r"\(.*?\)", "", example['answer'])
elif data_name == "mawps":
gt_cot, gt_ans = None, example['target']
elif data_name == "tabmwp":
gt_cot = example['solution']
gt_ans = example['answer']
if example['ans_type'] in ['integer_number', 'decimal_number']:
if '/' in gt_ans:
gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1])
elif ',' in gt_ans:
gt_ans = float(gt_ans.replace(',', ''))
elif '%' in gt_ans:
gt_ans = float(gt_ans.split('%')[0]) / 100
else:
gt_ans = float(gt_ans)
elif data_name == "bbh":
gt_cot, gt_ans = None, example['target']
else:
raise NotImplementedError(data_name)
# post process
gt_cot = str(gt_cot).strip()
gt_ans = strip_string(gt_ans)
return gt_cot, gt_ans
def parse_question(example, data_name):
question = ""
if data_name == "asdiv":
question = f"{example['body'].strip()} {example['question'].strip()}"
elif data_name == "svamp":
body = example["Body"].strip()
if not body.endswith("."):
body = body + "."
question = f'{body} {example["Question"].strip()}'
elif data_name == "tabmwp":
title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else ""
question = f'Read the following table {title_str}and answer a question:\n'
question += f'{example["table"]}\n{example["question"]}'
if example['choices']:
question += f' Please select from the following options: {example["choices"]}'
else:
for key in ['question', 'problem', 'Question', 'input']:
if key in example:
question = example[key]
break
assert question != ""
return question.strip()
def run_execute(executor, result, prompt_type, execute=False):
if not result or result == 'error':
return None, None
report = None
if "program_only" in prompt_type:
prediction = extract_program_output(result)
elif prompt_type in ["pot", "pal"] and execute:
code = extract_program(result)
prediction, report = executor.apply(code)
else:
prediction = extract_answer(result)
prediction = strip_string(prediction)
return prediction, report

View File

@@ -0,0 +1,180 @@
"""
This file come from: https://github.com/microsoft/ToRA/blob/main/src/utils/python_executor.py
"""
import io
import regex
import pickle
import traceback
import copy
import datetime
import multiprocessing
import dateutil.relativedelta
import multiprocess
from multiprocess import Pool
from typing import Any, Dict, Optional
from pebble import ProcessPool
from tqdm import tqdm
from concurrent.futures import TimeoutError
from functools import partial
from timeout_decorator import timeout
from contextlib import redirect_stdout
class GenericRuntime:
GLOBAL_DICT = {}
LOCAL_DICT = None
HEADERS = []
def __init__(self):
self._global_vars = copy.copy(self.GLOBAL_DICT)
self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
for c in self.HEADERS:
self.exec_code(c)
def exec_code(self, code_piece: str) -> None:
if regex.search(r'(\s|^)?input\(', code_piece) or regex.search(r'(\s|^)?os.system\(', code_piece):
raise RuntimeError()
exec(code_piece, self._global_vars)
def eval_code(self, expr: str) -> Any:
return eval(expr, self._global_vars)
def inject(self, var_dict: Dict[str, Any]) -> None:
for k, v in var_dict.items():
self._global_vars[k] = v
@property
def answer(self):
return self._global_vars['answer']
class DateRuntime(GenericRuntime):
GLOBAL_DICT = {
'datetime': datetime.datetime,
'timedelta': dateutil.relativedelta.relativedelta,
'relativedelta': dateutil.relativedelta.relativedelta
}
class CustomDict(dict):
def __iter__(self):
return list(super().__iter__()).__iter__()
class ColorObjectRuntime(GenericRuntime):
GLOBAL_DICT = {'dict': CustomDict}
class PythonExecutor:
def __init__(
self,
runtime: Optional[Any] = None,
get_answer_symbol: Optional[str] = None,
get_answer_expr: Optional[str] = None,
get_answer_from_stdout: bool = False,
timeout_length: int = 5,
) -> None:
self.runtime = runtime if runtime else GenericRuntime()
self.answer_symbol = get_answer_symbol
self.answer_expr = get_answer_expr
self.get_answer_from_stdout = get_answer_from_stdout
self.timeout_length = timeout_length
def process_generation_to_code(self, gens: str):
return [g.split('\n') for g in gens]
@staticmethod
def execute(
code,
get_answer_from_stdout = None,
runtime = None,
answer_symbol = None,
answer_expr = None,
timeout_length = 10,
):
try:
if get_answer_from_stdout:
program_io = io.StringIO()
with redirect_stdout(program_io):
timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
program_io.seek(0)
result = program_io.readlines()[-1]
elif answer_symbol:
timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
result = runtime._global_vars[answer_symbol]
elif answer_expr:
timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
else:
timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1]))
result = timeout(timeout_length)(runtime.eval_code)(code[-1])
exec_info = "Done"
str(result)
pickle.dumps(result) # serialization check
except:
result = ''
exec_info = traceback.format_exc().split('\n')[-2]
return result, exec_info
def apply(self, code):
return self.batch_apply([code])[0]
def batch_apply(self, batch_code):
all_code_snippets = self.process_generation_to_code(batch_code)
timeout_cnt = 0
all_exec_results = []
with ProcessPool(max_workers=min(len(all_code_snippets), multiprocessing.cpu_count())) as pool:
executor = partial(
self.execute,
get_answer_from_stdout=self.get_answer_from_stdout,
runtime=self.runtime,
answer_symbol=self.answer_symbol,
answer_expr=self.answer_expr,
timeout_length=self.timeout_length, # this timeout not work
)
future = pool.map(executor, all_code_snippets, timeout=self.timeout_length)
iterator = future.result()
if len(all_code_snippets) > 100:
progress_bar = tqdm(total=len(all_code_snippets), desc="Execute")
else:
progress_bar = None
while True:
try:
result = next(iterator)
all_exec_results.append(result)
except StopIteration:
break
except TimeoutError as error:
print(error)
all_exec_results.append(("", "Timeout Error"))
timeout_cnt += 1
except Exception as error:
print(error)
exit()
if progress_bar is not None:
progress_bar.update(1)
if progress_bar is not None:
progress_bar.close()
batch_results = []
for code, (result, exec_info) in zip(all_code_snippets, all_exec_results):
batch_results.append((result, exec_info))
return batch_results
def _test():
batch_code = [
"""
print("Hello world!")
"""
]
executor = PythonExecutor(get_answer_from_stdout=True)
predictions = executor.apply(batch_code[0])
print(predictions)
if __name__ == '__main__':
_test()