mirror of
https://github.com/deepseek-ai/DeepSeek-Prover-V1.5
synced 2024-11-24 21:14:11 +00:00
184 lines
8.1 KiB
Python
184 lines
8.1 KiB
Python
import re
|
|
import json
|
|
|
|
import numpy as np
|
|
|
|
from prover.utils import AttrDict, LEAN4_DEFAULT_HEADER
|
|
|
|
|
|
class Proof(object):
|
|
def __init__(self, full_code, _args, _result_backup=None, **kwargs):
|
|
self._kwargs_backup = kwargs
|
|
for key, val in kwargs.items():
|
|
self.__setattr__(key, val)
|
|
self._args = _args
|
|
self._update_full_code(full_code, _result_backup=_result_backup)
|
|
|
|
@property
|
|
def result(self):
|
|
if self._verifier_request_id is not None:
|
|
self._result = self._scheduler.verifier_get_request_outputs(self._verifier_request_id)
|
|
self._verifier_request_id = None
|
|
return self._result
|
|
|
|
def is_result_ready(self):
|
|
if self._verifier_request_id is None:
|
|
return True
|
|
status = self._scheduler.verifier_get_request_status(self._verifier_request_id)
|
|
if status is not None:
|
|
self._result = status
|
|
self._verifier_request_id = None
|
|
return self._result is not None
|
|
|
|
@property
|
|
def cleaned_code(self):
|
|
return self.full_code[len(self.header) + len(self.formal_statement): len(self.full_code) - len(self.tailer)]
|
|
|
|
def _update_full_code(self, full_code, _result_backup=None):
|
|
self.full_code = full_code
|
|
self._verifier_request_id, self._result = None, None
|
|
if _result_backup is not None:
|
|
self._result = _result_backup
|
|
elif self._args.require_verification: # need to call verification server
|
|
self._verifier_request_id = self._scheduler.verifier_submit_request(dict(
|
|
code=self.full_code,
|
|
ast=True, tactics=True,
|
|
))
|
|
self._parse_full_code_lines()
|
|
|
|
def _parse_full_code_lines(self):
|
|
self._full_code_lines = self.full_code.split('\n')
|
|
self._line_offset, _offset = [], -1
|
|
for _line in self._full_code_lines:
|
|
_offset += 1 # '\n'
|
|
self._line_offset.append(_offset)
|
|
_offset += len(_line)
|
|
|
|
def _get_idx(self, pos_info):
|
|
return self._line_offset[pos_info['line'] - 1] + pos_info['column']
|
|
|
|
def segmentation(self, result=None):
|
|
if result is None:
|
|
result = self.result
|
|
if 'errors' not in result:
|
|
# compiler timeout
|
|
return []
|
|
|
|
_prefix_len = len(self.header) + len(self.formal_statement)
|
|
truncate_pos = len(self.full_code) - len(self.tailer)
|
|
for info in result['sorries'] + result['errors']:
|
|
info_pos = self._get_idx(info['pos'])
|
|
if info_pos >= _prefix_len and not info.get('data', str()).lstrip().startswith('unsolved goals'):
|
|
truncate_pos = min(truncate_pos, info_pos)
|
|
partial_code = self.full_code[:truncate_pos]
|
|
if len(partial_code) <= _prefix_len:
|
|
# all proof lines are invalid
|
|
return []
|
|
|
|
code_lines = partial_code.split('\n')
|
|
pos_last, segments = _prefix_len, []
|
|
for line_idx in range(len(code_lines)):
|
|
if self._line_offset[line_idx] >= _prefix_len:
|
|
def compute_last_valid_char_pos(line):
|
|
idx, last_non_blank = 0, len(line) + 1
|
|
while idx < len(line):
|
|
if line[idx: idx+2] == '--':
|
|
return last_non_blank
|
|
elif line[idx: idx+2] == '/-':
|
|
if '-/' not in line[idx+2:]:
|
|
# cannot split in this line
|
|
return len(line) + 1
|
|
idx = line.find('-/', idx+2) + 1
|
|
elif line[idx] != ' ':
|
|
last_non_blank = idx
|
|
idx += 1
|
|
return last_non_blank
|
|
line_lastChar = self._line_offset[line_idx] + compute_last_valid_char_pos(code_lines[line_idx])
|
|
line_endPos = self._line_offset[line_idx] + len(code_lines[line_idx])
|
|
|
|
pos_min, goal = 1e9, None
|
|
for tactic_info in result['ast']['tactics']:
|
|
pos, endPos = tactic_info['pos'], tactic_info['endPos']
|
|
if line_lastChar <= endPos and endPos <= line_endPos and pos < pos_min:
|
|
pos_min = pos
|
|
goal = tactic_info['stateAfter']
|
|
if goal is not None:
|
|
for tactic_info in result['ast']['tactics']:
|
|
pos, endPos = tactic_info['pos'], tactic_info['endPos']
|
|
if pos_last < endPos and endPos <= line_endPos and pos < pos_min:
|
|
pos_min = pos
|
|
|
|
while pos_min > 0 and partial_code[pos_min - 1] != '\n':
|
|
pos_min -= 1
|
|
indent_len = 0
|
|
while partial_code[pos_min + indent_len] == ' ':
|
|
indent_len += 1
|
|
newline_with_indent = '\n' + ' ' * indent_len
|
|
|
|
segments.append(AttrDict(
|
|
tactic_code=partial_code[pos_last: line_endPos] + '\n',
|
|
state_comment=newline_with_indent.join([
|
|
' ' * indent_len + '/- tactic state:',
|
|
' ' + goal.replace('\n', newline_with_indent + ' '),
|
|
'-/\n'
|
|
]),
|
|
goal=goal,
|
|
indent=indent_len,
|
|
))
|
|
pos_last = line_endPos + 1
|
|
if result['complete'] and (len(segments) == 0 or segments[-1].goal != 'no goals' or segments[-1].indent != segments[0].indent):
|
|
indent_len = 2 if len(segments) == 0 else segments[0].indent
|
|
newline_with_indent = '\n' + ' ' * indent_len
|
|
segments.append(AttrDict(
|
|
tactic_code=partial_code[pos_last:].rstrip(' \n') + '\n',
|
|
state_comment=newline_with_indent.join([
|
|
' ' * indent_len + '/- tactic state:',
|
|
' no goals',
|
|
'-/\n'
|
|
]),
|
|
goal='no goals',
|
|
indent=indent_len,
|
|
))
|
|
segments = [seg for seg in segments if len(seg.tactic_code.strip(' \n')) > 0]
|
|
return segments
|
|
|
|
|
|
class ProofSummarizer(object):
|
|
def __init__(self, data, scheduler=None):
|
|
"""
|
|
Inputs:
|
|
data (`dict`): The problem information storing in a `dict` object.
|
|
formal_statement (`str`): The formal statement of the unproved problem;
|
|
header (`str`, *optional*, defaults to ''): The code header required by the complier;
|
|
tailer (`str`, *optional*, defaults to ''): The code tailer required by the complier.
|
|
scheduler (prover.workers.scheduler.Scheduler, *optional*, defaults to None):
|
|
An interface to submit requests to models and the verification server.
|
|
If set to None, the downstream tasks may require the verification result as inputs.
|
|
"""
|
|
self.formal_statement = data['formal_statement']
|
|
self.header = data.get('header', LEAN4_DEFAULT_HEADER)
|
|
self.tailer = data.get('tailer', str())
|
|
self.scheduler = scheduler
|
|
|
|
def analyze(self, code, require_verification=True):
|
|
"""
|
|
Inputs:
|
|
code (`str`): The code of formal proof.
|
|
require_verification (`bool`, *optional*, defaults to True):
|
|
Whether to submit a request to the verification server.
|
|
If set to False, the downstream tasks may require the verification result as inputs.
|
|
Return:
|
|
A `Proof` object that summarizes the code.
|
|
"""
|
|
return Proof(
|
|
full_code=''.join([self.header, self.formal_statement, code.rstrip(' \n'), self.tailer]),
|
|
raw_code=code,
|
|
formal_statement=self.formal_statement,
|
|
header=self.header,
|
|
tailer=self.tailer,
|
|
_scheduler=self.scheduler,
|
|
_args=AttrDict(
|
|
require_verification=require_verification,
|
|
)
|
|
)
|