DeepSeek-Prover-V1.5/prover/utils.py

105 lines
3.7 KiB
Python
Raw Permalink Normal View History

2024-08-16 03:33:21 +00:00
import os
import json
import pytz
from pathlib import Path
from datetime import datetime
from collections import UserDict
from importlib.machinery import SourceFileLoader
from easydict import EasyDict as AttrDict
LEAN4_DEFAULT_HEADER = "import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n"
def non_cot_prompt(data):
return "Complete the following Lean 4 code:\n\n```lean4\n{header}{informal_prefix}{formal_statement}".format(
header=data.get('header', LEAN4_DEFAULT_HEADER),
informal_prefix=data.get('informal_prefix', str()),
formal_statement=data['formal_statement'],
)
def non_cot_few_shot_prompt(data):
return "Complete the following Lean 4 code:\n\n```lean4\n{header}{informal_prefix}{formal_statement}{formal_proof}\n```\n\n\n".format(
header=data.get('header', LEAN4_DEFAULT_HEADER),
informal_prefix=data.get('informal_prefix', str()),
formal_statement=data['formal_statement'],
formal_proof=data['formal_proof'],
)
def cot_prompt(data):
return "Complete the following Lean 4 code with explanatory comments preceding each line of code:\n\n```lean4\n{header}{informal_prefix}{formal_statement}".format(
header=data.get('header', LEAN4_DEFAULT_HEADER),
informal_prefix=data.get('informal_prefix', str()),
formal_statement=data['formal_statement'],
)
def cot_few_shot_prompt(data):
return "Complete the following Lean 4 code with explanatory comments preceding each line of code:\n\n```lean4\n{header}{informal_prefix}{formal_statement}{formal_proof}\n```\n\n\n".format(
header=data.get('header', LEAN4_DEFAULT_HEADER),
informal_prefix=data.get('informal_prefix', str()),
formal_statement=data['formal_statement'],
formal_proof=data['formal_proof'],
)
def post_process_output(output):
_find_idx = output.find("```")
return output[:_find_idx] if _find_idx >= 0 else output
MODEL_FORMAT = dict(
non_cot=dict(prompt=non_cot_prompt, output=post_process_output, few_shot=non_cot_few_shot_prompt),
cot=dict(prompt=cot_prompt, output=post_process_output, few_shot=cot_few_shot_prompt),
)
def get_datetime(readable=False):
if readable:
return datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y/%m/%d %H:%M:%S")
return datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y%m%d_%H%M%S")
def load_config(fname):
name = Path(fname).stem
mod = SourceFileLoader(name, fname).load_module()
config = {}
for n in dir(mod):
if not n.startswith("__"):
config[n] = getattr(mod, n)
config = AttrDict(config)
return config
def load_jsonl_objects(input_path):
objects = []
with open(input_path, 'r', encoding='utf-8') as fr:
for line in fr:
objects.append(json.loads(line))
return objects
class ConcurrentJob(object):
def __init__(self, stage_list):
assert len(stage_list) > 1
self.stage_list = stage_list
self.reset()
def is_idle(self):
return self._stage_idx is None
def reset(self):
self._stage_idx = None
self._stage_cache = None
def start(self, **kwargs):
self._stage_idx = 1
self._stage_cache = self.stage_list[0](**kwargs)
def get_status(self):
assert not self.is_idle()
while True:
status = self.stage_list[self._stage_idx](**self._stage_cache)
if status is None:
return None
self._stage_idx += 1
if self._stage_idx == len(self.stage_list):
self.reset()
return status
self._stage_cache = status