mirror of
https://github.com/deepseek-ai/DeepSeek-Math
synced 2025-06-26 18:16:20 +00:00
init
This commit is contained in:
172
evaluation/eval/eval_script.py
Normal file
172
evaluation/eval/eval_script.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import regex
|
||||
from copy import deepcopy
|
||||
from eval.eval_utils import math_equal
|
||||
from eval.ocwcourses_eval_utils import normalize_numeric, numeric_equality, normalize_symbolic_equation, SymbolicMathMixin
|
||||
|
||||
def is_correct(item, pred_key='prediction', prec=1e-3):
|
||||
pred = item[pred_key]
|
||||
ans = item['answer']
|
||||
if isinstance(pred, list) and isinstance(ans, list):
|
||||
pred_matched = set()
|
||||
ans_matched = set()
|
||||
for i in range(len(pred)):
|
||||
for j in range(len(ans)):
|
||||
item_cpy = deepcopy(item)
|
||||
item_cpy.update({
|
||||
pred_key: pred[i],
|
||||
'answer': ans[j]
|
||||
})
|
||||
if is_correct(item_cpy, pred_key=pred_key, prec=prec):
|
||||
pred_matched.add(i)
|
||||
ans_matched.add(j)
|
||||
if item_cpy[pred_key] == '2,3,4':
|
||||
print(item, flush=True)
|
||||
print("wtf", flush=True)
|
||||
return len(pred_matched) == len(pred) and len(ans_matched) == len(ans)
|
||||
elif isinstance(pred, str) and isinstance(ans, str):
|
||||
if '\\cup' in pred and '\\cup' in ans:
|
||||
item = deepcopy(item)
|
||||
item.update({
|
||||
pred_key: pred.split('\\cup'),
|
||||
'answer': ans.split('\\cup'),
|
||||
})
|
||||
return is_correct(item, pred_key=pred_key, prec=prec)
|
||||
else:
|
||||
label = False
|
||||
try:
|
||||
label = abs(float(regex.sub(r',', '', str(pred))) - float(regex.sub(r',', '', str(ans)))) < prec
|
||||
except:
|
||||
pass
|
||||
label = label or (ans and pred == ans) or math_equal(pred, ans)
|
||||
return label
|
||||
else:
|
||||
print(item, flush=True)
|
||||
raise NotImplementedError()
|
||||
|
||||
def eval_math(item, pred_key='prediction', prec=1e-3):
|
||||
pred = item[pred_key]
|
||||
if pred_key == 'program_output' and isinstance(pred, str):
|
||||
pred = [pred]
|
||||
ans = item['answer']
|
||||
if isinstance(pred, list) and isinstance(ans, list):
|
||||
# for some questions in MATH, `reference` repeats answers
|
||||
_ans = []
|
||||
for a in ans:
|
||||
if a not in _ans:
|
||||
_ans.append(a)
|
||||
ans = _ans
|
||||
# some predictions for MATH questions also repeats answers
|
||||
_pred = []
|
||||
for a in pred:
|
||||
if a not in _pred:
|
||||
_pred.append(a)
|
||||
# some predictions mistakenly box non-answer strings
|
||||
pred = _pred[-len(ans):]
|
||||
|
||||
item.update({
|
||||
pred_key: pred,
|
||||
'answer': ans
|
||||
})
|
||||
return is_correct(item, pred_key=pred_key, prec=prec)
|
||||
|
||||
def eval_last_single_answer(item, pred_key='prediction', prec=1e-3):
|
||||
for key in [pred_key, 'answer']:
|
||||
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
|
||||
return is_correct(item, pred_key=pred_key, prec=prec)
|
||||
|
||||
def eval_agieval_gaokao_math_cloze(item, pred_key='prediction', prec=1e-3):
|
||||
if pred_key == 'program_output' and isinstance(item[pred_key], str):
|
||||
item[pred_key] = [item[pred_key]]
|
||||
for key in [pred_key, 'answer']:
|
||||
assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list"
|
||||
pred = item[pred_key]
|
||||
ans = item['answer']
|
||||
_pred = []
|
||||
for p in pred:
|
||||
p = p + ";"
|
||||
while p:
|
||||
left_brackets = 0
|
||||
for i in range(len(p)):
|
||||
if p[i] == ';' or (p[i] == ',' and left_brackets == 0):
|
||||
_p, p = p[:i].strip(), p[i + 1:].strip()
|
||||
if _p not in _pred:
|
||||
_pred.append(_p)
|
||||
break
|
||||
elif p[i] in '([{':
|
||||
left_brackets += 1
|
||||
elif p[i] in ')]}':
|
||||
left_brackets -= 1
|
||||
pred = _pred[-len(ans):]
|
||||
if len(pred) == len(ans):
|
||||
for p, a in zip(pred, ans):
|
||||
item.update({
|
||||
pred_key: p,
|
||||
'answer': a,
|
||||
})
|
||||
if not is_correct(item, pred_key=pred_key, prec=prec):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def eval_agieval_gaokao_mathqa(item, pred_key='prediction', prec=1e-3):
|
||||
if pred_key == 'program_output' and isinstance(item[pred_key], str):
|
||||
item[pred_key] = [item[pred_key]]
|
||||
pred_str = " ".join(item[pred_key])
|
||||
ans = item['answer']
|
||||
tag = None
|
||||
idx = -1
|
||||
for t in 'ABCD':
|
||||
if t in pred_str and pred_str.index(t) > idx:
|
||||
tag = t
|
||||
idx = pred_str.index(t)
|
||||
return tag == ans
|
||||
|
||||
def eval_math_sat(item, pred_key='prediction', prec=1e-3):
|
||||
for key in [pred_key, 'answer']:
|
||||
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
|
||||
return item[pred_key].lower() == item['answer'].lower()
|
||||
|
||||
def eval_mmlu_stem(item, pred_key='prediction', prec=1e-3):
|
||||
return eval_math_sat(item, pred_key=pred_key, prec=prec)
|
||||
|
||||
def eval_ocwcourses(item, pred_key='prediction', prec=1e-3):
|
||||
INVALID_ANSWER = "[invalidanswer]"
|
||||
for key in [pred_key, 'answer']:
|
||||
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
|
||||
pred = item[pred_key]
|
||||
ans = item['answer']
|
||||
|
||||
try:
|
||||
float(ans)
|
||||
normalize_fn = normalize_numeric
|
||||
is_equiv = numeric_equality
|
||||
answer_type = "numeric"
|
||||
except ValueError:
|
||||
if "=" in ans:
|
||||
normalize_fn = normalize_symbolic_equation
|
||||
is_equiv = lambda x, y: x==y
|
||||
answer_type = "equation"
|
||||
else:
|
||||
normalize_fn = SymbolicMathMixin().normalize_tex
|
||||
is_equiv = SymbolicMathMixin().is_tex_equiv
|
||||
answer_type = "expression"
|
||||
|
||||
correct_answer = normalize_fn(ans)
|
||||
|
||||
unnormalized_answer = pred if pred else INVALID_ANSWER
|
||||
model_answer = normalize_fn(unnormalized_answer)
|
||||
|
||||
if unnormalized_answer == INVALID_ANSWER:
|
||||
acc = 0
|
||||
elif model_answer == INVALID_ANSWER:
|
||||
acc = 0
|
||||
elif is_equiv(model_answer, correct_answer):
|
||||
acc = 1
|
||||
else:
|
||||
acc = 0
|
||||
|
||||
return acc
|
||||
|
||||
def eval_minif2f_isabelle(item, pred_key='prediction', prec=1e-3):
|
||||
return True
|
||||
325
evaluation/eval/eval_utils.py
Executable file
325
evaluation/eval/eval_utils.py
Executable file
@@ -0,0 +1,325 @@
|
||||
import multiprocessing
|
||||
from math import isclose
|
||||
import numpy as np
|
||||
from typing import Union, Any, Dict
|
||||
|
||||
from sympy import simplify, N
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
from sympy.parsing.latex import parse_latex
|
||||
import re
|
||||
import regex
|
||||
|
||||
from data_processing.answer_extraction import extract_answer, extract_program_output, strip_string
|
||||
|
||||
def extract_program(result: str, last_only=True):
|
||||
"""
|
||||
extract the program after "```python", and before "```"
|
||||
"""
|
||||
program = ""
|
||||
start = False
|
||||
for line in result.split("\n"):
|
||||
if line.startswith("```python"):
|
||||
if last_only:
|
||||
program = "" # only extract the last program
|
||||
else:
|
||||
program += "\n# ========\n"
|
||||
start = True
|
||||
elif line.startswith("```"):
|
||||
start = False
|
||||
elif start:
|
||||
program += line + "\n"
|
||||
return program
|
||||
|
||||
|
||||
def parse_ground_truth(example: Dict[str, Any], data_name):
|
||||
if 'gt_cot' in example:
|
||||
return example['gt_cot'], strip_string(example['gt'])
|
||||
|
||||
# parse ground truth
|
||||
if data_name in ["math", 'ocw']:
|
||||
gt_cot = example['solution']
|
||||
gt_ans = extract_answer(gt_cot)
|
||||
elif data_name == "gsm8k":
|
||||
gt_cot, gt_ans = example['answer'].split("####")
|
||||
elif data_name == "gsm-hard":
|
||||
gt_cot, gt_ans = example['code'], example['target']
|
||||
elif data_name == "svamp":
|
||||
gt_cot, gt_ans = example['Equation'], example['Answer']
|
||||
elif data_name == "asdiv":
|
||||
gt_cot = example['formula']
|
||||
gt_ans = re.sub(r"\(.*?\)", "", example['answer'])
|
||||
elif data_name == "mawps":
|
||||
gt_cot, gt_ans = None, example['target']
|
||||
elif data_name == "tabmwp":
|
||||
gt_cot = example['solution']
|
||||
gt_ans = example['answer']
|
||||
if example['ans_type'] in ['integer_number', 'decimal_number']:
|
||||
if '/' in gt_ans:
|
||||
gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1])
|
||||
elif ',' in gt_ans:
|
||||
gt_ans = float(gt_ans.replace(',', ''))
|
||||
elif '%' in gt_ans:
|
||||
gt_ans = float(gt_ans.split('%')[0]) / 100
|
||||
else:
|
||||
gt_ans = float(gt_ans)
|
||||
elif data_name == "bbh":
|
||||
gt_cot, gt_ans = None, example['target']
|
||||
else:
|
||||
raise NotImplementedError(data_name)
|
||||
# post process
|
||||
gt_cot = str(gt_cot).strip()
|
||||
gt_ans = strip_string(gt_ans)
|
||||
return gt_cot, gt_ans
|
||||
|
||||
|
||||
def parse_question(example, data_name):
|
||||
question = ""
|
||||
if data_name == "asdiv":
|
||||
question = f"{example['body'].strip()} {example['question'].strip()}"
|
||||
elif data_name == "svamp":
|
||||
body = example["Body"].strip()
|
||||
if not body.endswith("."):
|
||||
body = body + "."
|
||||
question = f'{body} {example["Question"].strip()}'
|
||||
elif data_name == "tabmwp":
|
||||
title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else ""
|
||||
question = f'Read the following table {title_str}and answer a question:\n'
|
||||
question += f'{example["table"]}\n{example["question"]}'
|
||||
if example['choices']:
|
||||
question += f' Please select from the following options: {example["choices"]}'
|
||||
else:
|
||||
for key in ['question', 'problem', 'Question', 'input']:
|
||||
if key in example:
|
||||
question = example[key]
|
||||
break
|
||||
assert question != ""
|
||||
return question.strip()
|
||||
|
||||
|
||||
def run_execute(executor, result, prompt_type, execute=False):
|
||||
if not result or result == 'error':
|
||||
return None, None
|
||||
report = None
|
||||
|
||||
if "program_only" in prompt_type:
|
||||
prediction = extract_program_output(result)
|
||||
elif prompt_type in ["pot", "pal"] and execute:
|
||||
code = extract_program(result)
|
||||
prediction, report = executor.apply(code)
|
||||
else:
|
||||
prediction = extract_answer(result)
|
||||
|
||||
prediction = strip_string(prediction)
|
||||
return prediction, report
|
||||
|
||||
|
||||
def parse_digits(num):
|
||||
# format: 234.23 || 23%
|
||||
num = regex.sub(',', '', str(num))
|
||||
try:
|
||||
return float(num)
|
||||
except:
|
||||
if num.endswith('%'):
|
||||
num = num[:-1]
|
||||
if num.endswith('\\'):
|
||||
num = num[:-1]
|
||||
try:
|
||||
return float(num) / 100
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
def is_digit(num):
|
||||
# paired with parse_digits
|
||||
return parse_digits(num) is not None
|
||||
|
||||
|
||||
def normalize_prediction(prediction):
|
||||
try: # 1. numerical equal
|
||||
if is_digit(prediction):
|
||||
prediction = np.round(float(str(prediction).replace(",", "")), 6)
|
||||
return str(prediction)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 2. symbolic equal
|
||||
prediction = str(prediction).strip()
|
||||
|
||||
## deal with [], (), {}
|
||||
brackets = []
|
||||
while prediction.startswith("[") and prediction.endswith("]") or (prediction.startswith("(") and prediction.endswith(")")):
|
||||
bracket = prediction[0]
|
||||
prediction = prediction[1:-1]
|
||||
if brackets and ',' in prediction:
|
||||
pred_parts = [normalize_prediction(part) for part in prediction.split(",")]
|
||||
prediction = ",".join(pred_parts)
|
||||
|
||||
if brackets:
|
||||
for b in reversed(brackets):
|
||||
if b == '[':
|
||||
prediction = '[' + prediction + ']'
|
||||
else:
|
||||
assert b == '('
|
||||
prediction = '(' + prediction + ')'
|
||||
|
||||
def _parse(s):
|
||||
for f in [parse_latex, parse_expr]:
|
||||
try:
|
||||
return f(s)
|
||||
except:
|
||||
pass
|
||||
return s
|
||||
|
||||
prediction = _parse(prediction)
|
||||
|
||||
for s in ['{', "}", "(", ")"]:
|
||||
prediction = prediction.replace(s, "")
|
||||
|
||||
return prediction
|
||||
|
||||
|
||||
def math_equal(prediction: Union[bool, float, str],
|
||||
reference: Union[float, str],
|
||||
include_percentage: bool = True,
|
||||
is_close: bool = True,
|
||||
timeout: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Exact match of math if and only if:
|
||||
1. numerical equal: both can convert to float and are equal
|
||||
2. symbolic equal: both can convert to sympy expression and are equal
|
||||
"""
|
||||
if str(prediction) == str(reference):
|
||||
return True
|
||||
|
||||
try: # 1. numerical equal
|
||||
if is_digit(prediction) and is_digit(reference):
|
||||
prediction = parse_digits(prediction)
|
||||
reference = parse_digits(reference)
|
||||
# number questions
|
||||
if include_percentage:
|
||||
gt_result = [reference / 100, reference, reference * 100]
|
||||
else:
|
||||
gt_result = [reference]
|
||||
for item in gt_result:
|
||||
try:
|
||||
if is_close:
|
||||
if isclose(item, prediction, abs_tol=1e-3):
|
||||
return True
|
||||
else:
|
||||
if item == prediction:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
except:
|
||||
pass
|
||||
|
||||
if not prediction and prediction not in [0, False]:
|
||||
return False
|
||||
|
||||
# 2. symbolic equal
|
||||
reference = str(reference).strip()
|
||||
prediction = str(prediction).strip()
|
||||
|
||||
if regex.match(r'(\(|\[).+(\)|\])', prediction) is not None and regex.match(r'(\(|\[).+(\)|\])', reference) is not None:
|
||||
pred_parts = prediction[1:-1].split(",")
|
||||
ref_parts = reference[1:-1].split(",")
|
||||
if len(pred_parts) == len(ref_parts):
|
||||
if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
|
||||
return True
|
||||
|
||||
if (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) and \
|
||||
(reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")):
|
||||
pred_lines = [line.strip() for line in prediction[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()]
|
||||
ref_lines = [line.strip() for line in reference[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()]
|
||||
matched = True
|
||||
if len(pred_lines) == len(ref_lines):
|
||||
for pred_line, ref_line in zip(pred_lines, ref_lines):
|
||||
pred_parts = pred_line.split("&")
|
||||
ref_parts = ref_line.split("&")
|
||||
if len(pred_parts) == len(ref_parts):
|
||||
if not all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
|
||||
matched = False
|
||||
break
|
||||
else:
|
||||
matched = False
|
||||
if not matched:
|
||||
break
|
||||
else:
|
||||
matched = False
|
||||
if matched:
|
||||
return True
|
||||
|
||||
if prediction.count('=') == 1 and reference.count('=') == 1:
|
||||
pred = prediction.split('=')
|
||||
pred = f"{pred[0].strip()} - ({pred[1].strip()})"
|
||||
ref = reference.split('=')
|
||||
ref = f"{ref[0].strip()} - ({ref[1].strip()})"
|
||||
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
|
||||
return True
|
||||
elif prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference:
|
||||
if math_equal(prediction.split('=')[1], reference, include_percentage, is_close):
|
||||
return True
|
||||
elif reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction:
|
||||
if math_equal(prediction, reference.split('=')[1], include_percentage, is_close):
|
||||
return True
|
||||
|
||||
# symbolic equal with sympy
|
||||
if timeout:
|
||||
if call_with_timeout(symbolic_equal_process, prediction, reference):
|
||||
return True
|
||||
else:
|
||||
if symbolic_equal(prediction, reference):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def math_equal_process(param):
|
||||
return math_equal(param[-2], param[-1])
|
||||
|
||||
|
||||
def symbolic_equal(a, b):
|
||||
def _parse(s):
|
||||
for f in [parse_latex, parse_expr]:
|
||||
try:
|
||||
return f(s)
|
||||
except:
|
||||
pass
|
||||
return s
|
||||
a = _parse(a)
|
||||
b = _parse(b)
|
||||
|
||||
try:
|
||||
if simplify(a-b) == 0:
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if isclose(N(a), N(b), abs_tol=1e-3):
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def symbolic_equal_process(a, b, output_queue):
|
||||
result = symbolic_equal(a, b)
|
||||
output_queue.put(result)
|
||||
|
||||
|
||||
def call_with_timeout(func, *args, timeout=1, **kwargs):
|
||||
output_queue = multiprocessing.Queue()
|
||||
process_args = args + (output_queue,)
|
||||
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
|
||||
process.start()
|
||||
process.join(timeout)
|
||||
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
process.join()
|
||||
return False
|
||||
|
||||
return output_queue.get()
|
||||
263
evaluation/eval/ocwcourses_eval_utils.py
Normal file
263
evaluation/eval/ocwcourses_eval_utils.py
Normal file
@@ -0,0 +1,263 @@
|
||||
import re
|
||||
import numpy as np
|
||||
import sympy
|
||||
from sympy.core.sympify import SympifyError
|
||||
from sympy.parsing.latex import parse_latex
|
||||
|
||||
import signal
|
||||
|
||||
INVALID_ANSWER = "[invalidanswer]"
|
||||
|
||||
class timeout:
|
||||
def __init__(self, seconds=1, error_message="Timeout"):
|
||||
self.seconds = seconds
|
||||
self.error_message = error_message
|
||||
|
||||
def handle_timeout(self, signum, frame):
|
||||
raise TimeoutError(self.error_message)
|
||||
|
||||
def __enter__(self):
|
||||
signal.signal(signal.SIGALRM, self.handle_timeout)
|
||||
signal.alarm(self.seconds)
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
signal.alarm(0)
|
||||
|
||||
def normalize_numeric(s):
|
||||
if s is None:
|
||||
return None
|
||||
for unit in [
|
||||
"eV",
|
||||
" \\mathrm{~kg} \\cdot \\mathrm{m} / \\mathrm{s}",
|
||||
" kg m/s",
|
||||
"kg*m/s",
|
||||
"kg",
|
||||
"m/s",
|
||||
"m / s",
|
||||
"m s^{-1}",
|
||||
"\\text{ m/s}",
|
||||
" \\mathrm{m/s}",
|
||||
" \\text{ m/s}",
|
||||
"g/mole",
|
||||
"g/mol",
|
||||
"\\mathrm{~g}",
|
||||
"\\mathrm{~g} / \\mathrm{mol}",
|
||||
"W",
|
||||
"erg/s",
|
||||
"years",
|
||||
"year",
|
||||
"cm",
|
||||
]:
|
||||
s = s.replace(unit, "")
|
||||
s = s.strip()
|
||||
for maybe_unit in ["m", "s", "cm"]:
|
||||
s = s.replace("\\mathrm{" + maybe_unit + "}", "")
|
||||
s = s.replace("\\mathrm{~" + maybe_unit + "}", "")
|
||||
s = s.strip()
|
||||
s = s.strip("$")
|
||||
try:
|
||||
return float(eval(s))
|
||||
except:
|
||||
try:
|
||||
expr = parse_latex(s)
|
||||
if expr.is_number:
|
||||
return float(expr)
|
||||
return INVALID_ANSWER
|
||||
except:
|
||||
return INVALID_ANSWER
|
||||
|
||||
def numeric_equality(n1, n2, threshold=0.01):
|
||||
if n1 is None or n2 is None:
|
||||
return False
|
||||
if np.isclose(n1, 0) or np.isclose(n2, 0) or np.isclose(n1 - n2, 0):
|
||||
return np.abs(n1 - n2) < threshold * (n1 + n2) / 2
|
||||
else:
|
||||
return np.isclose(n1, n2)
|
||||
|
||||
def normalize_symbolic_equation(s):
|
||||
if not isinstance(s, str):
|
||||
return INVALID_ANSWER
|
||||
if s.startswith("\\["):
|
||||
s = s[2:]
|
||||
if s.endswith("\\]"):
|
||||
s = s[:-2]
|
||||
s = s.replace("\\left(", "(")
|
||||
s = s.replace("\\right)", ")")
|
||||
s = s.replace("\\\\", "\\")
|
||||
if s.startswith("$") or s.endswith("$"):
|
||||
s = s.strip("$")
|
||||
try:
|
||||
maybe_expression = parse_latex(s)
|
||||
if not isinstance(maybe_expression, sympy.core.relational.Equality):
|
||||
# we have equation, not expression
|
||||
return INVALID_ANSWER
|
||||
else:
|
||||
return maybe_expression
|
||||
except:
|
||||
return INVALID_ANSWER
|
||||
|
||||
class SymbolicMathMixin:
|
||||
"""
|
||||
Methods useful for parsing mathematical expressions from text and determining equivalence of expressions.
|
||||
"""
|
||||
|
||||
SUBSTITUTIONS = [ # used for text normalize
|
||||
("an ", ""),
|
||||
("a ", ""),
|
||||
(".$", "$"),
|
||||
("\\$", ""),
|
||||
(r"\ ", ""),
|
||||
(" ", ""),
|
||||
("mbox", "text"),
|
||||
(",\\text{and}", ","),
|
||||
("\\text{and}", ","),
|
||||
("\\text{m}", "\\text{}"),
|
||||
]
|
||||
REMOVED_EXPRESSIONS = [ # used for text normalizer
|
||||
"square",
|
||||
"ways",
|
||||
"integers",
|
||||
"dollars",
|
||||
"mph",
|
||||
"inches",
|
||||
"ft",
|
||||
"hours",
|
||||
"km",
|
||||
"units",
|
||||
"\\ldots",
|
||||
"sue",
|
||||
"points",
|
||||
"feet",
|
||||
"minutes",
|
||||
"digits",
|
||||
"cents",
|
||||
"degrees",
|
||||
"cm",
|
||||
"gm",
|
||||
"pounds",
|
||||
"meters",
|
||||
"meals",
|
||||
"edges",
|
||||
"students",
|
||||
"childrentickets",
|
||||
"multiples",
|
||||
"\\text{s}",
|
||||
"\\text{.}",
|
||||
"\\text{\ns}",
|
||||
"\\text{}^2",
|
||||
"\\text{}^3",
|
||||
"\\text{\n}",
|
||||
"\\text{}",
|
||||
r"\mathrm{th}",
|
||||
r"^\circ",
|
||||
r"^{\circ}",
|
||||
r"\;",
|
||||
r",\!",
|
||||
"{,}",
|
||||
'"',
|
||||
"\\dots",
|
||||
]
|
||||
|
||||
def normalize_tex(self, final_answer: str) -> str:
|
||||
"""
|
||||
Normalizes a string representing a mathematical expression.
|
||||
Used as a preprocessing step before parsing methods.
|
||||
|
||||
Copied character for character from appendix D of Lewkowycz et al. (2022)
|
||||
"""
|
||||
final_answer = final_answer.split("=")[-1]
|
||||
|
||||
for before, after in self.SUBSTITUTIONS:
|
||||
final_answer = final_answer.replace(before, after)
|
||||
for expr in self.REMOVED_EXPRESSIONS:
|
||||
final_answer = final_answer.replace(expr, "")
|
||||
|
||||
# Extract answer that is in LaTeX math, is bold,
|
||||
# is surrounded by a box, etc.
|
||||
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
|
||||
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
|
||||
|
||||
# Normalize shorthand TeX:
|
||||
# \fracab -> \frac{a}{b}
|
||||
# \frac{abc}{bef} -> \frac{abc}{bef}
|
||||
# \fracabc -> \frac{a}{b}c
|
||||
# \sqrta -> \sqrt{a}
|
||||
# \sqrtab -> sqrt{a}b
|
||||
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
|
||||
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
|
||||
final_answer = final_answer.replace("$", "")
|
||||
|
||||
# Normalize 100,000 -> 100000
|
||||
if final_answer.replace(",", "").isdigit():
|
||||
final_answer = final_answer.replace(",", "")
|
||||
|
||||
return final_answer
|
||||
|
||||
def parse_tex(self, text: str, time_limit: int = 5) -> sympy.Basic:
|
||||
"""
|
||||
Wrapper around `sympy.parse_text` that outputs a SymPy expression.
|
||||
Typically, you want to apply `normalize_text` as a preprocessing step.
|
||||
"""
|
||||
try:
|
||||
with timeout(seconds=time_limit):
|
||||
parsed = parse_latex(text)
|
||||
except (
|
||||
# general error handling: there is a long tail of possible sympy/other
|
||||
# errors we would like to catch
|
||||
Exception
|
||||
) as e:
|
||||
print(f"failed to parse {text} with exception {e}")
|
||||
return None
|
||||
|
||||
return parsed
|
||||
|
||||
def is_exp_equiv(self, x1: sympy.Basic, x2: sympy.Basic, time_limit=5) -> bool:
|
||||
"""
|
||||
Determines whether two sympy expressions are equal.
|
||||
"""
|
||||
try:
|
||||
with timeout(seconds=time_limit):
|
||||
try:
|
||||
diff = x1 - x2
|
||||
except (SympifyError, ValueError, TypeError) as e:
|
||||
print(
|
||||
f"Couldn't subtract {x1} and {x2} with exception {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
if sympy.simplify(diff) == 0:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except (SympifyError, ValueError, TypeError) as e:
|
||||
print(f"Failed to simplify {x1}-{x2} with {e}")
|
||||
return False
|
||||
except TimeoutError as e:
|
||||
print(f"Timed out comparing {x1} and {x2}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"failed on unrecognized exception {e}")
|
||||
return False
|
||||
|
||||
def is_tex_equiv(self, x1: str, x2: str, time_limit=5) -> bool:
|
||||
"""
|
||||
Determines whether two (ideally normalized using `normalize_text`) TeX expressions are equal.
|
||||
|
||||
Does so by first checking for string exact-match, then falls back on sympy-equivalence,
|
||||
following the (Lewkowycz et al. 2022) methodology.
|
||||
"""
|
||||
if x1 == x2:
|
||||
# don't resort to sympy if we have full string match, post-normalization
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
parsed_x2 = self.parse_tex(x2)
|
||||
if not parsed_x2:
|
||||
# if our reference fails to parse into a Sympy object,
|
||||
# we forgo parsing + checking our generated answer.
|
||||
return False
|
||||
return self.is_exp_equiv(self.parse_tex(x1), parsed_x2, time_limit=time_limit)
|
||||
165
evaluation/eval/python_executor.py
Executable file
165
evaluation/eval/python_executor.py
Executable file
@@ -0,0 +1,165 @@
|
||||
import os
|
||||
import io
|
||||
from contextlib import redirect_stdout
|
||||
import pickle
|
||||
import regex
|
||||
import copy
|
||||
from typing import Any, Dict, Optional
|
||||
import multiprocess
|
||||
from pebble import ProcessPool
|
||||
from concurrent.futures import TimeoutError
|
||||
from functools import partial
|
||||
import traceback
|
||||
from timeout_decorator import timeout
|
||||
|
||||
class GenericRuntime:
|
||||
GLOBAL_DICT = {}
|
||||
LOCAL_DICT = None
|
||||
HEADERS = []
|
||||
def __init__(self):
|
||||
self._global_vars = copy.copy(self.GLOBAL_DICT)
|
||||
self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
|
||||
|
||||
for c in self.HEADERS:
|
||||
self.exec_code(c)
|
||||
|
||||
def exec_code(self, code_piece: str) -> None:
|
||||
if regex.search(r'(\s|^)?input\(', code_piece) or regex.search(r'(\s|^)?os.system\(', code_piece):
|
||||
raise RuntimeError()
|
||||
exec(code_piece, self._global_vars)
|
||||
|
||||
def eval_code(self, expr: str) -> Any:
|
||||
return eval(expr, self._global_vars)
|
||||
|
||||
def inject(self, var_dict: Dict[str, Any]) -> None:
|
||||
for k, v in var_dict.items():
|
||||
self._global_vars[k] = v
|
||||
|
||||
@property
|
||||
def answer(self):
|
||||
return self._global_vars['answer']
|
||||
|
||||
class PythonExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
runtime: Optional[Any] = None,
|
||||
get_answer_symbol: Optional[str] = None,
|
||||
get_answer_expr: Optional[str] = None,
|
||||
get_answer_from_stdout: bool = False,
|
||||
) -> None:
|
||||
self.runtime = runtime if runtime else GenericRuntime()
|
||||
self.answer_symbol = get_answer_symbol
|
||||
self.answer_expr = get_answer_expr
|
||||
self.get_answer_from_stdout = get_answer_from_stdout
|
||||
|
||||
def process_generation_to_code(self, gens: str):
|
||||
batch_code = []
|
||||
for g in gens:
|
||||
multiline_comments = False
|
||||
code = []
|
||||
for line in g.split('\n'):
|
||||
strip_line = line.strip()
|
||||
if strip_line.startswith("#"):
|
||||
line = line.split("#", 1)[0] + "# comments"
|
||||
elif not multiline_comments and strip_line.startswith('"""') and strip_line.endswith('"""') and len(strip_line) >= 6:
|
||||
line = line.split('"""', 1)[0] + '"""comments"""'
|
||||
elif not multiline_comments and strip_line.startswith('"""'):
|
||||
multiline_comments = True
|
||||
elif multiline_comments and strip_line.endswith('"""'):
|
||||
multiline_comments = False
|
||||
line = ""
|
||||
if not multiline_comments:
|
||||
code.append(line)
|
||||
batch_code.append(code)
|
||||
return batch_code
|
||||
|
||||
@staticmethod
|
||||
def execute(
|
||||
code,
|
||||
get_answer_from_stdout = None,
|
||||
runtime = None,
|
||||
answer_symbol = None,
|
||||
answer_expr = None,
|
||||
timeout_length = 10,
|
||||
):
|
||||
try:
|
||||
if get_answer_from_stdout:
|
||||
program_io = io.StringIO()
|
||||
with redirect_stdout(program_io):
|
||||
timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
|
||||
program_io.seek(0)
|
||||
result = "".join(program_io.readlines()) # [-1]
|
||||
elif answer_symbol:
|
||||
timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
|
||||
result = runtime._global_vars[answer_symbol]
|
||||
elif answer_expr:
|
||||
timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
|
||||
result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
|
||||
else:
|
||||
timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1]))
|
||||
result = timeout(timeout_length)(runtime.eval_code)(code[-1])
|
||||
concise_exec_info = ""
|
||||
exec_info = ""
|
||||
str(result)
|
||||
pickle.dumps(result) # serialization check
|
||||
except:
|
||||
# traceback.print_exc()
|
||||
result = ''
|
||||
concise_exec_info = traceback.format_exc().split('\n')[-2]
|
||||
exec_info = traceback.format_exc()
|
||||
if get_answer_from_stdout and 'exec(code_piece, self._global_vars)' in exec_info:
|
||||
exec_info = exec_info.split('exec(code_piece, self._global_vars)')[-1].strip()
|
||||
msg = []
|
||||
for line in exec_info.split("\n"):
|
||||
patt = regex.search(r'(?P<start>.*)File "(?P<file>.*)", line (?P<lno>\d+), (?P<end>.*)', line)
|
||||
if patt is not None:
|
||||
if '<module>' in patt.group('end'):
|
||||
continue
|
||||
fname = patt.group("file")
|
||||
if "site-packages" in fname:
|
||||
fname = f"site-packages{fname.split('site-packages', 1)[1]}"
|
||||
line = f'{patt.group("start")}File "{fname}", {patt.group("end")}'
|
||||
else:
|
||||
line = f'{patt.group("start")}{patt.group("end")}'
|
||||
else:
|
||||
patt = regex.search(r'(?P<start>.*)(?P<file>/.*site-packages/.*\.py)(?P<end>.*)', line)
|
||||
if patt is not None:
|
||||
line = f'{patt.group("start")}site-packages{patt.group("file").split("site-packages", 1)[1]}{patt.group("end")}'
|
||||
msg.append(line)
|
||||
exec_info = "\n".join(msg)
|
||||
return result, concise_exec_info, exec_info
|
||||
|
||||
def apply(self, code):
|
||||
return self.batch_apply([code])[0]
|
||||
|
||||
def batch_apply(self, batch_code):
|
||||
all_code_snippets = self.process_generation_to_code(batch_code)
|
||||
all_exec_results = []
|
||||
executor = partial(
|
||||
self.execute,
|
||||
get_answer_from_stdout=self.get_answer_from_stdout,
|
||||
runtime=self.runtime,
|
||||
answer_symbol=self.answer_symbol,
|
||||
answer_expr=self.answer_expr,
|
||||
timeout_length=10,
|
||||
)
|
||||
with ProcessPool(max_workers=multiprocess.cpu_count()) as pool:
|
||||
iterator = pool.map(executor, all_code_snippets, timeout=10).result()
|
||||
|
||||
while True:
|
||||
try:
|
||||
result = next(iterator)
|
||||
all_exec_results.append(result)
|
||||
except StopIteration:
|
||||
break
|
||||
except TimeoutError as error:
|
||||
all_exec_results.append(("", "Timeout Error", "Timeout Error"))
|
||||
except Exception as error:
|
||||
print(error)
|
||||
exit()
|
||||
|
||||
batch_results = []
|
||||
for code, (result, concise_exec_info, exec_info) in zip(all_code_snippets, all_exec_results):
|
||||
metadata = {'code': code, 'exec_result': result, 'concise_exec_info': concise_exec_info, 'exec_info': exec_info}
|
||||
batch_results.append((result, metadata))
|
||||
return batch_results
|
||||
263
evaluation/eval/utils.py
Executable file
263
evaluation/eval/utils.py
Executable file
@@ -0,0 +1,263 @@
|
||||
import torch
|
||||
import tqdm
|
||||
from transformers import StoppingCriteria, GenerationConfig
|
||||
|
||||
class KeyWordsCriteria(StoppingCriteria):
|
||||
def __init__(self, stop_id_sequences, tokenizer, prompt_length):
|
||||
assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids"
|
||||
self.tokenizer = tokenizer
|
||||
self.stop_id_sequences = stop_id_sequences
|
||||
self.stop_sequences = [tokenizer.decode(sequence) for sequence in stop_id_sequences]
|
||||
print(f"stop sequences: {self.stop_sequences}", flush=True)
|
||||
self.prompt_length = prompt_length
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
sequences_should_be_stopped = []
|
||||
for i in range(input_ids.shape[0]):
|
||||
ids = input_ids[i][self.prompt_length:].tolist()
|
||||
should_be_stopped = False
|
||||
for stop_ids, stop_sequence in zip(self.stop_id_sequences, self.stop_sequences):
|
||||
_ids = ids
|
||||
for j in range(len(_ids), 0, -1):
|
||||
s = self.tokenizer.decode(_ids[max(j - len(stop_ids) - 3, 0) :j])
|
||||
if s.endswith(stop_sequence):
|
||||
should_be_stopped = True
|
||||
break
|
||||
if should_be_stopped:
|
||||
break
|
||||
sequences_should_be_stopped.append(should_be_stopped)
|
||||
return all(sequences_should_be_stopped)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_completions(model, tokenizer, prompts, batch_size=1, stop_id_sequences=None, end_of_generation_id_sequence=None, disable_tqdm=False, **generation_kwargs):
|
||||
generations = []
|
||||
finish_completion = []
|
||||
if not disable_tqdm:
|
||||
progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions")
|
||||
|
||||
if stop_id_sequences is not None:
|
||||
stop_sequences = [tokenizer.decode(stop_id_sequence) for stop_id_sequence in stop_id_sequences]
|
||||
|
||||
if end_of_generation_id_sequence is not None:
|
||||
end_of_generation_sequence = tokenizer.decode(end_of_generation_id_sequence)
|
||||
|
||||
num_return_sequences = generation_kwargs.get("num_return_sequences", 1)
|
||||
generation_kwargs['use_cache'] = True
|
||||
for i in range(0, len(prompts), batch_size):
|
||||
batch_prompts = prompts[i:i+batch_size]
|
||||
tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens='chatglm2' in str(model.__class__))
|
||||
batch_input_ids = tokenized_prompts.input_ids
|
||||
attention_mask = tokenized_prompts.attention_mask
|
||||
|
||||
if model.device.type == "cuda":
|
||||
batch_input_ids = batch_input_ids.cuda()
|
||||
attention_mask = attention_mask.cuda()
|
||||
|
||||
batch_finish_completion = [False] * len(batch_prompts) * num_return_sequences
|
||||
try:
|
||||
batch_outputs = model.generate(
|
||||
input_ids=batch_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
stopping_criteria=[KeyWordsCriteria(stop_id_sequences, tokenizer, batch_input_ids.size(1))] if stop_id_sequences else None,
|
||||
**generation_kwargs
|
||||
)
|
||||
|
||||
# the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate.
|
||||
# so some outputs still have the stop sequence, which we need to remove.
|
||||
if stop_id_sequences:
|
||||
for output_idx in range(batch_outputs.shape[0]):
|
||||
finish = False
|
||||
for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]):
|
||||
if any(tokenizer.decode(batch_outputs[output_idx, token_idx: token_idx + len(stop_sequence) + 3]).startswith(stop_sequence) for stop_sequence in stop_sequences):
|
||||
if end_of_generation_id_sequence is not None and tokenizer.decode(batch_outputs[output_idx, token_idx: token_idx + len(end_of_generation_id_sequence) + 3]).startswith(end_of_generation_sequence):
|
||||
batch_finish_completion[output_idx] = True
|
||||
batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id
|
||||
break
|
||||
|
||||
# remove the prompt from the output
|
||||
# we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs.
|
||||
# we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token.
|
||||
# space is important for some tasks (e.g., code completion).
|
||||
batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True)
|
||||
batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
|
||||
# duplicate the prompts to match the number of return sequences
|
||||
batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)]
|
||||
batch_generations = [
|
||||
output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs)
|
||||
]
|
||||
except Exception as e:
|
||||
print("Error when generating completions for batch:")
|
||||
print(batch_prompts)
|
||||
print("Error message:")
|
||||
print(e)
|
||||
print("Use empty string as the completion.")
|
||||
batch_generations = [""] * len(batch_prompts) * num_return_sequences
|
||||
|
||||
generations += batch_generations
|
||||
finish_completion += batch_finish_completion
|
||||
|
||||
if not disable_tqdm:
|
||||
progress.update(len(batch_prompts)//num_return_sequences)
|
||||
|
||||
assert len(generations) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences"
|
||||
return generations, finish_completion
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_next_word_predictions(model, tokenizer, prompts, candidate_token_ids=None, batch_size=1, return_token_predictions=False, disable_tqdm=False):
|
||||
predictions, probs = [], []
|
||||
if not disable_tqdm:
|
||||
progress = tqdm.tqdm(total=len(prompts), desc="Getting Predictions")
|
||||
|
||||
for i in range(0, len(prompts), batch_size):
|
||||
batch_prompts = prompts[i: i+batch_size]
|
||||
tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=False)
|
||||
batch_input_ids = tokenized_prompts.input_ids
|
||||
attention_mask = tokenized_prompts.attention_mask
|
||||
|
||||
if model.device.type == "cuda":
|
||||
batch_input_ids = batch_input_ids.cuda()
|
||||
attention_mask = attention_mask.cuda()
|
||||
|
||||
batch_logits = model(input_ids=batch_input_ids, attention_mask=attention_mask).logits[:, -1, :]
|
||||
if candidate_token_ids is not None:
|
||||
batch_logits = batch_logits[:, candidate_token_ids]
|
||||
batch_probs = torch.softmax(batch_logits, dim=-1)
|
||||
batch_prediction_indices = torch.argmax(batch_probs, dim=-1)
|
||||
if return_token_predictions:
|
||||
if candidate_token_ids is not None:
|
||||
candidate_tokens = tokenizer.convert_ids_to_tokens(candidate_token_ids)
|
||||
batch_predictions = [candidate_tokens[idx] for idx in batch_prediction_indices]
|
||||
else:
|
||||
batch_predictions = tokenizer.convert_ids_to_tokens(batch_prediction_indices)
|
||||
predictions += batch_predictions
|
||||
else:
|
||||
predictions += batch_prediction_indices.tolist()
|
||||
probs += batch_probs.tolist()
|
||||
|
||||
if not disable_tqdm:
|
||||
progress.update(len(batch_prompts))
|
||||
|
||||
assert len(predictions) == len(prompts), "number of predictions should be equal to number of prompts"
|
||||
return predictions, probs
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def score_completions(model, tokenizer, scoring_examples, disable_tqdm=False):
|
||||
'''
|
||||
Each scoring example is a dict, which contains the following keys:
|
||||
- prompt: the prompt to score
|
||||
- completions: a list of completions to score
|
||||
'''
|
||||
|
||||
if not disable_tqdm:
|
||||
progress = tqdm.tqdm(total=len(scoring_examples), desc="Scoring Completions")
|
||||
|
||||
# unroll the scoring examples
|
||||
unrolled_examples = []
|
||||
for scoring_example in scoring_examples:
|
||||
prompt = scoring_example["prompt"]
|
||||
for completion in scoring_example["completions"]:
|
||||
unrolled_examples.append({
|
||||
"prompt": prompt,
|
||||
"completion": completion
|
||||
})
|
||||
|
||||
scores = []
|
||||
# currently we don't support batching, because we want to directly use the loss returned by the model to score each completion.
|
||||
for unrolled_example in unrolled_examples:
|
||||
encoded_example = encode_with_prompt_completion_format(unrolled_example, tokenizer, max_seq_length=None)
|
||||
# unsqueeze the batch dimension
|
||||
for key, value in encoded_example.items():
|
||||
encoded_example[key] = value.unsqueeze(0)
|
||||
if model.device.type == "cuda":
|
||||
encoded_example = {
|
||||
key: value.cuda() for key, value in encoded_example.items()
|
||||
}
|
||||
outputs = model(**encoded_example)
|
||||
loss = outputs.loss
|
||||
scores.append(-loss.item())
|
||||
if not disable_tqdm:
|
||||
progress.update(1)
|
||||
|
||||
# roll up the scores
|
||||
rolled_up_scores = {}
|
||||
for unrolled_example, score in zip(unrolled_examples, scores):
|
||||
prompt = unrolled_example["prompt"]
|
||||
completion = unrolled_example["completion"]
|
||||
if prompt not in rolled_up_scores:
|
||||
rolled_up_scores[prompt] = {}
|
||||
rolled_up_scores[prompt][completion] = score
|
||||
|
||||
return rolled_up_scores
|
||||
|
||||
|
||||
|
||||
def load_hf_lm_and_tokenizer(
|
||||
model_name_or_path,
|
||||
tokenizer_name_or_path=None,
|
||||
device_map="auto",
|
||||
load_in_8bit=False,
|
||||
load_in_half=False,
|
||||
gptq_model=False,
|
||||
use_fast_tokenizer=True,
|
||||
padding_side="left",
|
||||
):
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer
|
||||
|
||||
if not tokenizer_name_or_path:
|
||||
tokenizer_name_or_path = model_name_or_path
|
||||
|
||||
is_chatglm2 = 'chatglm2' in tokenizer_name_or_path.lower() or 'chatglm2' in model_name_or_path
|
||||
is_qwen = 'qwen' in tokenizer_name_or_path.lower() or 'qwen' in model_name_or_path
|
||||
|
||||
if is_chatglm2 or is_qwen:
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True)
|
||||
if is_qwen:
|
||||
tokenizer.eos_token = '<|endoftext|>'
|
||||
tokenizer.eos_token_id = 151643
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=use_fast_tokenizer)
|
||||
# set padding side to left for batch generation
|
||||
tokenizer.padding_side = padding_side
|
||||
# set pad token to eos token if pad token is not set (as is the case for llama models)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
if gptq_model:
|
||||
from auto_gptq import AutoGPTQForCausalLM
|
||||
model_wrapper = AutoGPTQForCausalLM.from_quantized(
|
||||
model_name_or_path, device="cuda:0", use_triton=True
|
||||
)
|
||||
model = model_wrapper.model
|
||||
elif load_in_8bit:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
device_map=device_map,
|
||||
load_in_8bit=True
|
||||
)
|
||||
else:
|
||||
kwargs = {}
|
||||
model_class = AutoModelForCausalLM
|
||||
if is_chatglm2:
|
||||
kwargs = {'trust_remote_code': True}
|
||||
model_class = AutoModel
|
||||
elif is_qwen:
|
||||
kwargs = {'trust_remote_code': True}
|
||||
if device_map:
|
||||
model = model_class.from_pretrained(model_name_or_path, device_map=device_map, **kwargs)
|
||||
else:
|
||||
model = model_class.from_pretrained(model_name_or_path, **kwargs)
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
if is_qwen:
|
||||
model.generation_config = GenerationConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
model.generation_config.do_sample = False
|
||||
if not is_chatglm2 and not is_qwen and load_in_half:
|
||||
model = model.half()
|
||||
model.eval()
|
||||
return model, tokenizer
|
||||
Reference in New Issue
Block a user