mirror of
https://github.com/deepseek-ai/DeepSeek-Prover-V1.5
synced 2024-11-22 03:17:43 +00:00
156 lines
6.3 KiB
Python
156 lines
6.3 KiB
Python
|
import os
|
||
|
import time
|
||
|
import json
|
||
|
import ctypes
|
||
|
import resource
|
||
|
import tempfile
|
||
|
import traceback
|
||
|
import threading
|
||
|
import subprocess
|
||
|
import multiprocessing as mp
|
||
|
from pprint import pprint
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from prover.lean.ast_parser import lean4_parser
|
||
|
from prover.workers import ProcessScheduler
|
||
|
from prover.utils import AttrDict
|
||
|
|
||
|
|
||
|
HOME_DIR = os.path.expanduser('~')
|
||
|
DEFAULT_LAKE_PATH = f'{HOME_DIR}/.elan/bin/lake'
|
||
|
DEFAULT_LEAN_WORKSPACE = 'mathlib4/'
|
||
|
|
||
|
|
||
|
def verify_lean4_file(code, lake_path=DEFAULT_LAKE_PATH, lean_workspace=DEFAULT_LEAN_WORKSPACE, last_env=None, verbose=False, timeout=300, allTactics=False, ast=False, premises=False, tactics=False):
|
||
|
command = dict(cmd=code, allTactics=allTactics, ast=ast, tactics=tactics, premises=premises)
|
||
|
if last_env is not None:
|
||
|
command.update(env=last_env)
|
||
|
message_str = json.dumps(command, ensure_ascii=False)
|
||
|
if verbose:
|
||
|
print(message_str)
|
||
|
start_time = time.time()
|
||
|
system_messages = ''
|
||
|
try:
|
||
|
with tempfile.TemporaryFile(mode='w+', encoding='utf-8') as temp_file:
|
||
|
temp_file.write(message_str + "\r\n\r\n")
|
||
|
temp_file.seek(0)
|
||
|
outputs = subprocess.run([lake_path, "exe", 'repl'], stdin=temp_file, capture_output=True, text=True, cwd=lean_workspace, timeout=timeout)
|
||
|
result = json.loads(outputs.stdout)
|
||
|
ast_results = lean4_parser(code, result['ast']) if 'ast' in result and result['ast'] else {}
|
||
|
result = {
|
||
|
"sorries" : result.get('sorries', []),
|
||
|
"tactics" : result.get('tactics', []),
|
||
|
"errors" : [m for m in result.get('messages', []) if m['severity'] == 'error'],
|
||
|
"warnings" : [m for m in result.get('messages', []) if m['severity'] == 'warning'],
|
||
|
"infos" : [m for m in result.get('messages', []) if m['severity'] == 'info'],
|
||
|
"system_messages" : system_messages,
|
||
|
"system_errors" : None,
|
||
|
"ast" : ast_results,
|
||
|
"verified_code" : code,
|
||
|
}
|
||
|
result['pass'] = not result['errors']
|
||
|
result['complete'] = result['pass'] and not result['sorries'] and not any("declaration uses 'sorry'" in warning['data'] or 'failed' in warning['data'] for warning in result['warnings'])
|
||
|
except:
|
||
|
result = {
|
||
|
"pass": False,
|
||
|
"complete": False,
|
||
|
"system_errors": traceback.format_exc(),
|
||
|
"system_messages": system_messages
|
||
|
}
|
||
|
result['verify_time'] = time.time() - start_time
|
||
|
return result
|
||
|
|
||
|
|
||
|
class Lean4ServerProcess(mp.Process):
|
||
|
def __init__(self, idx, task_queue, request_statuses, lock, extra_args=AttrDict()):
|
||
|
super().__init__()
|
||
|
self.idx = idx
|
||
|
self.task_queue = task_queue
|
||
|
self.request_statuses = request_statuses
|
||
|
self.lock = lock
|
||
|
self.extra_args = extra_args
|
||
|
|
||
|
self.timeout = extra_args.get('timeout', 300)
|
||
|
self.memory_limit = extra_args.get('memory_limit', -1)
|
||
|
self.last_output_time = mp.Value(ctypes.c_double, time.time())
|
||
|
self.complete_count = mp.Value(ctypes.c_int, 0)
|
||
|
|
||
|
def run(self):
|
||
|
if self.memory_limit > 0:
|
||
|
resource.setrlimit(
|
||
|
resource.RLIMIT_AS,
|
||
|
(self.memory_limit * (1000 ** 3), self.memory_limit * (1000 ** 3))
|
||
|
)
|
||
|
while True:
|
||
|
inputs = self.task_queue.get()
|
||
|
if inputs is None: # Terminate when receiving None
|
||
|
break
|
||
|
for _, request_id, task in inputs:
|
||
|
if isinstance(task, str):
|
||
|
task = dict(code=task)
|
||
|
if 'timeout' not in task:
|
||
|
task['timeout'] = self.timeout
|
||
|
result = verify_lean4_file(**task)
|
||
|
if len(result['system_messages']) > 0:
|
||
|
retry_start_time = time.time()
|
||
|
while ('lean::exception: failed to create thread' in result['system_messages'] or
|
||
|
'std::bad_alloc: std::bad_alloc' in result['system_messages'] or
|
||
|
'Cannot allocate memory' in result['system_messages']) \
|
||
|
and time.time() - retry_start_time < self.timeout:
|
||
|
time.sleep(0.1)
|
||
|
result = verify_lean4_file(**task)
|
||
|
with self.lock:
|
||
|
self.request_statuses[request_id] = result
|
||
|
self.last_output_time.value = time.time()
|
||
|
self.complete_count.value += 1
|
||
|
|
||
|
|
||
|
class Lean4ServerScheduler(ProcessScheduler):
|
||
|
def __init__(self, max_concurrent_requests=64, timeout=300, memory_limit=-1, name='verifier'):
|
||
|
super().__init__(batch_size=1, name=name)
|
||
|
|
||
|
self.processes = [
|
||
|
Lean4ServerProcess(
|
||
|
idx=idx,
|
||
|
task_queue=self.task_queue,
|
||
|
request_statuses=self.request_statuses,
|
||
|
lock=self.lock,
|
||
|
extra_args=AttrDict(
|
||
|
timeout=timeout,
|
||
|
memory_limit=memory_limit,
|
||
|
)
|
||
|
)
|
||
|
for idx in range(max_concurrent_requests)
|
||
|
]
|
||
|
for p in self.processes:
|
||
|
p.start()
|
||
|
print(f'Complete launching {len(self.processes)} LeanServerProcesses')
|
||
|
|
||
|
self.timeout = timeout
|
||
|
self._running_monitor = mp.Value(ctypes.c_bool, True)
|
||
|
self._last_complete_count = mp.Value(ctypes.c_int, 0)
|
||
|
self._monitor_process = mp.Process(target=self._monitor)
|
||
|
self._monitor_process.start()
|
||
|
|
||
|
def _monitor(self):
|
||
|
while self._running_monitor.value:
|
||
|
time.sleep(1.0)
|
||
|
subprocess.run(['killall', 'repl', f'--older-than={int(self.timeout) + 10}s'], capture_output=True)
|
||
|
|
||
|
def close(self):
|
||
|
super().close()
|
||
|
for p in self.processes:
|
||
|
p.join()
|
||
|
self._running_monitor.value = False
|
||
|
self._monitor_process.join()
|
||
|
print(f'All {len(self.processes)} LeanServerProcesses stopped')
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
code = open('mathlib4/.lake/packages/REPL/test/aime_1983_p9.code.in').read()
|
||
|
lean4_scheduler = Lean4ServerScheduler(max_concurrent_requests=1, timeout=300, memory_limit=10, name='verifier')
|
||
|
request_id_list = lean4_scheduler.submit_all_request([dict(code=code, ast=True, tactics=True)])
|
||
|
outputs_list = lean4_scheduler.get_all_request_outputs(request_id_list)
|
||
|
lean4_scheduler.close()
|
||
|
pprint(outputs_list)
|