mirror of
https://github.com/deepseek-ai/DeepSeek-Math
synced 2024-11-22 11:38:17 +00:00
173 lines
6.0 KiB
Python
173 lines
6.0 KiB
Python
import regex
|
|
from copy import deepcopy
|
|
from eval.eval_utils import math_equal
|
|
from eval.ocwcourses_eval_utils import normalize_numeric, numeric_equality, normalize_symbolic_equation, SymbolicMathMixin
|
|
|
|
def is_correct(item, pred_key='prediction', prec=1e-3):
|
|
pred = item[pred_key]
|
|
ans = item['answer']
|
|
if isinstance(pred, list) and isinstance(ans, list):
|
|
pred_matched = set()
|
|
ans_matched = set()
|
|
for i in range(len(pred)):
|
|
for j in range(len(ans)):
|
|
item_cpy = deepcopy(item)
|
|
item_cpy.update({
|
|
pred_key: pred[i],
|
|
'answer': ans[j]
|
|
})
|
|
if is_correct(item_cpy, pred_key=pred_key, prec=prec):
|
|
pred_matched.add(i)
|
|
ans_matched.add(j)
|
|
if item_cpy[pred_key] == '2,3,4':
|
|
print(item, flush=True)
|
|
print("wtf", flush=True)
|
|
return len(pred_matched) == len(pred) and len(ans_matched) == len(ans)
|
|
elif isinstance(pred, str) and isinstance(ans, str):
|
|
if '\\cup' in pred and '\\cup' in ans:
|
|
item = deepcopy(item)
|
|
item.update({
|
|
pred_key: pred.split('\\cup'),
|
|
'answer': ans.split('\\cup'),
|
|
})
|
|
return is_correct(item, pred_key=pred_key, prec=prec)
|
|
else:
|
|
label = False
|
|
try:
|
|
label = abs(float(regex.sub(r',', '', str(pred))) - float(regex.sub(r',', '', str(ans)))) < prec
|
|
except:
|
|
pass
|
|
label = label or (ans and pred == ans) or math_equal(pred, ans)
|
|
return label
|
|
else:
|
|
print(item, flush=True)
|
|
raise NotImplementedError()
|
|
|
|
def eval_math(item, pred_key='prediction', prec=1e-3):
|
|
pred = item[pred_key]
|
|
if pred_key == 'program_output' and isinstance(pred, str):
|
|
pred = [pred]
|
|
ans = item['answer']
|
|
if isinstance(pred, list) and isinstance(ans, list):
|
|
# for some questions in MATH, `reference` repeats answers
|
|
_ans = []
|
|
for a in ans:
|
|
if a not in _ans:
|
|
_ans.append(a)
|
|
ans = _ans
|
|
# some predictions for MATH questions also repeats answers
|
|
_pred = []
|
|
for a in pred:
|
|
if a not in _pred:
|
|
_pred.append(a)
|
|
# some predictions mistakenly box non-answer strings
|
|
pred = _pred[-len(ans):]
|
|
|
|
item.update({
|
|
pred_key: pred,
|
|
'answer': ans
|
|
})
|
|
return is_correct(item, pred_key=pred_key, prec=prec)
|
|
|
|
def eval_last_single_answer(item, pred_key='prediction', prec=1e-3):
|
|
for key in [pred_key, 'answer']:
|
|
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
|
|
return is_correct(item, pred_key=pred_key, prec=prec)
|
|
|
|
def eval_agieval_gaokao_math_cloze(item, pred_key='prediction', prec=1e-3):
|
|
if pred_key == 'program_output' and isinstance(item[pred_key], str):
|
|
item[pred_key] = [item[pred_key]]
|
|
for key in [pred_key, 'answer']:
|
|
assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list"
|
|
pred = item[pred_key]
|
|
ans = item['answer']
|
|
_pred = []
|
|
for p in pred:
|
|
p = p + ";"
|
|
while p:
|
|
left_brackets = 0
|
|
for i in range(len(p)):
|
|
if p[i] == ';' or (p[i] == ',' and left_brackets == 0):
|
|
_p, p = p[:i].strip(), p[i + 1:].strip()
|
|
if _p not in _pred:
|
|
_pred.append(_p)
|
|
break
|
|
elif p[i] in '([{':
|
|
left_brackets += 1
|
|
elif p[i] in ')]}':
|
|
left_brackets -= 1
|
|
pred = _pred[-len(ans):]
|
|
if len(pred) == len(ans):
|
|
for p, a in zip(pred, ans):
|
|
item.update({
|
|
pred_key: p,
|
|
'answer': a,
|
|
})
|
|
if not is_correct(item, pred_key=pred_key, prec=prec):
|
|
return False
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def eval_agieval_gaokao_mathqa(item, pred_key='prediction', prec=1e-3):
|
|
if pred_key == 'program_output' and isinstance(item[pred_key], str):
|
|
item[pred_key] = [item[pred_key]]
|
|
pred_str = " ".join(item[pred_key])
|
|
ans = item['answer']
|
|
tag = None
|
|
idx = -1
|
|
for t in 'ABCD':
|
|
if t in pred_str and pred_str.index(t) > idx:
|
|
tag = t
|
|
idx = pred_str.index(t)
|
|
return tag == ans
|
|
|
|
def eval_math_sat(item, pred_key='prediction', prec=1e-3):
|
|
for key in [pred_key, 'answer']:
|
|
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
|
|
return item[pred_key].lower() == item['answer'].lower()
|
|
|
|
def eval_mmlu_stem(item, pred_key='prediction', prec=1e-3):
|
|
return eval_math_sat(item, pred_key=pred_key, prec=prec)
|
|
|
|
def eval_ocwcourses(item, pred_key='prediction', prec=1e-3):
|
|
INVALID_ANSWER = "[invalidanswer]"
|
|
for key in [pred_key, 'answer']:
|
|
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
|
|
pred = item[pred_key]
|
|
ans = item['answer']
|
|
|
|
try:
|
|
float(ans)
|
|
normalize_fn = normalize_numeric
|
|
is_equiv = numeric_equality
|
|
answer_type = "numeric"
|
|
except ValueError:
|
|
if "=" in ans:
|
|
normalize_fn = normalize_symbolic_equation
|
|
is_equiv = lambda x, y: x==y
|
|
answer_type = "equation"
|
|
else:
|
|
normalize_fn = SymbolicMathMixin().normalize_tex
|
|
is_equiv = SymbolicMathMixin().is_tex_equiv
|
|
answer_type = "expression"
|
|
|
|
correct_answer = normalize_fn(ans)
|
|
|
|
unnormalized_answer = pred if pred else INVALID_ANSWER
|
|
model_answer = normalize_fn(unnormalized_answer)
|
|
|
|
if unnormalized_answer == INVALID_ANSWER:
|
|
acc = 0
|
|
elif model_answer == INVALID_ANSWER:
|
|
acc = 0
|
|
elif is_equiv(model_answer, correct_answer):
|
|
acc = 1
|
|
else:
|
|
acc = 0
|
|
|
|
return acc
|
|
|
|
def eval_minif2f_isabelle(item, pred_key='prediction', prec=1e-3):
|
|
return True
|