""" This file come from: https://github.com/microsoft/ToRA/blob/main/src/utils/parser.py """ import re from typing import Any, Dict def _fix_fracs(string): substrs = string.split("\\frac") new_str = substrs[0] if len(substrs) > 1: substrs = substrs[1:] for substr in substrs: new_str += "\\frac" if len(substr) > 0 and substr[0] == "{": new_str += substr else: try: assert len(substr) >= 2 except: return string a = substr[0] b = substr[1] if b != "{": if len(substr) > 2: post_substr = substr[2:] new_str += "{" + a + "}{" + b + "}" + post_substr else: new_str += "{" + a + "}{" + b + "}" else: if len(substr) > 2: post_substr = substr[2:] new_str += "{" + a + "}" + b + post_substr else: new_str += "{" + a + "}" + b string = new_str return string def _fix_a_slash_b(string): if len(string.split("/")) != 2: return string a = string.split("/")[0] b = string.split("/")[1] try: if "sqrt" not in a: a = int(a) if "sqrt" not in b: b = int(b) assert string == "{}/{}".format(a, b) new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" return new_string except: return string def _fix_sqrt(string): _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) return _string def strip_string(string): string = str(string).strip() # linebreaks string = string.replace("\n", "") # right "." string = string.rstrip(".") # remove inverse spaces string = string.replace("\\!", "") string = string.replace("\\ ", "") # replace \\ with \ string = string.replace("\\\\", "\\") string = string.replace("\\\\", "\\") # replace tfrac and dfrac with frac string = string.replace("tfrac", "frac") string = string.replace("dfrac", "frac") # remove \left and \right string = string.replace("\\left", "") string = string.replace("\\right", "") # Remove unit: miles, dollars if after is not none _string = re.sub(r"\\text{.*?}$", "", string).strip() if _string != "" and _string != string: # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) string = _string # Remove circ (degrees) string = string.replace("^{\\circ}", "") string = string.replace("^\\circ", "") # remove dollar signs string = string.replace("\\$", "") string = string.replace("$", "") string = string.replace("\\text", "") string = string.replace("x\\in", "") # remove percentage string = string.replace("\\%", "") string = string.replace("\%", "") string = string.replace("%", "") # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string string = string.replace(" .", " 0.") string = string.replace("{.", "{0.") # cdot string = string.replace("\\cdot", "") # inf string = string.replace("infinity", "\\infty") if "\\infty" not in string: string = string.replace("inf", "\\infty") string = string.replace("+\\inity", "\\infty") # and string = string.replace("and", "") string = string.replace("\\mathbf", "") # use regex to remove \mbox{...} string = re.sub(r"\\mbox{.*?}", "", string) # quote string.replace("'", "") string.replace("\"", "") # i, j if "j" in string and "i" not in string: string = string.replace("j", "i") # replace a.000b where b is not number or b is end, with ab, use regex string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string) string = re.sub(r"(\d+)\.0+$", r"\1", string) # if empty, return empty string if len(string) == 0: return string if string[0] == ".": string = "0" + string # to consider: get rid of e.g. "k = " or "q = " at beginning if len(string.split("=")) == 2: if len(string.split("=")[0]) <= 2: string = string.split("=")[1] string = _fix_sqrt(string) string = string.replace(" ", "") # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} string = _fix_fracs(string) # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y string = _fix_a_slash_b(string) return string def extract_answer(pred_str): if 'boxed' in pred_str: ans = pred_str.split('boxed')[-1] if len(ans) == 0: return "" elif (ans[0] == '{'): stack = 1 a = '' for c in ans[1:]: if (c == '{'): stack += 1 a += c elif (c == '}'): stack -= 1 if (stack == 0): break a += c else: a += c else: a = ans.split('$')[0].strip() pred=a elif ('he answer is' in pred_str): pred = pred_str.split('he answer is')[-1].strip() elif extract_program_output(pred_str) != "": # fall back to program pred = extract_program_output(pred_str) else: # use the last number pattern = '-?\d*\.?\d+' pred = re.findall(pattern, pred_str.replace(",", "")) if(len(pred) >= 1): pred = pred[-1] else: pred = '' # multiple line pred = pred.split("\n")[0] if pred != "" and pred[0] == ":": pred = pred[1:] if pred != "" and pred[-1] == ".": pred = pred[:-1] if pred != "" and pred[-1] == "/": pred = pred[:-1] pred = strip_string(pred) return pred def extract_program(result: str, last_only=True): """ extract the program after "```python", and before "```" """ program = "" start = False for line in result.split("\n"): if line.startswith("```python"): if last_only: program = "" # only extract the last program else: program += "\n# ========\n" start = True elif line.startswith("```"): start = False elif start: program += line + "\n" return program def extract_program_output(pred_str): """ extract output between the last ```output\n...\n``` """ if "```output" not in pred_str: return "" if '```output' in pred_str: pred_str = pred_str.split('```output')[-1] if '```' in pred_str: pred_str = pred_str.split('```')[0] output = pred_str.strip() return output def parse_ground_truth(example: Dict[str, Any], data_name): if 'gt_cot' in example: return example['gt_cot'], strip_string(example['gt']) # parse ground truth if data_name in ["math", 'ocw']: gt_cot = example['solution'] gt_ans = extract_answer(gt_cot) elif data_name == "gsm8k": gt_cot, gt_ans = example['answer'].split("####") elif data_name == "gsm-hard": gt_cot, gt_ans = example['code'], example['target'] elif data_name == "svamp": gt_cot, gt_ans = example['Equation'], example['Answer'] elif data_name == "asdiv": gt_cot = example['formula'] gt_ans = re.sub(r"\(.*?\)", "", example['answer']) elif data_name == "mawps": gt_cot, gt_ans = None, example['target'] elif data_name == "tabmwp": gt_cot = example['solution'] gt_ans = example['answer'] if example['ans_type'] in ['integer_number', 'decimal_number']: if '/' in gt_ans: gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1]) elif ',' in gt_ans: gt_ans = float(gt_ans.replace(',', '')) elif '%' in gt_ans: gt_ans = float(gt_ans.split('%')[0]) / 100 else: gt_ans = float(gt_ans) elif data_name == "bbh": gt_cot, gt_ans = None, example['target'] else: raise NotImplementedError(data_name) # post process gt_cot = str(gt_cot).strip() gt_ans = strip_string(gt_ans) return gt_cot, gt_ans def parse_question(example, data_name): question = "" if data_name == "asdiv": question = f"{example['body'].strip()} {example['question'].strip()}" elif data_name == "svamp": body = example["Body"].strip() if not body.endswith("."): body = body + "." question = f'{body} {example["Question"].strip()}' elif data_name == "tabmwp": title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else "" question = f'Read the following table {title_str}and answer a question:\n' question += f'{example["table"]}\n{example["question"]}' if example['choices']: question += f' Please select from the following options: {example["choices"]}' else: for key in ['question', 'problem', 'Question', 'input']: if key in example: question = example[key] break assert question != "" return question.strip() def run_execute(executor, result, prompt_type, execute=False): if not result or result == 'error': return None, None report = None if "program_only" in prompt_type: prediction = extract_program_output(result) elif prompt_type in ["pot", "pal"] and execute: code = extract_program(result) prediction, report = executor.apply(code) else: prediction = extract_answer(result) prediction = strip_string(prediction) return prediction, report