DeepSeek-Prover-V1.5/prover/lean/verifier.py

156 lines
6.3 KiB
Python
Raw Normal View History

2024-08-16 03:33:21 +00:00
import os
import time
import json
import ctypes
import resource
import tempfile
import traceback
import threading
import subprocess
import multiprocessing as mp
from pprint import pprint
import numpy as np
from prover.lean.ast_parser import lean4_parser
from prover.workers import ProcessScheduler
from prover.utils import AttrDict
HOME_DIR = os.path.expanduser('~')
DEFAULT_LAKE_PATH = f'{HOME_DIR}/.elan/bin/lake'
DEFAULT_LEAN_WORKSPACE = 'mathlib4/'
def verify_lean4_file(code, lake_path=DEFAULT_LAKE_PATH, lean_workspace=DEFAULT_LEAN_WORKSPACE, last_env=None, verbose=False, timeout=300, allTactics=False, ast=False, premises=False, tactics=False):
command = dict(cmd=code, allTactics=allTactics, ast=ast, tactics=tactics, premises=premises)
if last_env is not None:
command.update(env=last_env)
message_str = json.dumps(command, ensure_ascii=False)
if verbose:
print(message_str)
start_time = time.time()
system_messages = ''
try:
with tempfile.TemporaryFile(mode='w+', encoding='utf-8') as temp_file:
temp_file.write(message_str + "\r\n\r\n")
temp_file.seek(0)
outputs = subprocess.run([lake_path, "exe", 'repl'], stdin=temp_file, capture_output=True, text=True, cwd=lean_workspace, timeout=timeout)
result = json.loads(outputs.stdout)
ast_results = lean4_parser(code, result['ast']) if 'ast' in result and result['ast'] else {}
result = {
"sorries" : result.get('sorries', []),
"tactics" : result.get('tactics', []),
"errors" : [m for m in result.get('messages', []) if m['severity'] == 'error'],
"warnings" : [m for m in result.get('messages', []) if m['severity'] == 'warning'],
"infos" : [m for m in result.get('messages', []) if m['severity'] == 'info'],
"system_messages" : system_messages,
"system_errors" : None,
"ast" : ast_results,
"verified_code" : code,
}
result['pass'] = not result['errors']
result['complete'] = result['pass'] and not result['sorries'] and not any("declaration uses 'sorry'" in warning['data'] or 'failed' in warning['data'] for warning in result['warnings'])
except:
result = {
"pass": False,
"complete": False,
"system_errors": traceback.format_exc(),
"system_messages": system_messages
}
result['verify_time'] = time.time() - start_time
return result
class Lean4ServerProcess(mp.Process):
def __init__(self, idx, task_queue, request_statuses, lock, extra_args=AttrDict()):
super().__init__()
self.idx = idx
self.task_queue = task_queue
self.request_statuses = request_statuses
self.lock = lock
self.extra_args = extra_args
self.timeout = extra_args.get('timeout', 300)
self.memory_limit = extra_args.get('memory_limit', -1)
self.last_output_time = mp.Value(ctypes.c_double, time.time())
self.complete_count = mp.Value(ctypes.c_int, 0)
def run(self):
if self.memory_limit > 0:
resource.setrlimit(
resource.RLIMIT_AS,
(self.memory_limit * (1000 ** 3), self.memory_limit * (1000 ** 3))
)
while True:
inputs = self.task_queue.get()
if inputs is None: # Terminate when receiving None
break
for _, request_id, task in inputs:
if isinstance(task, str):
task = dict(code=task)
if 'timeout' not in task:
task['timeout'] = self.timeout
result = verify_lean4_file(**task)
if len(result['system_messages']) > 0:
retry_start_time = time.time()
while ('lean::exception: failed to create thread' in result['system_messages'] or
'std::bad_alloc: std::bad_alloc' in result['system_messages'] or
'Cannot allocate memory' in result['system_messages']) \
and time.time() - retry_start_time < self.timeout:
time.sleep(0.1)
result = verify_lean4_file(**task)
with self.lock:
self.request_statuses[request_id] = result
self.last_output_time.value = time.time()
self.complete_count.value += 1
class Lean4ServerScheduler(ProcessScheduler):
def __init__(self, max_concurrent_requests=64, timeout=300, memory_limit=-1, name='verifier'):
super().__init__(batch_size=1, name=name)
self.processes = [
Lean4ServerProcess(
idx=idx,
task_queue=self.task_queue,
request_statuses=self.request_statuses,
lock=self.lock,
extra_args=AttrDict(
timeout=timeout,
memory_limit=memory_limit,
)
)
for idx in range(max_concurrent_requests)
]
for p in self.processes:
p.start()
print(f'Complete launching {len(self.processes)} LeanServerProcesses')
self.timeout = timeout
self._running_monitor = mp.Value(ctypes.c_bool, True)
self._last_complete_count = mp.Value(ctypes.c_int, 0)
self._monitor_process = mp.Process(target=self._monitor)
self._monitor_process.start()
def _monitor(self):
while self._running_monitor.value:
time.sleep(1.0)
subprocess.run(['killall', 'repl', f'--older-than={int(self.timeout) + 10}s'], capture_output=True)
def close(self):
super().close()
for p in self.processes:
p.join()
self._running_monitor.value = False
self._monitor_process.join()
print(f'All {len(self.processes)} LeanServerProcesses stopped')
if __name__ == '__main__':
code = open('mathlib4/.lake/packages/REPL/test/aime_1983_p9.code.in').read()
lean4_scheduler = Lean4ServerScheduler(max_concurrent_requests=1, timeout=300, memory_limit=10, name='verifier')
request_id_list = lean4_scheduler.submit_all_request([dict(code=code, ast=True, tactics=True)])
outputs_list = lean4_scheduler.get_all_request_outputs(request_id_list)
lean4_scheduler.close()
pprint(outputs_list)