mirror of
https://github.com/deepseek-ai/DeepSeek-Coder
synced 2025-06-26 18:25:53 +00:00
init project
This commit is contained in:
151
Evaluation/PAL-Math/utils/grader.py
Normal file
151
Evaluation/PAL-Math/utils/grader.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
|
||||
- https://github.com/microsoft/ProphetNet/tree/master/CRITIC
|
||||
- https://github.com/openai/prm800k
|
||||
"""
|
||||
import multiprocessing
|
||||
from math import isclose
|
||||
from typing import Union
|
||||
|
||||
from sympy import simplify, N
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
from sympy.parsing.latex import parse_latex
|
||||
|
||||
|
||||
def is_digit(s):
|
||||
try:
|
||||
float(str(s).replace(",", ""))
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def math_equal(prediction: Union[bool, float, str],
|
||||
reference: Union[float, str],
|
||||
include_percentage: bool = True,
|
||||
is_close: bool = True,
|
||||
timeout: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Exact match of math if and only if:
|
||||
1. numerical equal: both can convert to float and are equal
|
||||
2. symbolic equal: both can convert to sympy expression and are equal
|
||||
"""
|
||||
try: # 1. numerical equal
|
||||
if is_digit(prediction) and is_digit(reference):
|
||||
prediction = float(str(prediction).replace(",", ""))
|
||||
reference = float(str(reference).replace(",", ""))
|
||||
# number questions
|
||||
if include_percentage:
|
||||
gt_result = [reference / 100, reference, reference * 100]
|
||||
else:
|
||||
gt_result = [reference]
|
||||
for item in gt_result:
|
||||
try:
|
||||
if is_close:
|
||||
if isclose(item, prediction, rel_tol=1e-4):
|
||||
return True
|
||||
else:
|
||||
if item == prediction:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
except:
|
||||
pass
|
||||
|
||||
if not prediction and prediction not in [0, False]:
|
||||
return False
|
||||
|
||||
# 2. symbolic equal
|
||||
reference = str(reference).strip()
|
||||
prediction = str(prediction).strip()
|
||||
|
||||
## deal with [], (), {}
|
||||
pred_str, ref_str = prediction, reference
|
||||
if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or \
|
||||
(prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")):
|
||||
pred_str = pred_str.strip("[]()")
|
||||
ref_str = ref_str.strip("[]()")
|
||||
for s in ['{', "}", "(", ")"]:
|
||||
ref_str = ref_str.replace(s, "")
|
||||
pred_str = pred_str.replace(s, "")
|
||||
if pred_str == ref_str:
|
||||
return True
|
||||
|
||||
## [a, b] vs. [c, d], return a==c and b==d
|
||||
if (prediction.startswith("[") and prediction.endswith("]")) and (reference.startswith("[") and reference.endswith("]")) or \
|
||||
(prediction.startswith("(") and prediction.endswith(")")) and (reference.startswith("(") and reference.endswith(")")):
|
||||
pred_parts = prediction[1:-1].split(",")
|
||||
ref_parts = reference[1:-1].split(",")
|
||||
if len(pred_parts) == len(ref_parts):
|
||||
if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
|
||||
return True
|
||||
|
||||
# symbolic equal with sympy
|
||||
if timeout:
|
||||
if call_with_timeout(symbolic_equal_process, prediction, reference):
|
||||
return True
|
||||
else:
|
||||
if symbolic_equal(prediction, reference):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def math_equal_process(param):
|
||||
print(param[-2], param[-1],math_equal(param[-2], param[-1]))
|
||||
return math_equal(param[-2], param[-1])
|
||||
|
||||
|
||||
def symbolic_equal(a, b):
|
||||
def _parse(s):
|
||||
for f in [parse_latex, parse_expr]:
|
||||
try:
|
||||
return f(s)
|
||||
except:
|
||||
pass
|
||||
return s
|
||||
a = _parse(a)
|
||||
b = _parse(b)
|
||||
|
||||
try:
|
||||
if simplify(a-b) == 0:
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if isclose(N(a), N(b), rel_tol=1e-3):
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def symbolic_equal_process(a, b, output_queue):
|
||||
result = symbolic_equal(a, b)
|
||||
output_queue.put(result)
|
||||
|
||||
|
||||
def call_with_timeout(func, *args, timeout=1, **kwargs):
|
||||
output_queue = multiprocessing.Queue()
|
||||
process_args = args + (output_queue,)
|
||||
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
|
||||
process.start()
|
||||
process.join(timeout)
|
||||
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
process.join()
|
||||
return False
|
||||
|
||||
return output_queue.get()
|
||||
|
||||
|
||||
def _test_math_equal():
|
||||
# print(math_equal("0.0833333333333333", "\\frac{1}{12}"))
|
||||
# print(math_equal("(1,4.5)", "(1,\\frac{9}{2})"))
|
||||
print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True))
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_math_equal()
|
||||
320
Evaluation/PAL-Math/utils/parser.py
Normal file
320
Evaluation/PAL-Math/utils/parser.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
This file come from: https://github.com/microsoft/ToRA/blob/main/src/utils/parser.py
|
||||
"""
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def _fix_fracs(string):
|
||||
substrs = string.split("\\frac")
|
||||
new_str = substrs[0]
|
||||
if len(substrs) > 1:
|
||||
substrs = substrs[1:]
|
||||
for substr in substrs:
|
||||
new_str += "\\frac"
|
||||
if len(substr) > 0 and substr[0] == "{":
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
except:
|
||||
return string
|
||||
a = substr[0]
|
||||
b = substr[1]
|
||||
if b != "{":
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}{" + b + "}"
|
||||
else:
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}" + b + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}" + b
|
||||
string = new_str
|
||||
return string
|
||||
|
||||
|
||||
def _fix_a_slash_b(string):
|
||||
if len(string.split("/")) != 2:
|
||||
return string
|
||||
a = string.split("/")[0]
|
||||
b = string.split("/")[1]
|
||||
try:
|
||||
if "sqrt" not in a:
|
||||
a = int(a)
|
||||
if "sqrt" not in b:
|
||||
b = int(b)
|
||||
assert string == "{}/{}".format(a, b)
|
||||
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
||||
return new_string
|
||||
except:
|
||||
return string
|
||||
|
||||
|
||||
def _fix_sqrt(string):
|
||||
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
|
||||
return _string
|
||||
|
||||
|
||||
def strip_string(string):
|
||||
string = str(string).strip()
|
||||
# linebreaks
|
||||
string = string.replace("\n", "")
|
||||
|
||||
# right "."
|
||||
string = string.rstrip(".")
|
||||
|
||||
# remove inverse spaces
|
||||
string = string.replace("\\!", "")
|
||||
string = string.replace("\\ ", "")
|
||||
|
||||
# replace \\ with \
|
||||
string = string.replace("\\\\", "\\")
|
||||
string = string.replace("\\\\", "\\")
|
||||
|
||||
# replace tfrac and dfrac with frac
|
||||
string = string.replace("tfrac", "frac")
|
||||
string = string.replace("dfrac", "frac")
|
||||
|
||||
# remove \left and \right
|
||||
string = string.replace("\\left", "")
|
||||
string = string.replace("\\right", "")
|
||||
|
||||
# Remove unit: miles, dollars if after is not none
|
||||
_string = re.sub(r"\\text{.*?}$", "", string).strip()
|
||||
if _string != "" and _string != string:
|
||||
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
|
||||
string = _string
|
||||
|
||||
# Remove circ (degrees)
|
||||
string = string.replace("^{\\circ}", "")
|
||||
string = string.replace("^\\circ", "")
|
||||
|
||||
# remove dollar signs
|
||||
string = string.replace("\\$", "")
|
||||
string = string.replace("$", "")
|
||||
|
||||
string = string.replace("\\text", "")
|
||||
string = string.replace("x\\in", "")
|
||||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace("\%", "")
|
||||
string = string.replace("%", "")
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
string = string.replace(" .", " 0.")
|
||||
string = string.replace("{.", "{0.")
|
||||
|
||||
# cdot
|
||||
string = string.replace("\\cdot", "")
|
||||
|
||||
# inf
|
||||
string = string.replace("infinity", "\\infty")
|
||||
if "\\infty" not in string:
|
||||
string = string.replace("inf", "\\infty")
|
||||
string = string.replace("+\\inity", "\\infty")
|
||||
|
||||
# and
|
||||
string = string.replace("and", "")
|
||||
string = string.replace("\\mathbf", "")
|
||||
|
||||
# use regex to remove \mbox{...}
|
||||
string = re.sub(r"\\mbox{.*?}", "", string)
|
||||
|
||||
# quote
|
||||
string.replace("'", "")
|
||||
string.replace("\"", "")
|
||||
|
||||
# i, j
|
||||
if "j" in string and "i" not in string:
|
||||
string = string.replace("j", "i")
|
||||
|
||||
# replace a.000b where b is not number or b is end, with ab, use regex
|
||||
string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string)
|
||||
string = re.sub(r"(\d+)\.0+$", r"\1", string)
|
||||
|
||||
# if empty, return empty string
|
||||
if len(string) == 0:
|
||||
return string
|
||||
if string[0] == ".":
|
||||
string = "0" + string
|
||||
|
||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||
if len(string.split("=")) == 2:
|
||||
if len(string.split("=")[0]) <= 2:
|
||||
string = string.split("=")[1]
|
||||
|
||||
string = _fix_sqrt(string)
|
||||
string = string.replace(" ", "")
|
||||
|
||||
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||
string = _fix_fracs(string)
|
||||
|
||||
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||
string = _fix_a_slash_b(string)
|
||||
|
||||
return string
|
||||
|
||||
def extract_answer(pred_str):
|
||||
if 'boxed' in pred_str:
|
||||
ans = pred_str.split('boxed')[-1]
|
||||
if len(ans) == 0:
|
||||
return ""
|
||||
elif (ans[0] == '{'):
|
||||
stack = 1
|
||||
a = ''
|
||||
for c in ans[1:]:
|
||||
if (c == '{'):
|
||||
stack += 1
|
||||
a += c
|
||||
elif (c == '}'):
|
||||
stack -= 1
|
||||
if (stack == 0): break
|
||||
a += c
|
||||
else:
|
||||
a += c
|
||||
else:
|
||||
a = ans.split('$')[0].strip()
|
||||
pred=a
|
||||
elif ('he answer is' in pred_str):
|
||||
pred = pred_str.split('he answer is')[-1].strip()
|
||||
elif extract_program_output(pred_str) != "":
|
||||
# fall back to program
|
||||
pred = extract_program_output(pred_str)
|
||||
else: # use the last number
|
||||
pattern = '-?\d*\.?\d+'
|
||||
pred = re.findall(pattern, pred_str.replace(",", ""))
|
||||
if(len(pred) >= 1):
|
||||
pred = pred[-1]
|
||||
else: pred = ''
|
||||
|
||||
# multiple line
|
||||
pred = pred.split("\n")[0]
|
||||
if pred != "" and pred[0] == ":":
|
||||
pred = pred[1:]
|
||||
if pred != "" and pred[-1] == ".":
|
||||
pred = pred[:-1]
|
||||
if pred != "" and pred[-1] == "/":
|
||||
pred = pred[:-1]
|
||||
pred = strip_string(pred)
|
||||
return pred
|
||||
|
||||
|
||||
def extract_program(result: str, last_only=True):
|
||||
"""
|
||||
extract the program after "```python", and before "```"
|
||||
"""
|
||||
program = ""
|
||||
start = False
|
||||
for line in result.split("\n"):
|
||||
if line.startswith("```python"):
|
||||
if last_only:
|
||||
program = "" # only extract the last program
|
||||
else:
|
||||
program += "\n# ========\n"
|
||||
start = True
|
||||
elif line.startswith("```"):
|
||||
start = False
|
||||
elif start:
|
||||
program += line + "\n"
|
||||
return program
|
||||
|
||||
|
||||
def extract_program_output(pred_str):
|
||||
"""
|
||||
extract output between the last ```output\n...\n```
|
||||
"""
|
||||
if "```output" not in pred_str:
|
||||
return ""
|
||||
if '```output' in pred_str:
|
||||
pred_str = pred_str.split('```output')[-1]
|
||||
if '```' in pred_str:
|
||||
pred_str = pred_str.split('```')[0]
|
||||
output = pred_str.strip()
|
||||
return output
|
||||
|
||||
|
||||
def parse_ground_truth(example: Dict[str, Any], data_name):
|
||||
if 'gt_cot' in example:
|
||||
return example['gt_cot'], strip_string(example['gt'])
|
||||
|
||||
# parse ground truth
|
||||
if data_name in ["math", 'ocw']:
|
||||
gt_cot = example['solution']
|
||||
gt_ans = extract_answer(gt_cot)
|
||||
elif data_name == "gsm8k":
|
||||
gt_cot, gt_ans = example['answer'].split("####")
|
||||
elif data_name == "gsm-hard":
|
||||
gt_cot, gt_ans = example['code'], example['target']
|
||||
elif data_name == "svamp":
|
||||
gt_cot, gt_ans = example['Equation'], example['Answer']
|
||||
elif data_name == "asdiv":
|
||||
gt_cot = example['formula']
|
||||
gt_ans = re.sub(r"\(.*?\)", "", example['answer'])
|
||||
elif data_name == "mawps":
|
||||
gt_cot, gt_ans = None, example['target']
|
||||
elif data_name == "tabmwp":
|
||||
gt_cot = example['solution']
|
||||
gt_ans = example['answer']
|
||||
if example['ans_type'] in ['integer_number', 'decimal_number']:
|
||||
if '/' in gt_ans:
|
||||
gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1])
|
||||
elif ',' in gt_ans:
|
||||
gt_ans = float(gt_ans.replace(',', ''))
|
||||
elif '%' in gt_ans:
|
||||
gt_ans = float(gt_ans.split('%')[0]) / 100
|
||||
else:
|
||||
gt_ans = float(gt_ans)
|
||||
elif data_name == "bbh":
|
||||
gt_cot, gt_ans = None, example['target']
|
||||
else:
|
||||
raise NotImplementedError(data_name)
|
||||
# post process
|
||||
gt_cot = str(gt_cot).strip()
|
||||
gt_ans = strip_string(gt_ans)
|
||||
return gt_cot, gt_ans
|
||||
|
||||
|
||||
def parse_question(example, data_name):
|
||||
question = ""
|
||||
if data_name == "asdiv":
|
||||
question = f"{example['body'].strip()} {example['question'].strip()}"
|
||||
elif data_name == "svamp":
|
||||
body = example["Body"].strip()
|
||||
if not body.endswith("."):
|
||||
body = body + "."
|
||||
question = f'{body} {example["Question"].strip()}'
|
||||
elif data_name == "tabmwp":
|
||||
title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else ""
|
||||
question = f'Read the following table {title_str}and answer a question:\n'
|
||||
question += f'{example["table"]}\n{example["question"]}'
|
||||
if example['choices']:
|
||||
question += f' Please select from the following options: {example["choices"]}'
|
||||
else:
|
||||
for key in ['question', 'problem', 'Question', 'input']:
|
||||
if key in example:
|
||||
question = example[key]
|
||||
break
|
||||
assert question != ""
|
||||
return question.strip()
|
||||
|
||||
|
||||
def run_execute(executor, result, prompt_type, execute=False):
|
||||
if not result or result == 'error':
|
||||
return None, None
|
||||
report = None
|
||||
|
||||
if "program_only" in prompt_type:
|
||||
prediction = extract_program_output(result)
|
||||
elif prompt_type in ["pot", "pal"] and execute:
|
||||
code = extract_program(result)
|
||||
prediction, report = executor.apply(code)
|
||||
else:
|
||||
prediction = extract_answer(result)
|
||||
|
||||
prediction = strip_string(prediction)
|
||||
return prediction, report
|
||||
180
Evaluation/PAL-Math/utils/python_executor.py
Executable file
180
Evaluation/PAL-Math/utils/python_executor.py
Executable file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
This file come from: https://github.com/microsoft/ToRA/blob/main/src/utils/python_executor.py
|
||||
"""
|
||||
import io
|
||||
import regex
|
||||
import pickle
|
||||
import traceback
|
||||
import copy
|
||||
import datetime
|
||||
import multiprocessing
|
||||
import dateutil.relativedelta
|
||||
import multiprocess
|
||||
from multiprocess import Pool
|
||||
from typing import Any, Dict, Optional
|
||||
from pebble import ProcessPool
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import TimeoutError
|
||||
from functools import partial
|
||||
from timeout_decorator import timeout
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
|
||||
class GenericRuntime:
|
||||
GLOBAL_DICT = {}
|
||||
LOCAL_DICT = None
|
||||
HEADERS = []
|
||||
def __init__(self):
|
||||
self._global_vars = copy.copy(self.GLOBAL_DICT)
|
||||
self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
|
||||
|
||||
for c in self.HEADERS:
|
||||
self.exec_code(c)
|
||||
|
||||
def exec_code(self, code_piece: str) -> None:
|
||||
if regex.search(r'(\s|^)?input\(', code_piece) or regex.search(r'(\s|^)?os.system\(', code_piece):
|
||||
raise RuntimeError()
|
||||
exec(code_piece, self._global_vars)
|
||||
|
||||
def eval_code(self, expr: str) -> Any:
|
||||
return eval(expr, self._global_vars)
|
||||
|
||||
def inject(self, var_dict: Dict[str, Any]) -> None:
|
||||
for k, v in var_dict.items():
|
||||
self._global_vars[k] = v
|
||||
|
||||
@property
|
||||
def answer(self):
|
||||
return self._global_vars['answer']
|
||||
|
||||
class DateRuntime(GenericRuntime):
|
||||
GLOBAL_DICT = {
|
||||
'datetime': datetime.datetime,
|
||||
'timedelta': dateutil.relativedelta.relativedelta,
|
||||
'relativedelta': dateutil.relativedelta.relativedelta
|
||||
}
|
||||
|
||||
|
||||
class CustomDict(dict):
|
||||
def __iter__(self):
|
||||
return list(super().__iter__()).__iter__()
|
||||
|
||||
class ColorObjectRuntime(GenericRuntime):
|
||||
GLOBAL_DICT = {'dict': CustomDict}
|
||||
|
||||
|
||||
class PythonExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
runtime: Optional[Any] = None,
|
||||
get_answer_symbol: Optional[str] = None,
|
||||
get_answer_expr: Optional[str] = None,
|
||||
get_answer_from_stdout: bool = False,
|
||||
timeout_length: int = 5,
|
||||
) -> None:
|
||||
self.runtime = runtime if runtime else GenericRuntime()
|
||||
self.answer_symbol = get_answer_symbol
|
||||
self.answer_expr = get_answer_expr
|
||||
self.get_answer_from_stdout = get_answer_from_stdout
|
||||
self.timeout_length = timeout_length
|
||||
|
||||
def process_generation_to_code(self, gens: str):
|
||||
return [g.split('\n') for g in gens]
|
||||
|
||||
@staticmethod
|
||||
def execute(
|
||||
code,
|
||||
get_answer_from_stdout = None,
|
||||
runtime = None,
|
||||
answer_symbol = None,
|
||||
answer_expr = None,
|
||||
timeout_length = 10,
|
||||
):
|
||||
try:
|
||||
if get_answer_from_stdout:
|
||||
program_io = io.StringIO()
|
||||
with redirect_stdout(program_io):
|
||||
timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
|
||||
program_io.seek(0)
|
||||
result = program_io.readlines()[-1]
|
||||
elif answer_symbol:
|
||||
timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
|
||||
result = runtime._global_vars[answer_symbol]
|
||||
elif answer_expr:
|
||||
timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
|
||||
result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
|
||||
else:
|
||||
timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1]))
|
||||
result = timeout(timeout_length)(runtime.eval_code)(code[-1])
|
||||
exec_info = "Done"
|
||||
str(result)
|
||||
pickle.dumps(result) # serialization check
|
||||
except:
|
||||
result = ''
|
||||
exec_info = traceback.format_exc().split('\n')[-2]
|
||||
return result, exec_info
|
||||
|
||||
def apply(self, code):
|
||||
return self.batch_apply([code])[0]
|
||||
|
||||
def batch_apply(self, batch_code):
|
||||
all_code_snippets = self.process_generation_to_code(batch_code)
|
||||
|
||||
timeout_cnt = 0
|
||||
all_exec_results = []
|
||||
with ProcessPool(max_workers=min(len(all_code_snippets), multiprocessing.cpu_count())) as pool:
|
||||
executor = partial(
|
||||
self.execute,
|
||||
get_answer_from_stdout=self.get_answer_from_stdout,
|
||||
runtime=self.runtime,
|
||||
answer_symbol=self.answer_symbol,
|
||||
answer_expr=self.answer_expr,
|
||||
timeout_length=self.timeout_length, # this timeout not work
|
||||
)
|
||||
future = pool.map(executor, all_code_snippets, timeout=self.timeout_length)
|
||||
iterator = future.result()
|
||||
|
||||
if len(all_code_snippets) > 100:
|
||||
progress_bar = tqdm(total=len(all_code_snippets), desc="Execute")
|
||||
else:
|
||||
progress_bar = None
|
||||
|
||||
while True:
|
||||
try:
|
||||
result = next(iterator)
|
||||
all_exec_results.append(result)
|
||||
except StopIteration:
|
||||
break
|
||||
except TimeoutError as error:
|
||||
print(error)
|
||||
all_exec_results.append(("", "Timeout Error"))
|
||||
timeout_cnt += 1
|
||||
except Exception as error:
|
||||
print(error)
|
||||
exit()
|
||||
if progress_bar is not None:
|
||||
progress_bar.update(1)
|
||||
|
||||
if progress_bar is not None:
|
||||
progress_bar.close()
|
||||
|
||||
batch_results = []
|
||||
for code, (result, exec_info) in zip(all_code_snippets, all_exec_results):
|
||||
batch_results.append((result, exec_info))
|
||||
return batch_results
|
||||
|
||||
|
||||
def _test():
|
||||
batch_code = [
|
||||
"""
|
||||
print("Hello world!")
|
||||
"""
|
||||
]
|
||||
|
||||
executor = PythonExecutor(get_answer_from_stdout=True)
|
||||
predictions = executor.apply(batch_code[0])
|
||||
print(predictions)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test()
|
||||
Reference in New Issue
Block a user