mirror of
https://github.com/deepseek-ai/DeepSeek-Coder
synced 2025-01-23 19:07:17 +00:00
321 lines
9.8 KiB
Python
321 lines
9.8 KiB
Python
"""
|
|
This file come from: https://github.com/microsoft/ToRA/blob/main/src/utils/parser.py
|
|
"""
|
|
import re
|
|
from typing import Any, Dict
|
|
|
|
|
|
def _fix_fracs(string):
|
|
substrs = string.split("\\frac")
|
|
new_str = substrs[0]
|
|
if len(substrs) > 1:
|
|
substrs = substrs[1:]
|
|
for substr in substrs:
|
|
new_str += "\\frac"
|
|
if len(substr) > 0 and substr[0] == "{":
|
|
new_str += substr
|
|
else:
|
|
try:
|
|
assert len(substr) >= 2
|
|
except:
|
|
return string
|
|
a = substr[0]
|
|
b = substr[1]
|
|
if b != "{":
|
|
if len(substr) > 2:
|
|
post_substr = substr[2:]
|
|
new_str += "{" + a + "}{" + b + "}" + post_substr
|
|
else:
|
|
new_str += "{" + a + "}{" + b + "}"
|
|
else:
|
|
if len(substr) > 2:
|
|
post_substr = substr[2:]
|
|
new_str += "{" + a + "}" + b + post_substr
|
|
else:
|
|
new_str += "{" + a + "}" + b
|
|
string = new_str
|
|
return string
|
|
|
|
|
|
def _fix_a_slash_b(string):
|
|
if len(string.split("/")) != 2:
|
|
return string
|
|
a = string.split("/")[0]
|
|
b = string.split("/")[1]
|
|
try:
|
|
if "sqrt" not in a:
|
|
a = int(a)
|
|
if "sqrt" not in b:
|
|
b = int(b)
|
|
assert string == "{}/{}".format(a, b)
|
|
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
|
return new_string
|
|
except:
|
|
return string
|
|
|
|
|
|
def _fix_sqrt(string):
|
|
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
|
|
return _string
|
|
|
|
|
|
def strip_string(string):
|
|
string = str(string).strip()
|
|
# linebreaks
|
|
string = string.replace("\n", "")
|
|
|
|
# right "."
|
|
string = string.rstrip(".")
|
|
|
|
# remove inverse spaces
|
|
string = string.replace("\\!", "")
|
|
string = string.replace("\\ ", "")
|
|
|
|
# replace \\ with \
|
|
string = string.replace("\\\\", "\\")
|
|
string = string.replace("\\\\", "\\")
|
|
|
|
# replace tfrac and dfrac with frac
|
|
string = string.replace("tfrac", "frac")
|
|
string = string.replace("dfrac", "frac")
|
|
|
|
# remove \left and \right
|
|
string = string.replace("\\left", "")
|
|
string = string.replace("\\right", "")
|
|
|
|
# Remove unit: miles, dollars if after is not none
|
|
_string = re.sub(r"\\text{.*?}$", "", string).strip()
|
|
if _string != "" and _string != string:
|
|
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
|
|
string = _string
|
|
|
|
# Remove circ (degrees)
|
|
string = string.replace("^{\\circ}", "")
|
|
string = string.replace("^\\circ", "")
|
|
|
|
# remove dollar signs
|
|
string = string.replace("\\$", "")
|
|
string = string.replace("$", "")
|
|
|
|
string = string.replace("\\text", "")
|
|
string = string.replace("x\\in", "")
|
|
|
|
# remove percentage
|
|
string = string.replace("\\%", "")
|
|
string = string.replace("\%", "")
|
|
string = string.replace("%", "")
|
|
|
|
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
|
string = string.replace(" .", " 0.")
|
|
string = string.replace("{.", "{0.")
|
|
|
|
# cdot
|
|
string = string.replace("\\cdot", "")
|
|
|
|
# inf
|
|
string = string.replace("infinity", "\\infty")
|
|
if "\\infty" not in string:
|
|
string = string.replace("inf", "\\infty")
|
|
string = string.replace("+\\inity", "\\infty")
|
|
|
|
# and
|
|
string = string.replace("and", "")
|
|
string = string.replace("\\mathbf", "")
|
|
|
|
# use regex to remove \mbox{...}
|
|
string = re.sub(r"\\mbox{.*?}", "", string)
|
|
|
|
# quote
|
|
string.replace("'", "")
|
|
string.replace("\"", "")
|
|
|
|
# i, j
|
|
if "j" in string and "i" not in string:
|
|
string = string.replace("j", "i")
|
|
|
|
# replace a.000b where b is not number or b is end, with ab, use regex
|
|
string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string)
|
|
string = re.sub(r"(\d+)\.0+$", r"\1", string)
|
|
|
|
# if empty, return empty string
|
|
if len(string) == 0:
|
|
return string
|
|
if string[0] == ".":
|
|
string = "0" + string
|
|
|
|
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
|
if len(string.split("=")) == 2:
|
|
if len(string.split("=")[0]) <= 2:
|
|
string = string.split("=")[1]
|
|
|
|
string = _fix_sqrt(string)
|
|
string = string.replace(" ", "")
|
|
|
|
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
|
string = _fix_fracs(string)
|
|
|
|
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
|
string = _fix_a_slash_b(string)
|
|
|
|
return string
|
|
|
|
def extract_answer(pred_str):
|
|
if 'boxed' in pred_str:
|
|
ans = pred_str.split('boxed')[-1]
|
|
if len(ans) == 0:
|
|
return ""
|
|
elif (ans[0] == '{'):
|
|
stack = 1
|
|
a = ''
|
|
for c in ans[1:]:
|
|
if (c == '{'):
|
|
stack += 1
|
|
a += c
|
|
elif (c == '}'):
|
|
stack -= 1
|
|
if (stack == 0): break
|
|
a += c
|
|
else:
|
|
a += c
|
|
else:
|
|
a = ans.split('$')[0].strip()
|
|
pred=a
|
|
elif ('he answer is' in pred_str):
|
|
pred = pred_str.split('he answer is')[-1].strip()
|
|
elif extract_program_output(pred_str) != "":
|
|
# fall back to program
|
|
pred = extract_program_output(pred_str)
|
|
else: # use the last number
|
|
pattern = '-?\d*\.?\d+'
|
|
pred = re.findall(pattern, pred_str.replace(",", ""))
|
|
if(len(pred) >= 1):
|
|
pred = pred[-1]
|
|
else: pred = ''
|
|
|
|
# multiple line
|
|
pred = pred.split("\n")[0]
|
|
if pred != "" and pred[0] == ":":
|
|
pred = pred[1:]
|
|
if pred != "" and pred[-1] == ".":
|
|
pred = pred[:-1]
|
|
if pred != "" and pred[-1] == "/":
|
|
pred = pred[:-1]
|
|
pred = strip_string(pred)
|
|
return pred
|
|
|
|
|
|
def extract_program(result: str, last_only=True):
|
|
"""
|
|
extract the program after "```python", and before "```"
|
|
"""
|
|
program = ""
|
|
start = False
|
|
for line in result.split("\n"):
|
|
if line.startswith("```python"):
|
|
if last_only:
|
|
program = "" # only extract the last program
|
|
else:
|
|
program += "\n# ========\n"
|
|
start = True
|
|
elif line.startswith("```"):
|
|
start = False
|
|
elif start:
|
|
program += line + "\n"
|
|
return program
|
|
|
|
|
|
def extract_program_output(pred_str):
|
|
"""
|
|
extract output between the last ```output\n...\n```
|
|
"""
|
|
if "```output" not in pred_str:
|
|
return ""
|
|
if '```output' in pred_str:
|
|
pred_str = pred_str.split('```output')[-1]
|
|
if '```' in pred_str:
|
|
pred_str = pred_str.split('```')[0]
|
|
output = pred_str.strip()
|
|
return output
|
|
|
|
|
|
def parse_ground_truth(example: Dict[str, Any], data_name):
|
|
if 'gt_cot' in example:
|
|
return example['gt_cot'], strip_string(example['gt'])
|
|
|
|
# parse ground truth
|
|
if data_name in ["math", 'ocw']:
|
|
gt_cot = example['solution']
|
|
gt_ans = extract_answer(gt_cot)
|
|
elif data_name == "gsm8k":
|
|
gt_cot, gt_ans = example['answer'].split("####")
|
|
elif data_name == "gsm-hard":
|
|
gt_cot, gt_ans = example['code'], example['target']
|
|
elif data_name == "svamp":
|
|
gt_cot, gt_ans = example['Equation'], example['Answer']
|
|
elif data_name == "asdiv":
|
|
gt_cot = example['formula']
|
|
gt_ans = re.sub(r"\(.*?\)", "", example['answer'])
|
|
elif data_name == "mawps":
|
|
gt_cot, gt_ans = None, example['target']
|
|
elif data_name == "tabmwp":
|
|
gt_cot = example['solution']
|
|
gt_ans = example['answer']
|
|
if example['ans_type'] in ['integer_number', 'decimal_number']:
|
|
if '/' in gt_ans:
|
|
gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1])
|
|
elif ',' in gt_ans:
|
|
gt_ans = float(gt_ans.replace(',', ''))
|
|
elif '%' in gt_ans:
|
|
gt_ans = float(gt_ans.split('%')[0]) / 100
|
|
else:
|
|
gt_ans = float(gt_ans)
|
|
elif data_name == "bbh":
|
|
gt_cot, gt_ans = None, example['target']
|
|
else:
|
|
raise NotImplementedError(data_name)
|
|
# post process
|
|
gt_cot = str(gt_cot).strip()
|
|
gt_ans = strip_string(gt_ans)
|
|
return gt_cot, gt_ans
|
|
|
|
|
|
def parse_question(example, data_name):
|
|
question = ""
|
|
if data_name == "asdiv":
|
|
question = f"{example['body'].strip()} {example['question'].strip()}"
|
|
elif data_name == "svamp":
|
|
body = example["Body"].strip()
|
|
if not body.endswith("."):
|
|
body = body + "."
|
|
question = f'{body} {example["Question"].strip()}'
|
|
elif data_name == "tabmwp":
|
|
title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else ""
|
|
question = f'Read the following table {title_str}and answer a question:\n'
|
|
question += f'{example["table"]}\n{example["question"]}'
|
|
if example['choices']:
|
|
question += f' Please select from the following options: {example["choices"]}'
|
|
else:
|
|
for key in ['question', 'problem', 'Question', 'input']:
|
|
if key in example:
|
|
question = example[key]
|
|
break
|
|
assert question != ""
|
|
return question.strip()
|
|
|
|
|
|
def run_execute(executor, result, prompt_type, execute=False):
|
|
if not result or result == 'error':
|
|
return None, None
|
|
report = None
|
|
|
|
if "program_only" in prompt_type:
|
|
prediction = extract_program_output(result)
|
|
elif prompt_type in ["pot", "pal"] and execute:
|
|
code = extract_program(result)
|
|
prediction, report = executor.apply(code)
|
|
else:
|
|
prediction = extract_answer(result)
|
|
|
|
prediction = strip_string(prediction)
|
|
return prediction, report
|