import multiprocessing from math import isclose import numpy as np from typing import Union, Any, Dict from sympy import simplify, N from sympy.parsing.sympy_parser import parse_expr from sympy.parsing.latex import parse_latex import re import regex from data_processing.answer_extraction import extract_answer, extract_program_output, strip_string def extract_program(result: str, last_only=True): """ extract the program after "```python", and before "```" """ program = "" start = False for line in result.split("\n"): if line.startswith("```python"): if last_only: program = "" # only extract the last program else: program += "\n# ========\n" start = True elif line.startswith("```"): start = False elif start: program += line + "\n" return program def parse_ground_truth(example: Dict[str, Any], data_name): if 'gt_cot' in example: return example['gt_cot'], strip_string(example['gt']) # parse ground truth if data_name in ["math", 'ocw']: gt_cot = example['solution'] gt_ans = extract_answer(gt_cot) elif data_name == "gsm8k": gt_cot, gt_ans = example['answer'].split("####") elif data_name == "gsm-hard": gt_cot, gt_ans = example['code'], example['target'] elif data_name == "svamp": gt_cot, gt_ans = example['Equation'], example['Answer'] elif data_name == "asdiv": gt_cot = example['formula'] gt_ans = re.sub(r"\(.*?\)", "", example['answer']) elif data_name == "mawps": gt_cot, gt_ans = None, example['target'] elif data_name == "tabmwp": gt_cot = example['solution'] gt_ans = example['answer'] if example['ans_type'] in ['integer_number', 'decimal_number']: if '/' in gt_ans: gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1]) elif ',' in gt_ans: gt_ans = float(gt_ans.replace(',', '')) elif '%' in gt_ans: gt_ans = float(gt_ans.split('%')[0]) / 100 else: gt_ans = float(gt_ans) elif data_name == "bbh": gt_cot, gt_ans = None, example['target'] else: raise NotImplementedError(data_name) # post process gt_cot = str(gt_cot).strip() gt_ans = strip_string(gt_ans) return gt_cot, gt_ans def parse_question(example, data_name): question = "" if data_name == "asdiv": question = f"{example['body'].strip()} {example['question'].strip()}" elif data_name == "svamp": body = example["Body"].strip() if not body.endswith("."): body = body + "." question = f'{body} {example["Question"].strip()}' elif data_name == "tabmwp": title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else "" question = f'Read the following table {title_str}and answer a question:\n' question += f'{example["table"]}\n{example["question"]}' if example['choices']: question += f' Please select from the following options: {example["choices"]}' else: for key in ['question', 'problem', 'Question', 'input']: if key in example: question = example[key] break assert question != "" return question.strip() def run_execute(executor, result, prompt_type, execute=False): if not result or result == 'error': return None, None report = None if "program_only" in prompt_type: prediction = extract_program_output(result) elif prompt_type in ["pot", "pal"] and execute: code = extract_program(result) prediction, report = executor.apply(code) else: prediction = extract_answer(result) prediction = strip_string(prediction) return prediction, report def parse_digits(num): # format: 234.23 || 23% num = regex.sub(',', '', str(num)) try: return float(num) except: if num.endswith('%'): num = num[:-1] if num.endswith('\\'): num = num[:-1] try: return float(num) / 100 except: pass return None def is_digit(num): # paired with parse_digits return parse_digits(num) is not None def normalize_prediction(prediction): try: # 1. numerical equal if is_digit(prediction): prediction = np.round(float(str(prediction).replace(",", "")), 6) return str(prediction) except: pass # 2. symbolic equal prediction = str(prediction).strip() ## deal with [], (), {} brackets = [] while prediction.startswith("[") and prediction.endswith("]") or (prediction.startswith("(") and prediction.endswith(")")): bracket = prediction[0] prediction = prediction[1:-1] if brackets and ',' in prediction: pred_parts = [normalize_prediction(part) for part in prediction.split(",")] prediction = ",".join(pred_parts) if brackets: for b in reversed(brackets): if b == '[': prediction = '[' + prediction + ']' else: assert b == '(' prediction = '(' + prediction + ')' def _parse(s): for f in [parse_latex, parse_expr]: try: return f(s) except: pass return s prediction = _parse(prediction) for s in ['{', "}", "(", ")"]: prediction = prediction.replace(s, "") return prediction def math_equal(prediction: Union[bool, float, str], reference: Union[float, str], include_percentage: bool = True, is_close: bool = True, timeout: bool = False, ) -> bool: """ Exact match of math if and only if: 1. numerical equal: both can convert to float and are equal 2. symbolic equal: both can convert to sympy expression and are equal """ if str(prediction) == str(reference): return True try: # 1. numerical equal if is_digit(prediction) and is_digit(reference): prediction = parse_digits(prediction) reference = parse_digits(reference) # number questions if include_percentage: gt_result = [reference / 100, reference, reference * 100] else: gt_result = [reference] for item in gt_result: try: if is_close: if isclose(item, prediction, abs_tol=1e-3): return True else: if item == prediction: return True except Exception: continue return False except: pass if not prediction and prediction not in [0, False]: return False # 2. symbolic equal reference = str(reference).strip() prediction = str(prediction).strip() if regex.match(r'(\(|\[).+(\)|\])', prediction) is not None and regex.match(r'(\(|\[).+(\)|\])', reference) is not None: pred_parts = prediction[1:-1].split(",") ref_parts = reference[1:-1].split(",") if len(pred_parts) == len(ref_parts): if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): return True if (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) and \ (reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")): pred_lines = [line.strip() for line in prediction[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()] ref_lines = [line.strip() for line in reference[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()] matched = True if len(pred_lines) == len(ref_lines): for pred_line, ref_line in zip(pred_lines, ref_lines): pred_parts = pred_line.split("&") ref_parts = ref_line.split("&") if len(pred_parts) == len(ref_parts): if not all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): matched = False break else: matched = False if not matched: break else: matched = False if matched: return True if prediction.count('=') == 1 and reference.count('=') == 1: pred = prediction.split('=') pred = f"{pred[0].strip()} - ({pred[1].strip()})" ref = reference.split('=') ref = f"{ref[0].strip()} - ({ref[1].strip()})" if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): return True elif prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference: if math_equal(prediction.split('=')[1], reference, include_percentage, is_close): return True elif reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction: if math_equal(prediction, reference.split('=')[1], include_percentage, is_close): return True # symbolic equal with sympy if timeout: if call_with_timeout(symbolic_equal_process, prediction, reference): return True else: if symbolic_equal(prediction, reference): return True return False def math_equal_process(param): return math_equal(param[-2], param[-1]) def symbolic_equal(a, b): def _parse(s): for f in [parse_latex, parse_expr]: try: return f(s) except: pass return s a = _parse(a) b = _parse(b) try: if simplify(a-b) == 0: return True except: pass try: if isclose(N(a), N(b), abs_tol=1e-3): return True except: pass return False def symbolic_equal_process(a, b, output_queue): result = symbolic_equal(a, b) output_queue.put(result) def call_with_timeout(func, *args, timeout=1, **kwargs): output_queue = multiprocessing.Queue() process_args = args + (output_queue,) process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) process.start() process.join(timeout) if process.is_alive(): process.terminate() process.join() return False return output_queue.get()