import os import io from contextlib import redirect_stdout import pickle import regex import copy from typing import Any, Dict, Optional import multiprocess from pebble import ProcessPool from concurrent.futures import TimeoutError from functools import partial import traceback from timeout_decorator import timeout class GenericRuntime: GLOBAL_DICT = {} LOCAL_DICT = None HEADERS = [] def __init__(self): self._global_vars = copy.copy(self.GLOBAL_DICT) self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None for c in self.HEADERS: self.exec_code(c) def exec_code(self, code_piece: str) -> None: if regex.search(r'(\s|^)?input\(', code_piece) or regex.search(r'(\s|^)?os.system\(', code_piece): raise RuntimeError() exec(code_piece, self._global_vars) def eval_code(self, expr: str) -> Any: return eval(expr, self._global_vars) def inject(self, var_dict: Dict[str, Any]) -> None: for k, v in var_dict.items(): self._global_vars[k] = v @property def answer(self): return self._global_vars['answer'] class PythonExecutor: def __init__( self, runtime: Optional[Any] = None, get_answer_symbol: Optional[str] = None, get_answer_expr: Optional[str] = None, get_answer_from_stdout: bool = False, ) -> None: self.runtime = runtime if runtime else GenericRuntime() self.answer_symbol = get_answer_symbol self.answer_expr = get_answer_expr self.get_answer_from_stdout = get_answer_from_stdout def process_generation_to_code(self, gens: str): batch_code = [] for g in gens: multiline_comments = False code = [] for line in g.split('\n'): strip_line = line.strip() if strip_line.startswith("#"): line = line.split("#", 1)[0] + "# comments" elif not multiline_comments and strip_line.startswith('"""') and strip_line.endswith('"""') and len(strip_line) >= 6: line = line.split('"""', 1)[0] + '"""comments"""' elif not multiline_comments and strip_line.startswith('"""'): multiline_comments = True elif multiline_comments and strip_line.endswith('"""'): multiline_comments = False line = "" if not multiline_comments: code.append(line) batch_code.append(code) return batch_code @staticmethod def execute( code, get_answer_from_stdout = None, runtime = None, answer_symbol = None, answer_expr = None, timeout_length = 10, ): try: if get_answer_from_stdout: program_io = io.StringIO() with redirect_stdout(program_io): timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) program_io.seek(0) result = "".join(program_io.readlines()) # [-1] elif answer_symbol: timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) result = runtime._global_vars[answer_symbol] elif answer_expr: timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) result = timeout(timeout_length)(runtime.eval_code)(answer_expr) else: timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1])) result = timeout(timeout_length)(runtime.eval_code)(code[-1]) concise_exec_info = "" exec_info = "" str(result) pickle.dumps(result) # serialization check except: # traceback.print_exc() result = '' concise_exec_info = traceback.format_exc().split('\n')[-2] exec_info = traceback.format_exc() if get_answer_from_stdout and 'exec(code_piece, self._global_vars)' in exec_info: exec_info = exec_info.split('exec(code_piece, self._global_vars)')[-1].strip() msg = [] for line in exec_info.split("\n"): patt = regex.search(r'(?P.*)File "(?P.*)", line (?P\d+), (?P.*)', line) if patt is not None: if '' in patt.group('end'): continue fname = patt.group("file") if "site-packages" in fname: fname = f"site-packages{fname.split('site-packages', 1)[1]}" line = f'{patt.group("start")}File "{fname}", {patt.group("end")}' else: line = f'{patt.group("start")}{patt.group("end")}' else: patt = regex.search(r'(?P.*)(?P/.*site-packages/.*\.py)(?P.*)', line) if patt is not None: line = f'{patt.group("start")}site-packages{patt.group("file").split("site-packages", 1)[1]}{patt.group("end")}' msg.append(line) exec_info = "\n".join(msg) return result, concise_exec_info, exec_info def apply(self, code): return self.batch_apply([code])[0] def batch_apply(self, batch_code): all_code_snippets = self.process_generation_to_code(batch_code) all_exec_results = [] executor = partial( self.execute, get_answer_from_stdout=self.get_answer_from_stdout, runtime=self.runtime, answer_symbol=self.answer_symbol, answer_expr=self.answer_expr, timeout_length=10, ) with ProcessPool(max_workers=multiprocess.cpu_count()) as pool: iterator = pool.map(executor, all_code_snippets, timeout=10).result() while True: try: result = next(iterator) all_exec_results.append(result) except StopIteration: break except TimeoutError as error: all_exec_results.append(("", "Timeout Error", "Timeout Error")) except Exception as error: print(error) exit() batch_results = [] for code, (result, concise_exec_info, exec_info) in zip(all_code_snippets, all_exec_results): metadata = {'code': code, 'exec_result': result, 'concise_exec_info': concise_exec_info, 'exec_info': exec_info} batch_results.append((result, metadata)) return batch_results