This commit is contained in:
ZhihongShao
2024-02-06 10:27:40 +08:00
commit 21cc5c6701
59 changed files with 17325 additions and 0 deletions

View File

@@ -0,0 +1,172 @@
import regex
from copy import deepcopy
from eval.eval_utils import math_equal
from eval.ocwcourses_eval_utils import normalize_numeric, numeric_equality, normalize_symbolic_equation, SymbolicMathMixin
def is_correct(item, pred_key='prediction', prec=1e-3):
pred = item[pred_key]
ans = item['answer']
if isinstance(pred, list) and isinstance(ans, list):
pred_matched = set()
ans_matched = set()
for i in range(len(pred)):
for j in range(len(ans)):
item_cpy = deepcopy(item)
item_cpy.update({
pred_key: pred[i],
'answer': ans[j]
})
if is_correct(item_cpy, pred_key=pred_key, prec=prec):
pred_matched.add(i)
ans_matched.add(j)
if item_cpy[pred_key] == '2,3,4':
print(item, flush=True)
print("wtf", flush=True)
return len(pred_matched) == len(pred) and len(ans_matched) == len(ans)
elif isinstance(pred, str) and isinstance(ans, str):
if '\\cup' in pred and '\\cup' in ans:
item = deepcopy(item)
item.update({
pred_key: pred.split('\\cup'),
'answer': ans.split('\\cup'),
})
return is_correct(item, pred_key=pred_key, prec=prec)
else:
label = False
try:
label = abs(float(regex.sub(r',', '', str(pred))) - float(regex.sub(r',', '', str(ans)))) < prec
except:
pass
label = label or (ans and pred == ans) or math_equal(pred, ans)
return label
else:
print(item, flush=True)
raise NotImplementedError()
def eval_math(item, pred_key='prediction', prec=1e-3):
pred = item[pred_key]
if pred_key == 'program_output' and isinstance(pred, str):
pred = [pred]
ans = item['answer']
if isinstance(pred, list) and isinstance(ans, list):
# for some questions in MATH, `reference` repeats answers
_ans = []
for a in ans:
if a not in _ans:
_ans.append(a)
ans = _ans
# some predictions for MATH questions also repeats answers
_pred = []
for a in pred:
if a not in _pred:
_pred.append(a)
# some predictions mistakenly box non-answer strings
pred = _pred[-len(ans):]
item.update({
pred_key: pred,
'answer': ans
})
return is_correct(item, pred_key=pred_key, prec=prec)
def eval_last_single_answer(item, pred_key='prediction', prec=1e-3):
for key in [pred_key, 'answer']:
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
return is_correct(item, pred_key=pred_key, prec=prec)
def eval_agieval_gaokao_math_cloze(item, pred_key='prediction', prec=1e-3):
if pred_key == 'program_output' and isinstance(item[pred_key], str):
item[pred_key] = [item[pred_key]]
for key in [pred_key, 'answer']:
assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list"
pred = item[pred_key]
ans = item['answer']
_pred = []
for p in pred:
p = p + ";"
while p:
left_brackets = 0
for i in range(len(p)):
if p[i] == ';' or (p[i] == ',' and left_brackets == 0):
_p, p = p[:i].strip(), p[i + 1:].strip()
if _p not in _pred:
_pred.append(_p)
break
elif p[i] in '([{':
left_brackets += 1
elif p[i] in ')]}':
left_brackets -= 1
pred = _pred[-len(ans):]
if len(pred) == len(ans):
for p, a in zip(pred, ans):
item.update({
pred_key: p,
'answer': a,
})
if not is_correct(item, pred_key=pred_key, prec=prec):
return False
return True
else:
return False
def eval_agieval_gaokao_mathqa(item, pred_key='prediction', prec=1e-3):
if pred_key == 'program_output' and isinstance(item[pred_key], str):
item[pred_key] = [item[pred_key]]
pred_str = " ".join(item[pred_key])
ans = item['answer']
tag = None
idx = -1
for t in 'ABCD':
if t in pred_str and pred_str.index(t) > idx:
tag = t
idx = pred_str.index(t)
return tag == ans
def eval_math_sat(item, pred_key='prediction', prec=1e-3):
for key in [pred_key, 'answer']:
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
return item[pred_key].lower() == item['answer'].lower()
def eval_mmlu_stem(item, pred_key='prediction', prec=1e-3):
return eval_math_sat(item, pred_key=pred_key, prec=prec)
def eval_ocwcourses(item, pred_key='prediction', prec=1e-3):
INVALID_ANSWER = "[invalidanswer]"
for key in [pred_key, 'answer']:
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
pred = item[pred_key]
ans = item['answer']
try:
float(ans)
normalize_fn = normalize_numeric
is_equiv = numeric_equality
answer_type = "numeric"
except ValueError:
if "=" in ans:
normalize_fn = normalize_symbolic_equation
is_equiv = lambda x, y: x==y
answer_type = "equation"
else:
normalize_fn = SymbolicMathMixin().normalize_tex
is_equiv = SymbolicMathMixin().is_tex_equiv
answer_type = "expression"
correct_answer = normalize_fn(ans)
unnormalized_answer = pred if pred else INVALID_ANSWER
model_answer = normalize_fn(unnormalized_answer)
if unnormalized_answer == INVALID_ANSWER:
acc = 0
elif model_answer == INVALID_ANSWER:
acc = 0
elif is_equiv(model_answer, correct_answer):
acc = 1
else:
acc = 0
return acc
def eval_minif2f_isabelle(item, pred_key='prediction', prec=1e-3):
return True

