mirror of
https://github.com/deepseek-ai/DeepSeek-Math
synced 2024-11-28 06:53:56 +00:00
326 lines
11 KiB
Python
326 lines
11 KiB
Python
|
import multiprocessing
|
||
|
from math import isclose
|
||
|
import numpy as np
|
||
|
from typing import Union, Any, Dict
|
||
|
|
||
|
from sympy import simplify, N
|
||
|
from sympy.parsing.sympy_parser import parse_expr
|
||
|
from sympy.parsing.latex import parse_latex
|
||
|
import re
|
||
|
import regex
|
||
|
|
||
|
from data_processing.answer_extraction import extract_answer, extract_program_output, strip_string
|
||
|
|
||
|
def extract_program(result: str, last_only=True):
|
||
|
"""
|
||
|
extract the program after "```python", and before "```"
|
||
|
"""
|
||
|
program = ""
|
||
|
start = False
|
||
|
for line in result.split("\n"):
|
||
|
if line.startswith("```python"):
|
||
|
if last_only:
|
||
|
program = "" # only extract the last program
|
||
|
else:
|
||
|
program += "\n# ========\n"
|
||
|
start = True
|
||
|
elif line.startswith("```"):
|
||
|
start = False
|
||
|
elif start:
|
||
|
program += line + "\n"
|
||
|
return program
|
||
|
|
||
|
|
||
|
def parse_ground_truth(example: Dict[str, Any], data_name):
|
||
|
if 'gt_cot' in example:
|
||
|
return example['gt_cot'], strip_string(example['gt'])
|
||
|
|
||
|
# parse ground truth
|
||
|
if data_name in ["math", 'ocw']:
|
||
|
gt_cot = example['solution']
|
||
|
gt_ans = extract_answer(gt_cot)
|
||
|
elif data_name == "gsm8k":
|
||
|
gt_cot, gt_ans = example['answer'].split("####")
|
||
|
elif data_name == "gsm-hard":
|
||
|
gt_cot, gt_ans = example['code'], example['target']
|
||
|
elif data_name == "svamp":
|
||
|
gt_cot, gt_ans = example['Equation'], example['Answer']
|
||
|
elif data_name == "asdiv":
|
||
|
gt_cot = example['formula']
|
||
|
gt_ans = re.sub(r"\(.*?\)", "", example['answer'])
|
||
|
elif data_name == "mawps":
|
||
|
gt_cot, gt_ans = None, example['target']
|
||
|
elif data_name == "tabmwp":
|
||
|
gt_cot = example['solution']
|
||
|
gt_ans = example['answer']
|
||
|
if example['ans_type'] in ['integer_number', 'decimal_number']:
|
||
|
if '/' in gt_ans:
|
||
|
gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1])
|
||
|
elif ',' in gt_ans:
|
||
|
gt_ans = float(gt_ans.replace(',', ''))
|
||
|
elif '%' in gt_ans:
|
||
|
gt_ans = float(gt_ans.split('%')[0]) / 100
|
||
|
else:
|
||
|
gt_ans = float(gt_ans)
|
||
|
elif data_name == "bbh":
|
||
|
gt_cot, gt_ans = None, example['target']
|
||
|
else:
|
||
|
raise NotImplementedError(data_name)
|
||
|
# post process
|
||
|
gt_cot = str(gt_cot).strip()
|
||
|
gt_ans = strip_string(gt_ans)
|
||
|
return gt_cot, gt_ans
|
||
|
|
||
|
|
||
|
def parse_question(example, data_name):
|
||
|
question = ""
|
||
|
if data_name == "asdiv":
|
||
|
question = f"{example['body'].strip()} {example['question'].strip()}"
|
||
|
elif data_name == "svamp":
|
||
|
body = example["Body"].strip()
|
||
|
if not body.endswith("."):
|
||
|
body = body + "."
|
||
|
question = f'{body} {example["Question"].strip()}'
|
||
|
elif data_name == "tabmwp":
|
||
|
title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else ""
|
||
|
question = f'Read the following table {title_str}and answer a question:\n'
|
||
|
question += f'{example["table"]}\n{example["question"]}'
|
||
|
if example['choices']:
|
||
|
question += f' Please select from the following options: {example["choices"]}'
|
||
|
else:
|
||
|
for key in ['question', 'problem', 'Question', 'input']:
|
||
|
if key in example:
|
||
|
question = example[key]
|
||
|
break
|
||
|
assert question != ""
|
||
|
return question.strip()
|
||
|
|
||
|
|
||
|
def run_execute(executor, result, prompt_type, execute=False):
|
||
|
if not result or result == 'error':
|
||
|
return None, None
|
||
|
report = None
|
||
|
|
||
|
if "program_only" in prompt_type:
|
||
|
prediction = extract_program_output(result)
|
||
|
elif prompt_type in ["pot", "pal"] and execute:
|
||
|
code = extract_program(result)
|
||
|
prediction, report = executor.apply(code)
|
||
|
else:
|
||
|
prediction = extract_answer(result)
|
||
|
|
||
|
prediction = strip_string(prediction)
|
||
|
return prediction, report
|
||
|
|
||
|
|
||
|
def parse_digits(num):
|
||
|
# format: 234.23 || 23%
|
||
|
num = regex.sub(',', '', str(num))
|
||
|
try:
|
||
|
return float(num)
|
||
|
except:
|
||
|
if num.endswith('%'):
|
||
|
num = num[:-1]
|
||
|
if num.endswith('\\'):
|
||
|
num = num[:-1]
|
||
|
try:
|
||
|
return float(num) / 100
|
||
|
except:
|
||
|
pass
|
||
|
return None
|
||
|
|
||
|
def is_digit(num):
|
||
|
# paired with parse_digits
|
||
|
return parse_digits(num) is not None
|
||
|
|
||
|
|
||
|
def normalize_prediction(prediction):
|
||
|
try: # 1. numerical equal
|
||
|
if is_digit(prediction):
|
||
|
prediction = np.round(float(str(prediction).replace(",", "")), 6)
|
||
|
return str(prediction)
|
||
|
except:
|
||
|
pass
|
||
|
|
||
|
# 2. symbolic equal
|
||
|
prediction = str(prediction).strip()
|
||
|
|
||
|
## deal with [], (), {}
|
||
|
brackets = []
|
||
|
while prediction.startswith("[") and prediction.endswith("]") or (prediction.startswith("(") and prediction.endswith(")")):
|
||
|
bracket = prediction[0]
|
||
|
prediction = prediction[1:-1]
|
||
|
if brackets and ',' in prediction:
|
||
|
pred_parts = [normalize_prediction(part) for part in prediction.split(",")]
|
||
|
prediction = ",".join(pred_parts)
|
||
|
|
||
|
if brackets:
|
||
|
for b in reversed(brackets):
|
||
|
if b == '[':
|
||
|
prediction = '[' + prediction + ']'
|
||
|
else:
|
||
|
assert b == '('
|
||
|
prediction = '(' + prediction + ')'
|
||
|
|
||
|
def _parse(s):
|
||
|
for f in [parse_latex, parse_expr]:
|
||
|
try:
|
||
|
return f(s)
|
||
|
except:
|
||
|
pass
|
||
|
return s
|
||
|
|
||
|
prediction = _parse(prediction)
|
||
|
|
||
|
for s in ['{', "}", "(", ")"]:
|
||
|
prediction = prediction.replace(s, "")
|
||
|
|
||
|
return prediction
|
||
|
|
||
|
|
||
|
def math_equal(prediction: Union[bool, float, str],
|
||
|
reference: Union[float, str],
|
||
|
include_percentage: bool = True,
|
||
|
is_close: bool = True,
|
||
|
timeout: bool = False,
|
||
|
) -> bool:
|
||
|
"""
|
||
|
Exact match of math if and only if:
|
||
|
1. numerical equal: both can convert to float and are equal
|
||
|
2. symbolic equal: both can convert to sympy expression and are equal
|
||
|
"""
|
||
|
if str(prediction) == str(reference):
|
||
|
return True
|
||
|
|
||
|
try: # 1. numerical equal
|
||
|
if is_digit(prediction) and is_digit(reference):
|
||
|
prediction = parse_digits(prediction)
|
||
|
reference = parse_digits(reference)
|
||
|
# number questions
|
||
|
if include_percentage:
|
||
|
gt_result = [reference / 100, reference, reference * 100]
|
||
|
else:
|
||
|
gt_result = [reference]
|
||
|
for item in gt_result:
|
||
|
try:
|
||
|
if is_close:
|
||
|
if isclose(item, prediction, abs_tol=1e-3):
|
||
|
return True
|
||
|
else:
|
||
|
if item == prediction:
|
||
|
return True
|
||
|
except Exception:
|
||
|
continue
|
||
|
return False
|
||
|
except:
|
||
|
pass
|
||
|
|
||
|
if not prediction and prediction not in [0, False]:
|
||
|
return False
|
||
|
|
||
|
# 2. symbolic equal
|
||
|
reference = str(reference).strip()
|
||
|
prediction = str(prediction).strip()
|
||
|
|
||
|
if regex.match(r'(\(|\[).+(\)|\])', prediction) is not None and regex.match(r'(\(|\[).+(\)|\])', reference) is not None:
|
||
|
pred_parts = prediction[1:-1].split(",")
|
||
|
ref_parts = reference[1:-1].split(",")
|
||
|
if len(pred_parts) == len(ref_parts):
|
||
|
if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
|
||
|
return True
|
||
|
|
||
|
if (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) and \
|
||
|
(reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")):
|
||
|
pred_lines = [line.strip() for line in prediction[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()]
|
||
|
ref_lines = [line.strip() for line in reference[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()]
|
||
|
matched = True
|
||
|
if len(pred_lines) == len(ref_lines):
|
||
|
for pred_line, ref_line in zip(pred_lines, ref_lines):
|
||
|
pred_parts = pred_line.split("&")
|
||
|
ref_parts = ref_line.split("&")
|
||
|
if len(pred_parts) == len(ref_parts):
|
||
|
if not all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
|
||
|
matched = False
|
||
|
break
|
||
|
else:
|
||
|
matched = False
|
||
|
if not matched:
|
||
|
break
|
||
|
else:
|
||
|
matched = False
|
||
|
if matched:
|
||
|
return True
|
||
|
|
||
|
if prediction.count('=') == 1 and reference.count('=') == 1:
|
||
|
pred = prediction.split('=')
|
||
|
pred = f"{pred[0].strip()} - ({pred[1].strip()})"
|
||
|
ref = reference.split('=')
|
||
|
ref = f"{ref[0].strip()} - ({ref[1].strip()})"
|
||
|
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
|
||
|
return True
|
||
|
elif prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference:
|
||
|
if math_equal(prediction.split('=')[1], reference, include_percentage, is_close):
|
||
|
return True
|
||
|
elif reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction:
|
||
|
if math_equal(prediction, reference.split('=')[1], include_percentage, is_close):
|
||
|
return True
|
||
|
|
||
|
# symbolic equal with sympy
|
||
|
if timeout:
|
||
|
if call_with_timeout(symbolic_equal_process, prediction, reference):
|
||
|
return True
|
||
|
else:
|
||
|
if symbolic_equal(prediction, reference):
|
||
|
return True
|
||
|
|
||
|
return False
|
||
|
|
||
|
|
||
|
def math_equal_process(param):
|
||
|
return math_equal(param[-2], param[-1])
|
||
|
|
||
|
|
||
|
def symbolic_equal(a, b):
|
||
|
def _parse(s):
|
||
|
for f in [parse_latex, parse_expr]:
|
||
|
try:
|
||
|
return f(s)
|
||
|
except:
|
||
|
pass
|
||
|
return s
|
||
|
a = _parse(a)
|
||
|
b = _parse(b)
|
||
|
|
||
|
try:
|
||
|
if simplify(a-b) == 0:
|
||
|
return True
|
||
|
except:
|
||
|
pass
|
||
|
|
||
|
try:
|
||
|
if isclose(N(a), N(b), abs_tol=1e-3):
|
||
|
return True
|
||
|
except:
|
||
|
pass
|
||
|
return False
|
||
|
|
||
|
|
||
|
def symbolic_equal_process(a, b, output_queue):
|
||
|
result = symbolic_equal(a, b)
|
||
|
output_queue.put(result)
|
||
|
|
||
|
|
||
|
def call_with_timeout(func, *args, timeout=1, **kwargs):
|
||
|
output_queue = multiprocessing.Queue()
|
||
|
process_args = args + (output_queue,)
|
||
|
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
|
||
|
process.start()
|
||
|
process.join(timeout)
|
||
|
|
||
|
if process.is_alive():
|
||
|
process.terminate()
|
||
|
process.join()
|
||
|
return False
|
||
|
|
||
|
return output_queue.get()
|