import os import time import json import ctypes import resource import tempfile import traceback import threading import subprocess import multiprocessing as mp from pprint import pprint import numpy as np from prover.lean.ast_parser import lean4_parser from prover.workers import ProcessScheduler from prover.utils import AttrDict HOME_DIR = os.path.expanduser('~') DEFAULT_LAKE_PATH = f'{HOME_DIR}/.elan/bin/lake' DEFAULT_LEAN_WORKSPACE = 'mathlib4/' def verify_lean4_file(code, lake_path=DEFAULT_LAKE_PATH, lean_workspace=DEFAULT_LEAN_WORKSPACE, last_env=None, verbose=False, timeout=300, allTactics=False, ast=False, premises=False, tactics=False): command = dict(cmd=code, allTactics=allTactics, ast=ast, tactics=tactics, premises=premises) if last_env is not None: command.update(env=last_env) message_str = json.dumps(command, ensure_ascii=False) if verbose: print(message_str) start_time = time.time() system_messages = '' try: with tempfile.TemporaryFile(mode='w+', encoding='utf-8') as temp_file: temp_file.write(message_str + "\r\n\r\n") temp_file.seek(0) outputs = subprocess.run([lake_path, "exe", 'repl'], stdin=temp_file, capture_output=True, text=True, cwd=lean_workspace, timeout=timeout) result = json.loads(outputs.stdout) ast_results = lean4_parser(code, result['ast']) if 'ast' in result and result['ast'] else {} result = { "sorries" : result.get('sorries', []), "tactics" : result.get('tactics', []), "errors" : [m for m in result.get('messages', []) if m['severity'] == 'error'], "warnings" : [m for m in result.get('messages', []) if m['severity'] == 'warning'], "infos" : [m for m in result.get('messages', []) if m['severity'] == 'info'], "system_messages" : system_messages, "system_errors" : None, "ast" : ast_results, "verified_code" : code, } result['pass'] = not result['errors'] result['complete'] = result['pass'] and not result['sorries'] and not any("declaration uses 'sorry'" in warning['data'] or 'failed' in warning['data'] for warning in result['warnings']) except: result = { "pass": False, "complete": False, "system_errors": traceback.format_exc(), "system_messages": system_messages } result['verify_time'] = time.time() - start_time return result class Lean4ServerProcess(mp.Process): def __init__(self, idx, task_queue, request_statuses, lock, extra_args=AttrDict()): super().__init__() self.idx = idx self.task_queue = task_queue self.request_statuses = request_statuses self.lock = lock self.extra_args = extra_args self.timeout = extra_args.get('timeout', 300) self.memory_limit = extra_args.get('memory_limit', -1) self.last_output_time = mp.Value(ctypes.c_double, time.time()) self.complete_count = mp.Value(ctypes.c_int, 0) def run(self): if self.memory_limit > 0: resource.setrlimit( resource.RLIMIT_AS, (self.memory_limit * (1000 ** 3), self.memory_limit * (1000 ** 3)) ) while True: inputs = self.task_queue.get() if inputs is None: # Terminate when receiving None break for _, request_id, task in inputs: if isinstance(task, str): task = dict(code=task) if 'timeout' not in task: task['timeout'] = self.timeout result = verify_lean4_file(**task) if len(result['system_messages']) > 0: retry_start_time = time.time() while ('lean::exception: failed to create thread' in result['system_messages'] or 'std::bad_alloc: std::bad_alloc' in result['system_messages'] or 'Cannot allocate memory' in result['system_messages']) \ and time.time() - retry_start_time < self.timeout: time.sleep(0.1) result = verify_lean4_file(**task) with self.lock: self.request_statuses[request_id] = result self.last_output_time.value = time.time() self.complete_count.value += 1 class Lean4ServerScheduler(ProcessScheduler): def __init__(self, max_concurrent_requests=64, timeout=300, memory_limit=-1, name='verifier'): super().__init__(batch_size=1, name=name) self.processes = [ Lean4ServerProcess( idx=idx, task_queue=self.task_queue, request_statuses=self.request_statuses, lock=self.lock, extra_args=AttrDict( timeout=timeout, memory_limit=memory_limit, ) ) for idx in range(max_concurrent_requests) ] for p in self.processes: p.start() print(f'Complete launching {len(self.processes)} LeanServerProcesses') self.timeout = timeout self._running_monitor = mp.Value(ctypes.c_bool, True) self._last_complete_count = mp.Value(ctypes.c_int, 0) self._monitor_process = mp.Process(target=self._monitor) self._monitor_process.start() def _monitor(self): while self._running_monitor.value: time.sleep(1.0) subprocess.run(['killall', 'repl', f'--older-than={int(self.timeout) + 10}s'], capture_output=True) def close(self): super().close() for p in self.processes: p.join() self._running_monitor.value = False self._monitor_process.join() print(f'All {len(self.processes)} LeanServerProcesses stopped') if __name__ == '__main__': code = open('mathlib4/.lake/packages/REPL/test/aime_1983_p9.code.in').read() lean4_scheduler = Lean4ServerScheduler(max_concurrent_requests=1, timeout=300, memory_limit=10, name='verifier') request_id_list = lean4_scheduler.submit_all_request([dict(code=code, ast=True, tactics=True)]) outputs_list = lean4_scheduler.get_all_request_outputs(request_id_list) lean4_scheduler.close() pprint(outputs_list)