import re import regex def _fix_fracs(string): substrs = string.split("\\frac") new_str = substrs[0] if len(substrs) > 1: substrs = substrs[1:] for substr in substrs: new_str += "\\frac" if len(substr) > 0 and substr[0] == "{": new_str += substr else: try: assert len(substr) >= 2 except: return string a = substr[0] b = substr[1] if b != "{": if len(substr) > 2: post_substr = substr[2:] new_str += "{" + a + "}{" + b + "}" + post_substr else: new_str += "{" + a + "}{" + b + "}" else: if len(substr) > 2: post_substr = substr[2:] new_str += "{" + a + "}" + b + post_substr else: new_str += "{" + a + "}" + b string = new_str return string def _fix_a_slash_b(string): if len(string.split("/")) != 2: return string a = string.split("/")[0] b = string.split("/")[1] try: if "sqrt" not in a: a = int(a) if "sqrt" not in b: b = int(b) assert string == "{}/{}".format(a, b) new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" return new_string except: return string def _fix_sqrt(string): _string = re.sub(r"\\sqrt(-?[0-9.a-zA-Z]+)", r"\\sqrt{\1}", string) _string = re.sub(r"\\sqrt\s+(\w+)$", r"\\sqrt{\1}", _string) return _string def _fix_tan(string): _string = re.sub(r"\\tan(-?[0-9.a-zA-Z]+)", r"\\tan{\1}", string) _string = re.sub(r"\\tan\s+(\w+)$", r"\\tan{\1}", _string) return _string def strip_string(string): string = str(string).strip() # linebreaks string = string.replace("\n", "") # right "." string = string.rstrip(".") # remove inverse spaces string = string.replace("\\!", "") # string = string.replace("\\ ", "") # replace \\ with \ # string = string.replace("\\\\", "\\") # string = string.replace("\\\\", "\\") if string.startswith("\\text{") and string.endswith("}"): string = string.split("{", 1)[1][:-1] # replace tfrac and dfrac with frac string = string.replace("tfrac", "frac") string = string.replace("dfrac", "frac") string = string.replace("cfrac", "frac") # remove \left and \right string = string.replace("\\left", "") string = string.replace("\\right", "") # Remove unit: miles, dollars if after is not none _string = re.sub(r"\\text{.*?}$", "", string).strip() if _string != "" and _string != string: # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) string = _string # Remove circ (degrees) string = string.replace("^{\\circ}", "").strip() string = string.replace("^\\circ", "").strip() string = regex.sub(r"\{(c|m)?m\}(\^(2|3))?", "", string).strip() string = regex.sub(r"p\.m\.$", "", string).strip() string = regex.sub(r"(\d)\s*t$", r"\1", string).strip() # remove dollar signs string = string.replace("\\$", "") string = string.replace("$", "") # string = string.replace("\\text", "") string = string.replace("x\\in", "") # remove percentage string = string.replace("\\%", "%") string = string.replace("\%", "%") # string = string.replace("%", "") # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string string = string.replace(" .", " 0.") string = string.replace("{.", "{0.") # cdot string = string.replace("\\cdot", "") # inf string = string.replace("infinity", "\\infty") if "\\infty" not in string: string = string.replace("inf", "\\infty") string = string.replace("+\\inity", "\\infty") # and # string = string.replace("and", "") string = string.replace("\\mathbf", "") string = string.replace("\\mathrm", "") # use regex to remove \mbox{...} string = re.sub(r"\\mbox{.*?}", "", string) # quote string.replace("'", "") string.replace("\"", "") # i, j if "j" in string and "i" not in string: string = string.replace("j", "i") # replace a.000b where b is not number or b is end, with ab, use regex string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string) string = re.sub(r"(\d+)\.0+$", r"\1", string) # if empty, return empty string if len(string) == 0: return string if string[0] == ".": string = "0" + string # to consider: get rid of e.g. "k = " or "q = " at beginning # if len(string.split("=")) == 2: # if len(string.split("=")[0]) <= 2: # string = string.split("=")[1] string = _fix_sqrt(string) string = _fix_tan(string) string = string.replace(" ", "") # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} string = _fix_fracs(string) # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y string = _fix_a_slash_b(string) string = regex.sub(r"(\\|,|\.)+$", "", string) return string def extract_boxed_answers(text): answers = [] for piece in text.split('boxed{')[1:]: n = 0 for i in range(len(piece)): if piece[i] == '{': n += 1 elif piece[i] == '}': n -= 1 if n < 0: if i + 1 < len(piece) and piece[i + 1] == '%': answers.append(piece[: i + 1]) else: answers.append(piece[:i]) break return answers def extract_program_output(pred_str): """ extract output between the last ```output\n...\n``` """ if "```output" not in pred_str: return "" if '```output' in pred_str: pred_str = pred_str.split('```output')[-1] if '```' in pred_str: pred_str = pred_str.split('```')[0] output = pred_str.strip() return output def extract_answer(pred_str, exhaust=False): pred = [] if 'final answer is $' in pred_str and '$. I hope' in pred_str: tmp = pred_str.split('final answer is $', 1)[1] pred = [tmp.split('$. I hope', 1)[0].strip()] elif 'boxed' in pred_str: pred = extract_boxed_answers(pred_str) elif ('he answer is' in pred_str): pred = [pred_str.split('he answer is')[-1].strip()] else: program_output = extract_program_output(pred_str) if program_output != "": # fall back to program pred.append(program_output) else: # use the last number pattern = '-?\d*\.?\d+' ans = re.findall(pattern, pred_str.replace(",", "")) if(len(ans) >= 1): ans = ans[-1] else: ans = '' if ans: pred.append(ans) # multiple line _pred = [] for ans in pred: ans = ans.strip().split("\n")[0] ans = ans.lstrip(":") ans = ans.rstrip(".") ans = ans.rstrip("/") ans = strip_string(ans) _pred.append(ans) if exhaust: return _pred else: return _pred[-1] if _pred else "" def extract_math_answer(question, reasoning, task): answer = [] for ans in extract_answer(reasoning, exhaust=True): if 'separated by commas' in question and all(ch not in ans for ch in '()[]'): answer.extend([a.strip() for a in ans.split(",")]) elif regex.search(r"\\text\{\s*and\s*\}", ans): answer.extend([a.strip() for a in regex.sub(r"\\text\{\s*and\s*\}", "[SEP]", ans).split("[SEP]")]) else: answer.append(ans.strip()) return answer def extract_math_few_shot_cot_answer(question, reasoning, task): if 'Problem:' in reasoning: reasoning = reasoning.split("Problem:", 1)[0] return extract_math_answer(question, reasoning, task) def extract_last_single_answer(question, reasoning, task): return extract_answer(reasoning, exhaust=False) def extract_gsm_few_shot_cot_answer(question, reasoning, task): if 'Q: ' in reasoning: reasoning = reasoning.split("Q: ", 1)[0] pred = [s for s in regex.findall(r'-?\d+\.?\d*', reasoning)] if pred: return pred[-1] else: return "[invalid]" def extract_agieval_gaokao_mathcloze_few_shot_cot_test(question, reasoning, task): if '问题 ' in reasoning: reasoning = reasoning.split("问题 ", 1)[0] if '答案是' in reasoning: ans = reasoning.split('答案是', 1)[1].strip() ans = ans.split("\n")[0].strip() ans = [ans.strip("$")] else: ans = ['placeholder'] return ans def extract_agieval_gaokao_mathqa_few_shot_cot_test(question, reasoning, task): if '问题 ' in reasoning: reasoning = reasoning.split("问题 ", 1)[0] if '答案是' in reasoning: ans = reasoning.split('答案是', 1)[1].strip() ans = ans.split("\n")[0].strip() else: ans = 'placeholder' return ans def extract_sat_few_shot_answer(question, reasoning, task): if 'Problem:' in reasoning: reasoning = reasoning.split("Problem:", 1)[0] patt = regex.search(r"the final answer is \(?(?P[abcd])\)?", reasoning.lower()) if patt is not None: return patt.group('ans').upper() return 'placeholder' def extract_ocwcourses_few_shot_answer(question, reasoning, task): if 'Problem:' in reasoning: reasoning = reasoning.split("Problem:", 1)[0] patt = regex.search(r"final answer is (?P.*)\. I hope it is correct.", reasoning) if patt is None: pred = "[invalid]" print(f"DEBUG >>>\n{reasoning}", flush=True) else: pred = patt.group('ans') return pred def extract_mmlu_stem(question, reasoning, task): if 'Problem:' in reasoning: reasoning = reasoning.split("Problem:", 1)[0] return extract_sat_few_shot_answer(question, reasoning, task) def extract_minif2f_isabelle(question, reasoning, task): if 'Informal:' in reasoning: reasoning = reasoning.split("Informal:", 1)[0] return reasoning.strip() def extract_cmath_few_shot_test(question, reasoning, task): if '问题:' in reasoning: reasoning = reasoning.split("问题:", 1)[0] if '答案是' in reasoning: ans = reasoning.split('答案是', 1)[1].strip() ans = ans.split("\n")[0] ans = ans.strip(":") ans = ans.strip("。") try: ans = [s for s in regex.findall(r'-?\d+\.?\d*', ans)][-1] except: print(f"DEBUG CMATH: {reasoning}", flush=True) ans = "[invalid]" else: ans = extract_last_single_answer(question, reasoning, task) return ans