""" This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from: - https://github.com/microsoft/ProphetNet/tree/master/CRITIC - https://github.com/openai/prm800k """ import multiprocessing from math import isclose from typing import Union from sympy import simplify, N from sympy.parsing.sympy_parser import parse_expr from sympy.parsing.latex import parse_latex def is_digit(s): try: float(str(s).replace(",", "")) return True except ValueError: return False def math_equal(prediction: Union[bool, float, str], reference: Union[float, str], include_percentage: bool = True, is_close: bool = True, timeout: bool = False, ) -> bool: """ Exact match of math if and only if: 1. numerical equal: both can convert to float and are equal 2. symbolic equal: both can convert to sympy expression and are equal """ try: # 1. numerical equal if is_digit(prediction) and is_digit(reference): prediction = float(str(prediction).replace(",", "")) reference = float(str(reference).replace(",", "")) # number questions if include_percentage: gt_result = [reference / 100, reference, reference * 100] else: gt_result = [reference] for item in gt_result: try: if is_close: if isclose(item, prediction, rel_tol=1e-4): return True else: if item == prediction: return True except Exception: continue return False except: pass if not prediction and prediction not in [0, False]: return False # 2. symbolic equal reference = str(reference).strip() prediction = str(prediction).strip() ## deal with [], (), {} pred_str, ref_str = prediction, reference if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or \ (prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")): pred_str = pred_str.strip("[]()") ref_str = ref_str.strip("[]()") for s in ['{', "}", "(", ")"]: ref_str = ref_str.replace(s, "") pred_str = pred_str.replace(s, "") if pred_str == ref_str: return True ## [a, b] vs. [c, d], return a==c and b==d if (prediction.startswith("[") and prediction.endswith("]")) and (reference.startswith("[") and reference.endswith("]")) or \ (prediction.startswith("(") and prediction.endswith(")")) and (reference.startswith("(") and reference.endswith(")")): pred_parts = prediction[1:-1].split(",") ref_parts = reference[1:-1].split(",") if len(pred_parts) == len(ref_parts): if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): return True # symbolic equal with sympy if timeout: if call_with_timeout(symbolic_equal_process, prediction, reference): return True else: if symbolic_equal(prediction, reference): return True return False def math_equal_process(param): print(param[-2], param[-1],math_equal(param[-2], param[-1])) return math_equal(param[-2], param[-1]) def symbolic_equal(a, b): def _parse(s): for f in [parse_latex, parse_expr]: try: return f(s) except: pass return s a = _parse(a) b = _parse(b) try: if simplify(a-b) == 0: return True except: pass try: if isclose(N(a), N(b), rel_tol=1e-3): return True except: pass return False def symbolic_equal_process(a, b, output_queue): result = symbolic_equal(a, b) output_queue.put(result) def call_with_timeout(func, *args, timeout=1, **kwargs): output_queue = multiprocessing.Queue() process_args = args + (output_queue,) process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) process.start() process.join(timeout) if process.is_alive(): process.terminate() process.join() return False return output_queue.get() def _test_math_equal(): # print(math_equal("0.0833333333333333", "\\frac{1}{12}")) # print(math_equal("(1,4.5)", "(1,\\frac{9}{2})")) print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True)) if __name__ == "__main__": _test_math_equal()