DeepSeek-Math/evaluation/eval/eval_script.py

173 lines
6.0 KiB
Python
Raw Normal View History

2024-02-06 02:27:40 +00:00
import regex
from copy import deepcopy
from eval.eval_utils import math_equal
from eval.ocwcourses_eval_utils import normalize_numeric, numeric_equality, normalize_symbolic_equation, SymbolicMathMixin
def is_correct(item, pred_key='prediction', prec=1e-3):
pred = item[pred_key]
ans = item['answer']
if isinstance(pred, list) and isinstance(ans, list):
pred_matched = set()
ans_matched = set()
for i in range(len(pred)):
for j in range(len(ans)):
item_cpy = deepcopy(item)
item_cpy.update({
pred_key: pred[i],
'answer': ans[j]
})
if is_correct(item_cpy, pred_key=pred_key, prec=prec):
pred_matched.add(i)
ans_matched.add(j)
if item_cpy[pred_key] == '2,3,4':
print(item, flush=True)
print("wtf", flush=True)
return len(pred_matched) == len(pred) and len(ans_matched) == len(ans)
elif isinstance(pred, str) and isinstance(ans, str):
if '\\cup' in pred and '\\cup' in ans:
item = deepcopy(item)
item.update({
pred_key: pred.split('\\cup'),
'answer': ans.split('\\cup'),
})
return is_correct(item, pred_key=pred_key, prec=prec)
else:
label = False
try:
label = abs(float(regex.sub(r',', '', str(pred))) - float(regex.sub(r',', '', str(ans)))) < prec
except:
pass
label = label or (ans and pred == ans) or math_equal(pred, ans)
return label
else:
print(item, flush=True)
raise NotImplementedError()
def eval_math(item, pred_key='prediction', prec=1e-3):
pred = item[pred_key]
if pred_key == 'program_output' and isinstance(pred, str):
pred = [pred]
ans = item['answer']
if isinstance(pred, list) and isinstance(ans, list):
# for some questions in MATH, `reference` repeats answers
_ans = []
for a in ans:
if a not in _ans:
_ans.append(a)
ans = _ans
# some predictions for MATH questions also repeats answers
_pred = []
for a in pred:
if a not in _pred:
_pred.append(a)
# some predictions mistakenly box non-answer strings
pred = _pred[-len(ans):]
item.update({
pred_key: pred,
'answer': ans
})
return is_correct(item, pred_key=pred_key, prec=prec)
def eval_last_single_answer(item, pred_key='prediction', prec=1e-3):
for key in [pred_key, 'answer']:
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
return is_correct(item, pred_key=pred_key, prec=prec)
def eval_agieval_gaokao_math_cloze(item, pred_key='prediction', prec=1e-3):
if pred_key == 'program_output' and isinstance(item[pred_key], str):
item[pred_key] = [item[pred_key]]
for key in [pred_key, 'answer']:
assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list"
pred = item[pred_key]
ans = item['answer']
_pred = []
for p in pred:
p = p + ";"
while p:
left_brackets = 0
for i in range(len(p)):
if p[i] == ';' or (p[i] == ',' and left_brackets == 0):
_p, p = p[:i].strip(), p[i + 1:].strip()
if _p not in _pred:
_pred.append(_p)
break
elif p[i] in '([{':
left_brackets += 1
elif p[i] in ')]}':
left_brackets -= 1
pred = _pred[-len(ans):]
if len(pred) == len(ans):
for p, a in zip(pred, ans):
item.update({
pred_key: p,
'answer': a,
})
if not is_correct(item, pred_key=pred_key, prec=prec):
return False
return True
else:
return False
def eval_agieval_gaokao_mathqa(item, pred_key='prediction', prec=1e-3):
if pred_key == 'program_output' and isinstance(item[pred_key], str):
item[pred_key] = [item[pred_key]]
pred_str = " ".join(item[pred_key])
ans = item['answer']
tag = None
idx = -1
for t in 'ABCD':
if t in pred_str and pred_str.index(t) > idx:
tag = t
idx = pred_str.index(t)
return tag == ans
def eval_math_sat(item, pred_key='prediction', prec=1e-3):
for key in [pred_key, 'answer']:
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
return item[pred_key].lower() == item['answer'].lower()
def eval_mmlu_stem(item, pred_key='prediction', prec=1e-3):
return eval_math_sat(item, pred_key=pred_key, prec=prec)
def eval_ocwcourses(item, pred_key='prediction', prec=1e-3):
INVALID_ANSWER = "[invalidanswer]"
for key in [pred_key, 'answer']:
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
pred = item[pred_key]
ans = item['answer']
try:
float(ans)
normalize_fn = normalize_numeric
is_equiv = numeric_equality
answer_type = "numeric"
except ValueError:
if "=" in ans:
normalize_fn = normalize_symbolic_equation
is_equiv = lambda x, y: x==y
answer_type = "equation"
else:
normalize_fn = SymbolicMathMixin().normalize_tex
is_equiv = SymbolicMathMixin().is_tex_equiv
answer_type = "expression"
correct_answer = normalize_fn(ans)
unnormalized_answer = pred if pred else INVALID_ANSWER
model_answer = normalize_fn(unnormalized_answer)
if unnormalized_answer == INVALID_ANSWER:
acc = 0
elif model_answer == INVALID_ANSWER:
acc = 0
elif is_equiv(model_answer, correct_answer):
acc = 1
else:
acc = 0
return acc
def eval_minif2f_isabelle(item, pred_key='prediction', prec=1e-3):
return True