325
evaluation/eval/eval_utils.py Executable file
View File

@@ -0,0 +1,325 @@
import multiprocessing
from math import isclose
import numpy as np
from typing import Union, Any, Dict
from sympy import simplify, N
from sympy.parsing.sympy_parser import parse_expr
from sympy.parsing.latex import parse_latex
import re
import regex
from data_processing.answer_extraction import extract_answer, extract_program_output, strip_string
def extract_program(result: str, last_only=True):
"""
extract the program after "```python", and before "```"
"""
program = ""
start = False
for line in result.split("\n"):
if line.startswith("```python"):
if last_only:
program = "" # only extract the last program
else:
program += "\n# ========\n"
start = True
elif line.startswith("```"):
start = False
elif start:
program += line + "\n"
return program
def parse_ground_truth(example: Dict[str, Any], data_name):
if 'gt_cot' in example:
return example['gt_cot'], strip_string(example['gt'])
# parse ground truth
if data_name in ["math", 'ocw']:
gt_cot = example['solution']
gt_ans = extract_answer(gt_cot)
elif data_name == "gsm8k":
gt_cot, gt_ans = example['answer'].split("####")
elif data_name == "gsm-hard":
gt_cot, gt_ans = example['code'], example['target']
elif data_name == "svamp":
gt_cot, gt_ans = example['Equation'], example['Answer']
elif data_name == "asdiv":
gt_cot = example['formula']
gt_ans = re.sub(r"\(.*?\)", "", example['answer'])
elif data_name == "mawps":
gt_cot, gt_ans = None, example['target']
elif data_name == "tabmwp":
gt_cot = example['solution']
gt_ans = example['answer']
if example['ans_type'] in ['integer_number', 'decimal_number']:
if '/' in gt_ans:
gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1])
elif ',' in gt_ans:
gt_ans = float(gt_ans.replace(',', ''))
elif '%' in gt_ans:
gt_ans = float(gt_ans.split('%')[0]) / 100
else:
gt_ans = float(gt_ans)
elif data_name == "bbh":
gt_cot, gt_ans = None, example['target']
else:
raise NotImplementedError(data_name)
# post process
gt_cot = str(gt_cot).strip()
gt_ans = strip_string(gt_ans)
return gt_cot, gt_ans
def parse_question(example, data_name):
question = ""
if data_name == "asdiv":
question = f"{example['body'].strip()} {example['question'].strip()}"
elif data_name == "svamp":
body = example["Body"].strip()
if not body.endswith("."):
body = body + "."
question = f'{body} {example["Question"].strip()}'
elif data_name == "tabmwp":
title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else ""
question = f'Read the following table {title_str}and answer a question:\n'
question += f'{example["table"]}\n{example["question"]}'
if example['choices']:
question += f' Please select from the following options: {example["choices"]}'
else:
for key in ['question', 'problem', 'Question', 'input']:
if key in example:
question = example[key]
break
assert question != ""
return question.strip()
def run_execute(executor, result, prompt_type, execute=False):
if not result or result == 'error':
return None, None
report = None
if "program_only" in prompt_type:
prediction = extract_program_output(result)
elif prompt_type in ["pot", "pal"] and execute:
code = extract_program(result)
prediction, report = executor.apply(code)
else:
prediction = extract_answer(result)
prediction = strip_string(prediction)
return prediction, report
def parse_digits(num):
# format: 234.23 || 23%
num = regex.sub(',', '', str(num))
try:
return float(num)
except:
if num.endswith('%'):
num = num[:-1]
if num.endswith('\\'):
num = num[:-1]
try:
return float(num) / 100
except:
pass
return None
def is_digit(num):
# paired with parse_digits
return parse_digits(num) is not None
def normalize_prediction(prediction):
try: # 1. numerical equal
if is_digit(prediction):
prediction = np.round(float(str(prediction).replace(",", "")), 6)
return str(prediction)
except:
pass
# 2. symbolic equal
prediction = str(prediction).strip()
## deal with [], (), {}
brackets = []
while prediction.startswith("[") and prediction.endswith("]") or (prediction.startswith("(") and prediction.endswith(")")):
bracket = prediction[0]
prediction = prediction[1:-1]
if brackets and ',' in prediction:
pred_parts = [normalize_prediction(part) for part in prediction.split(",")]
prediction = ",".join(pred_parts)
if brackets:
for b in reversed(brackets):
if b == '[':
prediction = '[' + prediction + ']'
else:
assert b == '('
prediction = '(' + prediction + ')'
def _parse(s):
for f in [parse_latex, parse_expr]:
try:
return f(s)
except:
pass
return s
prediction = _parse(prediction)
for s in ['{', "}", "(", ")"]:
prediction = prediction.replace(s, "")
return prediction
def math_equal(prediction: Union[bool, float, str],
reference: Union[float, str],
include_percentage: bool = True,
is_close: bool = True,
timeout: bool = False,
) -> bool:
"""
Exact match of math if and only if:
1. numerical equal: both can convert to float and are equal
2. symbolic equal: both can convert to sympy expression and are equal
"""
if str(prediction) == str(reference):
return True
try: # 1. numerical equal
if is_digit(prediction) and is_digit(reference):
prediction = parse_digits(prediction)
reference = parse_digits(reference)
# number questions
if include_percentage:
gt_result = [reference / 100, reference, reference * 100]
else:
gt_result = [reference]
for item in gt_result:
try:
if is_close:
if isclose(item, prediction, abs_tol=1e-3):
return True
else:
if item == prediction:
return True
except Exception:
continue
return False
except:
pass
if not prediction and prediction not in [0, False]:
return False
# 2. symbolic equal
reference = str(reference).strip()
prediction = str(prediction).strip()
if regex.match(r'(\(|\[).+(\)|\])', prediction) is not None and regex.match(r'(\(|\[).+(\)|\])', reference) is not None:
pred_parts = prediction[1:-1].split(",")
ref_parts = reference[1:-1].split(",")
if len(pred_parts) == len(ref_parts):
if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
return True
if (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) and \
(reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")):
pred_lines = [line.strip() for line in prediction[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()]
ref_lines = [line.strip() for line in reference[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()]
matched = True
if len(pred_lines) == len(ref_lines):
for pred_line, ref_line in zip(pred_lines, ref_lines):
pred_parts = pred_line.split("&")
ref_parts = ref_line.split("&")
if len(pred_parts) == len(ref_parts):
if not all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
matched = False
break
else:
matched = False
if not matched:
break
else:
matched = False
if matched:
return True
if prediction.count('=') == 1 and reference.count('=') == 1:
pred = prediction.split('=')
pred = f"{pred[0].strip()} - ({pred[1].strip()})"
ref = reference.split('=')
ref = f"{ref[0].strip()} - ({ref[1].strip()})"
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
return True
elif prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference:
if math_equal(prediction.split('=')[1], reference, include_percentage, is_close):
return True
elif reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction:
if math_equal(prediction, reference.split('=')[1], include_percentage, is_close):
return True
# symbolic equal with sympy
if timeout:
if call_with_timeout(symbolic_equal_process, prediction, reference):
return True
else:
if symbolic_equal(prediction, reference):
return True
return False
def math_equal_process(param):
return math_equal(param[-2], param[-1])
def symbolic_equal(a, b):
def _parse(s):
for f in [parse_latex, parse_expr]:
try:
return f(s)
except:
pass
return s
a = _parse(a)
b = _parse(b)
try:
if simplify(a-b) == 0:
return True
except:
pass
try:
if isclose(N(a), N(b), abs_tol=1e-3):
return True
except:
pass
return False
def symbolic_equal_process(a, b, output_queue):
result = symbolic_equal(a, b)
output_queue.put(result)
def call_with_timeout(func, *args, timeout=1, **kwargs):
output_queue = multiprocessing.Queue()
process_args = args + (output_queue,)
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
process.start()
process.join(timeout)
if process.is_alive():
process.terminate()
process.join()
return False
return output_queue.get()

View File

@@ -0,0 +1,263 @@
import re
import numpy as np
import sympy
from sympy.core.sympify import SympifyError
from sympy.parsing.latex import parse_latex
import signal
INVALID_ANSWER = "[invalidanswer]"
class timeout:
def __init__(self, seconds=1, error_message="Timeout"):
self.seconds = seconds
self.error_message = error_message
def handle_timeout(self, signum, frame):
raise TimeoutError(self.error_message)
def __enter__(self):
signal.signal(signal.SIGALRM, self.handle_timeout)
signal.alarm(self.seconds)
def __exit__(self, type, value, traceback):
signal.alarm(0)
def normalize_numeric(s):
if s is None:
return None
for unit in [
"eV",
" \\mathrm{~kg} \\cdot \\mathrm{m} / \\mathrm{s}",
" kg m/s",
"kg*m/s",
"kg",
"m/s",
"m / s",
"m s^{-1}",
"\\text{ m/s}",
" \\mathrm{m/s}",
" \\text{ m/s}",
"g/mole",
"g/mol",
"\\mathrm{~g}",
"\\mathrm{~g} / \\mathrm{mol}",
"W",
"erg/s",
"years",
"year",
"cm",
]:
s = s.replace(unit, "")
s = s.strip()
for maybe_unit in ["m", "s", "cm"]:
s = s.replace("\\mathrm{" + maybe_unit + "}", "")
s = s.replace("\\mathrm{~" + maybe_unit + "}", "")
s = s.strip()
s = s.strip("$")
try:
return float(eval(s))
except:
try:
expr = parse_latex(s)
if expr.is_number:
return float(expr)
return INVALID_ANSWER
except:
return INVALID_ANSWER
def numeric_equality(n1, n2, threshold=0.01):
if n1 is None or n2 is None:
return False
if np.isclose(n1, 0) or np.isclose(n2, 0) or np.isclose(n1 - n2, 0):
return np.abs(n1 - n2) < threshold * (n1 + n2) / 2
else:
return np.isclose(n1, n2)
def normalize_symbolic_equation(s):
if not isinstance(s, str):
return INVALID_ANSWER
if s.startswith("\\["):
s = s[2:]
if s.endswith("\\]"):
s = s[:-2]
s = s.replace("\\left(", "(")
s = s.replace("\\right)", ")")
s = s.replace("\\\\", "\\")
if s.startswith("$") or s.endswith("$"):
s = s.strip("$")
try:
maybe_expression = parse_latex(s)
if not isinstance(maybe_expression, sympy.core.relational.Equality):
# we have equation, not expression
return INVALID_ANSWER
else:
return maybe_expression
except:
return INVALID_ANSWER
class SymbolicMathMixin:
"""
Methods useful for parsing mathematical expressions from text and determining equivalence of expressions.
"""
SUBSTITUTIONS = [ # used for text normalize
("an ", ""),
("a ", ""),
(".$", "$"),
("\\$", ""),
(r"\ ", ""),
(" ", ""),
("mbox", "text"),
(",\\text{and}", ","),
("\\text{and}", ","),
("\\text{m}", "\\text{}"),
]
REMOVED_EXPRESSIONS = [ # used for text normalizer
"square",
"ways",
"integers",
"dollars",
"mph",
"inches",
"ft",
"hours",
"km",
"units",
"\\ldots",
"sue",
"points",
"feet",
"minutes",
"digits",
"cents",
"degrees",
"cm",
"gm",
"pounds",
"meters",
"meals",
"edges",
"students",
"childrentickets",
"multiples",
"\\text{s}",
"\\text{.}",
"\\text{\ns}",
"\\text{}^2",
"\\text{}^3",
"\\text{\n}",
"\\text{}",
r"\mathrm{th}",
r"^\circ",
r"^{\circ}",
r"\;",
r",\!",
"{,}",
'"',
"\\dots",
]
def normalize_tex(self, final_answer: str) -> str:
"""
Normalizes a string representing a mathematical expression.
Used as a preprocessing step before parsing methods.
Copied character for character from appendix D of Lewkowycz et al. (2022)
"""
final_answer = final_answer.split("=")[-1]
for before, after in self.SUBSTITUTIONS:
final_answer = final_answer.replace(before, after)
for expr in self.REMOVED_EXPRESSIONS:
final_answer = final_answer.replace(expr, "")
# Extract answer that is in LaTeX math, is bold,
# is surrounded by a box, etc.
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
# Normalize shorthand TeX:
# \fracab -> \frac{a}{b}
# \frac{abc}{bef} -> \frac{abc}{bef}
# \fracabc -> \frac{a}{b}c
# \sqrta -> \sqrt{a}
# \sqrtab -> sqrt{a}b
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
final_answer = final_answer.replace("$", "")
# Normalize 100,000 -> 100000
if final_answer.replace(",", "").isdigit():
final_answer = final_answer.replace(",", "")
return final_answer
def parse_tex(self, text: str, time_limit: int = 5) -> sympy.Basic:
"""
Wrapper around `sympy.parse_text` that outputs a SymPy expression.
Typically, you want to apply `normalize_text` as a preprocessing step.
"""
try:
with timeout(seconds=time_limit):
parsed = parse_latex(text)
except (
# general error handling: there is a long tail of possible sympy/other
# errors we would like to catch
Exception
) as e:
print(f"failed to parse {text} with exception {e}")
return None
return parsed
def is_exp_equiv(self, x1: sympy.Basic, x2: sympy.Basic, time_limit=5) -> bool:
"""
Determines whether two sympy expressions are equal.
"""
try:
with timeout(seconds=time_limit):
try:
diff = x1 - x2
except (SympifyError, ValueError, TypeError) as e:
print(
f"Couldn't subtract {x1} and {x2} with exception {e}"
)
return False
try:
if sympy.simplify(diff) == 0:
return True
else:
return False
except (SympifyError, ValueError, TypeError) as e:
print(f"Failed to simplify {x1}-{x2} with {e}")
return False
except TimeoutError as e:
print(f"Timed out comparing {x1} and {x2}")
return False
except Exception as e:
print(f"failed on unrecognized exception {e}")
return False
def is_tex_equiv(self, x1: str, x2: str, time_limit=5) -> bool:
"""
Determines whether two (ideally normalized using `normalize_text`) TeX expressions are equal.
Does so by first checking for string exact-match, then falls back on sympy-equivalence,
following the (Lewkowycz et al. 2022) methodology.
"""
if x1 == x2:
# don't resort to sympy if we have full string match, post-normalization
return True
else:
return False
parsed_x2 = self.parse_tex(x2)
if not parsed_x2:
# if our reference fails to parse into a Sympy object,
# we forgo parsing + checking our generated answer.
return False
return self.is_exp_equiv(self.parse_tex(x1), parsed_x2, time_limit=time_limit)

View File

@@ -0,0 +1,165 @@
import os
import io
from contextlib import redirect_stdout
import pickle
import regex
import copy
from typing import Any, Dict, Optional
import multiprocess
from pebble import ProcessPool
from concurrent.futures import TimeoutError
from functools import partial
import traceback
from timeout_decorator import timeout
class GenericRuntime:
GLOBAL_DICT = {}
LOCAL_DICT = None
HEADERS = []
def __init__(self):
self._global_vars = copy.copy(self.GLOBAL_DICT)
self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
for c in self.HEADERS:
self.exec_code(c)
def exec_code(self, code_piece: str) -> None:
if regex.search(r'(\s|^)?input\(', code_piece) or regex.search(r'(\s|^)?os.system\(', code_piece):
raise RuntimeError()
exec(code_piece, self._global_vars)
def eval_code(self, expr: str) -> Any:
return eval(expr, self._global_vars)
def inject(self, var_dict: Dict[str, Any]) -> None:
for k, v in var_dict.items():
self._global_vars[k] = v
@property
def answer(self):
return self._global_vars['answer']
class PythonExecutor:
def __init__(
self,
runtime: Optional[Any] = None,
get_answer_symbol: Optional[str] = None,
get_answer_expr: Optional[str] = None,
get_answer_from_stdout: bool = False,
) -> None:
self.runtime = runtime if runtime else GenericRuntime()
self.answer_symbol = get_answer_symbol
self.answer_expr = get_answer_expr
self.get_answer_from_stdout = get_answer_from_stdout
def process_generation_to_code(self, gens: str):
batch_code = []
for g in gens:
multiline_comments = False
code = []
for line in g.split('\n'):
strip_line = line.strip()
if strip_line.startswith("#"):
line = line.split("#", 1)[0] + "# comments"
elif not multiline_comments and strip_line.startswith('"""') and strip_line.endswith('"""') and len(strip_line) >= 6:
line = line.split('"""', 1)[0] + '"""comments"""'
elif not multiline_comments and strip_line.startswith('"""'):
multiline_comments = True
elif multiline_comments and strip_line.endswith('"""'):
multiline_comments = False
line = ""
if not multiline_comments:
code.append(line)
batch_code.append(code)
return batch_code
@staticmethod
def execute(
code,
get_answer_from_stdout = None,
runtime = None,
answer_symbol = None,
answer_expr = None,
timeout_length = 10,
):
try:
if get_answer_from_stdout:
program_io = io.StringIO()
with redirect_stdout(program_io):
timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
program_io.seek(0)
result = "".join(program_io.readlines()) # [-1]
elif answer_symbol:
timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
result = runtime._global_vars[answer_symbol]
elif answer_expr:
timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
else:
timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1]))
result = timeout(timeout_length)(runtime.eval_code)(code[-1])
concise_exec_info = ""
exec_info = ""
str(result)
pickle.dumps(result) # serialization check
except:
# traceback.print_exc()
result = ''
concise_exec_info = traceback.format_exc().split('\n')[-2]
exec_info = traceback.format_exc()
if get_answer_from_stdout and 'exec(code_piece, self._global_vars)' in exec_info:
exec_info = exec_info.split('exec(code_piece, self._global_vars)')[-1].strip()
msg = []
for line in exec_info.split("\n"):
patt = regex.search(r'(?P<start>.*)File "(?P<file>.*)", line (?P<lno>\d+), (?P<end>.*)', line)
if patt is not None:
if '<module>' in patt.group('end'):
continue
fname = patt.group("file")
if "site-packages" in fname:
fname = f"site-packages{fname.split('site-packages', 1)[1]}"
line = f'{patt.group("start")}File "{fname}", {patt.group("end")}'
else:
line = f'{patt.group("start")}{patt.group("end")}'
else:
patt = regex.search(r'(?P<start>.*)(?P<file>/.*site-packages/.*\.py)(?P<end>.*)', line)
if patt is not None:
line = f'{patt.group("start")}site-packages{patt.group("file").split("site-packages", 1)[1]}{patt.group("end")}'
msg.append(line)
exec_info = "\n".join(msg)
return result, concise_exec_info, exec_info
def apply(self, code):
return self.batch_apply([code])[0]
def batch_apply(self, batch_code):
all_code_snippets = self.process_generation_to_code(batch_code)
all_exec_results = []
executor = partial(
self.execute,
get_answer_from_stdout=self.get_answer_from_stdout,
runtime=self.runtime,
answer_symbol=self.answer_symbol,
answer_expr=self.answer_expr,
timeout_length=10,
)
with ProcessPool(max_workers=multiprocess.cpu_count()) as pool:
iterator = pool.map(executor, all_code_snippets, timeout=10).result()
while True:
try:
result = next(iterator)
all_exec_results.append(result)
except StopIteration:
break
except TimeoutError as error:
all_exec_results.append(("", "Timeout Error", "Timeout Error"))
except Exception as error:
print(error)
exit()
batch_results = []
for code, (result, concise_exec_info, exec_info) in zip(all_code_snippets, all_exec_results):
metadata = {'code': code, 'exec_result': result, 'concise_exec_info': concise_exec_info, 'exec_info': exec_info}
batch_results.append((result, metadata))
return batch_results

