mirror of
https://github.com/deepseek-ai/DeepSeek-Prover-V1.5
synced 2024-11-21 19:07:44 +00:00
105 lines
3.7 KiB
Python
105 lines
3.7 KiB
Python
import os
|
|
import json
|
|
import pytz
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from collections import UserDict
|
|
from importlib.machinery import SourceFileLoader
|
|
from easydict import EasyDict as AttrDict
|
|
|
|
|
|
LEAN4_DEFAULT_HEADER = "import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n"
|
|
|
|
def non_cot_prompt(data):
|
|
return "Complete the following Lean 4 code:\n\n```lean4\n{header}{informal_prefix}{formal_statement}".format(
|
|
header=data.get('header', LEAN4_DEFAULT_HEADER),
|
|
informal_prefix=data.get('informal_prefix', str()),
|
|
formal_statement=data['formal_statement'],
|
|
)
|
|
|
|
def non_cot_few_shot_prompt(data):
|
|
return "Complete the following Lean 4 code:\n\n```lean4\n{header}{informal_prefix}{formal_statement}{formal_proof}\n```\n\n\n".format(
|
|
header=data.get('header', LEAN4_DEFAULT_HEADER),
|
|
informal_prefix=data.get('informal_prefix', str()),
|
|
formal_statement=data['formal_statement'],
|
|
formal_proof=data['formal_proof'],
|
|
)
|
|
|
|
def cot_prompt(data):
|
|
return "Complete the following Lean 4 code with explanatory comments preceding each line of code:\n\n```lean4\n{header}{informal_prefix}{formal_statement}".format(
|
|
header=data.get('header', LEAN4_DEFAULT_HEADER),
|
|
informal_prefix=data.get('informal_prefix', str()),
|
|
formal_statement=data['formal_statement'],
|
|
)
|
|
|
|
def cot_few_shot_prompt(data):
|
|
return "Complete the following Lean 4 code with explanatory comments preceding each line of code:\n\n```lean4\n{header}{informal_prefix}{formal_statement}{formal_proof}\n```\n\n\n".format(
|
|
header=data.get('header', LEAN4_DEFAULT_HEADER),
|
|
informal_prefix=data.get('informal_prefix', str()),
|
|
formal_statement=data['formal_statement'],
|
|
formal_proof=data['formal_proof'],
|
|
)
|
|
|
|
def post_process_output(output):
|
|
_find_idx = output.find("```")
|
|
return output[:_find_idx] if _find_idx >= 0 else output
|
|
|
|
MODEL_FORMAT = dict(
|
|
non_cot=dict(prompt=non_cot_prompt, output=post_process_output, few_shot=non_cot_few_shot_prompt),
|
|
cot=dict(prompt=cot_prompt, output=post_process_output, few_shot=cot_few_shot_prompt),
|
|
)
|
|
|
|
|
|
def get_datetime(readable=False):
|
|
if readable:
|
|
return datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y/%m/%d %H:%M:%S")
|
|
return datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y%m%d_%H%M%S")
|
|
|
|
def load_config(fname):
|
|
name = Path(fname).stem
|
|
mod = SourceFileLoader(name, fname).load_module()
|
|
|
|
config = {}
|
|
for n in dir(mod):
|
|
if not n.startswith("__"):
|
|
config[n] = getattr(mod, n)
|
|
config = AttrDict(config)
|
|
|
|
return config
|
|
|
|
def load_jsonl_objects(input_path):
|
|
objects = []
|
|
with open(input_path, 'r', encoding='utf-8') as fr:
|
|
for line in fr:
|
|
objects.append(json.loads(line))
|
|
return objects
|
|
|
|
|
|
class ConcurrentJob(object):
|
|
def __init__(self, stage_list):
|
|
assert len(stage_list) > 1
|
|
self.stage_list = stage_list
|
|
self.reset()
|
|
|
|
def is_idle(self):
|
|
return self._stage_idx is None
|
|
|
|
def reset(self):
|
|
self._stage_idx = None
|
|
self._stage_cache = None
|
|
|
|
def start(self, **kwargs):
|
|
self._stage_idx = 1
|
|
self._stage_cache = self.stage_list[0](**kwargs)
|
|
|
|
def get_status(self):
|
|
assert not self.is_idle()
|
|
while True:
|
|
status = self.stage_list[self._stage_idx](**self._stage_cache)
|
|
if status is None:
|
|
return None
|
|
self._stage_idx += 1
|
|
if self._stage_idx == len(self.stage_list):
|
|
self.reset()
|
|
return status
|
|
self._stage_cache = status |