263
evaluation/eval/utils.py Executable file
View File

@@ -0,0 +1,263 @@
import torch
import tqdm
from transformers import StoppingCriteria, GenerationConfig
class KeyWordsCriteria(StoppingCriteria):
def __init__(self, stop_id_sequences, tokenizer, prompt_length):
assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids"
self.tokenizer = tokenizer
self.stop_id_sequences = stop_id_sequences
self.stop_sequences = [tokenizer.decode(sequence) for sequence in stop_id_sequences]
print(f"stop sequences: {self.stop_sequences}", flush=True)
self.prompt_length = prompt_length
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
sequences_should_be_stopped = []
for i in range(input_ids.shape[0]):
ids = input_ids[i][self.prompt_length:].tolist()
should_be_stopped = False
for stop_ids, stop_sequence in zip(self.stop_id_sequences, self.stop_sequences):
_ids = ids
for j in range(len(_ids), 0, -1):
s = self.tokenizer.decode(_ids[max(j - len(stop_ids) - 3, 0) :j])
if s.endswith(stop_sequence):
should_be_stopped = True
break
if should_be_stopped:
break
sequences_should_be_stopped.append(should_be_stopped)
return all(sequences_should_be_stopped)
@torch.no_grad()
def generate_completions(model, tokenizer, prompts, batch_size=1, stop_id_sequences=None, end_of_generation_id_sequence=None, disable_tqdm=False, **generation_kwargs):
generations = []
finish_completion = []
if not disable_tqdm:
progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions")
if stop_id_sequences is not None:
stop_sequences = [tokenizer.decode(stop_id_sequence) for stop_id_sequence in stop_id_sequences]
if end_of_generation_id_sequence is not None:
end_of_generation_sequence = tokenizer.decode(end_of_generation_id_sequence)
num_return_sequences = generation_kwargs.get("num_return_sequences", 1)
generation_kwargs['use_cache'] = True
for i in range(0, len(prompts), batch_size):
batch_prompts = prompts[i:i+batch_size]
tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens='chatglm2' in str(model.__class__))
batch_input_ids = tokenized_prompts.input_ids
attention_mask = tokenized_prompts.attention_mask
if model.device.type == "cuda":
batch_input_ids = batch_input_ids.cuda()
attention_mask = attention_mask.cuda()
batch_finish_completion = [False] * len(batch_prompts) * num_return_sequences
try:
batch_outputs = model.generate(
input_ids=batch_input_ids,
attention_mask=attention_mask,
stopping_criteria=[KeyWordsCriteria(stop_id_sequences, tokenizer, batch_input_ids.size(1))] if stop_id_sequences else None,
**generation_kwargs
)
# the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate.
# so some outputs still have the stop sequence, which we need to remove.
if stop_id_sequences:
for output_idx in range(batch_outputs.shape[0]):
finish = False
for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]):
if any(tokenizer.decode(batch_outputs[output_idx, token_idx: token_idx + len(stop_sequence) + 3]).startswith(stop_sequence) for stop_sequence in stop_sequences):
if end_of_generation_id_sequence is not None and tokenizer.decode(batch_outputs[output_idx, token_idx: token_idx + len(end_of_generation_id_sequence) + 3]).startswith(end_of_generation_sequence):
batch_finish_completion[output_idx] = True
batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id
break
# remove the prompt from the output
# we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs.
# we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token.
# space is important for some tasks (e.g., code completion).
batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True)
batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
# duplicate the prompts to match the number of return sequences
batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)]
batch_generations = [
output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs)
]
except Exception as e:
print("Error when generating completions for batch:")
print(batch_prompts)
print("Error message:")
print(e)
print("Use empty string as the completion.")
batch_generations = [""] * len(batch_prompts) * num_return_sequences
generations += batch_generations
finish_completion += batch_finish_completion
if not disable_tqdm:
progress.update(len(batch_prompts)//num_return_sequences)
assert len(generations) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences"
return generations, finish_completion
@torch.no_grad()
def get_next_word_predictions(model, tokenizer, prompts, candidate_token_ids=None, batch_size=1, return_token_predictions=False, disable_tqdm=False):
predictions, probs = [], []
if not disable_tqdm:
progress = tqdm.tqdm(total=len(prompts), desc="Getting Predictions")
for i in range(0, len(prompts), batch_size):
batch_prompts = prompts[i: i+batch_size]
tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=False)
batch_input_ids = tokenized_prompts.input_ids
attention_mask = tokenized_prompts.attention_mask
if model.device.type == "cuda":
batch_input_ids = batch_input_ids.cuda()
attention_mask = attention_mask.cuda()
batch_logits = model(input_ids=batch_input_ids, attention_mask=attention_mask).logits[:, -1, :]
if candidate_token_ids is not None:
batch_logits = batch_logits[:, candidate_token_ids]
batch_probs = torch.softmax(batch_logits, dim=-1)
batch_prediction_indices = torch.argmax(batch_probs, dim=-1)
if return_token_predictions:
if candidate_token_ids is not None:
candidate_tokens = tokenizer.convert_ids_to_tokens(candidate_token_ids)
batch_predictions = [candidate_tokens[idx] for idx in batch_prediction_indices]
else:
batch_predictions = tokenizer.convert_ids_to_tokens(batch_prediction_indices)
predictions += batch_predictions
else:
predictions += batch_prediction_indices.tolist()
probs += batch_probs.tolist()
if not disable_tqdm:
progress.update(len(batch_prompts))
assert len(predictions) == len(prompts), "number of predictions should be equal to number of prompts"
return predictions, probs
@torch.no_grad()
def score_completions(model, tokenizer, scoring_examples, disable_tqdm=False):
'''
Each scoring example is a dict, which contains the following keys:
- prompt: the prompt to score
- completions: a list of completions to score
'''
if not disable_tqdm:
progress = tqdm.tqdm(total=len(scoring_examples), desc="Scoring Completions")
# unroll the scoring examples
unrolled_examples = []
for scoring_example in scoring_examples:
prompt = scoring_example["prompt"]
for completion in scoring_example["completions"]:
unrolled_examples.append({
"prompt": prompt,
"completion": completion
})
scores = []
# currently we don't support batching, because we want to directly use the loss returned by the model to score each completion.
for unrolled_example in unrolled_examples:
encoded_example = encode_with_prompt_completion_format(unrolled_example, tokenizer, max_seq_length=None)
# unsqueeze the batch dimension
for key, value in encoded_example.items():
encoded_example[key] = value.unsqueeze(0)
if model.device.type == "cuda":
encoded_example = {
key: value.cuda() for key, value in encoded_example.items()
}
outputs = model(**encoded_example)
loss = outputs.loss
scores.append(-loss.item())
if not disable_tqdm:
progress.update(1)
# roll up the scores
rolled_up_scores = {}
for unrolled_example, score in zip(unrolled_examples, scores):
prompt = unrolled_example["prompt"]
completion = unrolled_example["completion"]
if prompt not in rolled_up_scores:
rolled_up_scores[prompt] = {}
rolled_up_scores[prompt][completion] = score
return rolled_up_scores
def load_hf_lm_and_tokenizer(
model_name_or_path,
tokenizer_name_or_path=None,
device_map="auto",
load_in_8bit=False,
load_in_half=False,
gptq_model=False,
use_fast_tokenizer=True,
padding_side="left",
):
from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer
if not tokenizer_name_or_path:
tokenizer_name_or_path = model_name_or_path
is_chatglm2 = 'chatglm2' in tokenizer_name_or_path.lower() or 'chatglm2' in model_name_or_path
is_qwen = 'qwen' in tokenizer_name_or_path.lower() or 'qwen' in model_name_or_path
if is_chatglm2 or is_qwen:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True)
if is_qwen:
tokenizer.eos_token = '<|endoftext|>'
tokenizer.eos_token_id = 151643
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=use_fast_tokenizer)
# set padding side to left for batch generation
tokenizer.padding_side = padding_side
# set pad token to eos token if pad token is not set (as is the case for llama models)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
if gptq_model:
from auto_gptq import AutoGPTQForCausalLM
model_wrapper = AutoGPTQForCausalLM.from_quantized(
model_name_or_path, device="cuda:0", use_triton=True
)
model = model_wrapper.model
elif load_in_8bit:
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map=device_map,
load_in_8bit=True
)
else:
kwargs = {}
model_class = AutoModelForCausalLM
if is_chatglm2:
kwargs = {'trust_remote_code': True}
model_class = AutoModel
elif is_qwen:
kwargs = {'trust_remote_code': True}
if device_map:
model = model_class.from_pretrained(model_name_or_path, device_map=device_map, **kwargs)
else:
model = model_class.from_pretrained(model_name_or_path, **kwargs)
if torch.cuda.is_available():
model = model.cuda()
if is_qwen:
model.generation_config = GenerationConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
model.generation_config.do_sample = False
if not is_chatglm2 and not is_qwen and load_in_half:
model = model.half()
model.eval()
return model, tokenizer