mirror of
https://github.com/deepseek-ai/DeepSeek-Prover-V1.5
synced 2024-11-22 03:17:43 +00:00
321 lines
13 KiB
Python
321 lines
13 KiB
Python
|
import os
|
||
|
import gc
|
||
|
import time
|
||
|
import math
|
||
|
import random
|
||
|
import pickle
|
||
|
import subprocess
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from .base import SamplingAlgorithmBase
|
||
|
from prover.lean.proof import ProofSummarizer
|
||
|
from prover.utils import ConcurrentJob
|
||
|
|
||
|
|
||
|
class TreeNode(object):
|
||
|
def __init__(self, parent=None, code=None, **kwargs):
|
||
|
self.parent = parent
|
||
|
self.children = dict()
|
||
|
self._info = {key: val for key, val in kwargs.items()}
|
||
|
if code is not None:
|
||
|
self.update_code(code)
|
||
|
|
||
|
if '_discounted_rewards' not in self._info:
|
||
|
self._info['_discounted_rewards'] = 0.0
|
||
|
if '_discounted_visitation' not in self._info:
|
||
|
self._info['_discounted_visitation'] = 0.0
|
||
|
if '_subtree_discounted_rewards' not in self._info:
|
||
|
self._info['_subtree_discounted_rewards'] = 0.0
|
||
|
if '_subtree_discounted_visitation' not in self._info:
|
||
|
self._info['_subtree_discounted_visitation'] = 0.0
|
||
|
self._num_running_jobs = 0
|
||
|
self._subtree_num_running_jobs = 0
|
||
|
self._update_value(gamma=0.0) # gamma=0.0 is okay for initialization
|
||
|
|
||
|
# basic tree supports
|
||
|
@property
|
||
|
def code(self):
|
||
|
return random.choice(self._info['_code_list'])
|
||
|
|
||
|
def update_code(self, code):
|
||
|
if '_code_list' not in self._info:
|
||
|
self._info['_code_list'] = []
|
||
|
if code not in self._info['_code_list']:
|
||
|
self._info['_code_list'].append(code)
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
return self._info[key]
|
||
|
|
||
|
def to_node_list(self):
|
||
|
return sum([child.to_node_list() for _, child in self.children.items()], start=[self])
|
||
|
|
||
|
def to_dict(self):
|
||
|
return dict(
|
||
|
info=self._info,
|
||
|
children={
|
||
|
edge: child_node.to_dict()
|
||
|
for edge, child_node in self.children.items()
|
||
|
}
|
||
|
)
|
||
|
|
||
|
@classmethod
|
||
|
def from_dict(cls, dict_data, parent=None):
|
||
|
node = cls(
|
||
|
parent=parent,
|
||
|
**dict_data['info'],
|
||
|
)
|
||
|
node.children = {
|
||
|
edge: cls.from_dict(child_dict, parent=node)
|
||
|
for edge, child_dict in dict_data['children'].items()
|
||
|
}
|
||
|
return node
|
||
|
|
||
|
# algorithm supports
|
||
|
def update_reward(self, reward, gamma, first_node=True):
|
||
|
if first_node:
|
||
|
self._info['_discounted_rewards'] = self._info['_discounted_rewards'] * gamma + reward
|
||
|
self._info['_discounted_visitation'] = self._info['_discounted_visitation'] * gamma + 1.0
|
||
|
self._info['_subtree_discounted_rewards'] = self._info['_subtree_discounted_rewards'] * gamma + reward
|
||
|
self._info['_subtree_discounted_visitation'] = self._info['_subtree_discounted_visitation'] * gamma + 1.0
|
||
|
self._update_value(gamma)
|
||
|
if self.parent is not None:
|
||
|
self.parent.update_reward(reward, gamma, first_node=False)
|
||
|
|
||
|
def start_new_job(self, gamma, first_node=True):
|
||
|
if first_node:
|
||
|
self._num_running_jobs += 1
|
||
|
self._subtree_num_running_jobs += 1
|
||
|
self._update_value(gamma)
|
||
|
if self.parent is not None:
|
||
|
self.parent.start_new_job(gamma, first_node=False)
|
||
|
|
||
|
def complete_job(self, gamma, first_node=True):
|
||
|
if first_node:
|
||
|
self._num_running_jobs -= 1
|
||
|
self._subtree_num_running_jobs -= 1
|
||
|
self._update_value(gamma)
|
||
|
if self.parent is not None:
|
||
|
self.parent.complete_job(gamma, first_node=False)
|
||
|
|
||
|
def _update_value(self, gamma):
|
||
|
discounted_rewards = self._info['_discounted_rewards'] * (gamma ** self._num_running_jobs)
|
||
|
discounted_visitation = \
|
||
|
self._info['_discounted_visitation'] * (gamma ** self._num_running_jobs) \
|
||
|
+ (1.0 - (gamma ** self._num_running_jobs)) / (1.0 - gamma)
|
||
|
self.value = discounted_rewards / max(discounted_visitation, 1e-2)
|
||
|
self.visitation = discounted_visitation
|
||
|
|
||
|
subtree_discounted_rewards = self._info['_subtree_discounted_rewards'] * (gamma ** self._subtree_num_running_jobs)
|
||
|
subtree_discounted_visitation = \
|
||
|
self._info['_subtree_discounted_visitation'] * (gamma ** self._subtree_num_running_jobs) \
|
||
|
+ (1.0 - (gamma ** self._subtree_num_running_jobs)) / (1.0 - gamma)
|
||
|
self.subtree_value = subtree_discounted_rewards / max(subtree_discounted_visitation, 1e-2)
|
||
|
self.subtree_visitation = subtree_discounted_visitation
|
||
|
|
||
|
|
||
|
class RMaxTS(SamplingAlgorithmBase):
|
||
|
def __init__(self, **kwargs):
|
||
|
super().__init__(**kwargs)
|
||
|
self.gamma = self.cfg.get('gamma', 0.99)
|
||
|
self.sample_num = self.cfg.get('sample_num', 6400)
|
||
|
self.concurrent_num = self.cfg.get('concurrent_num', 32)
|
||
|
self.tactic_state_comment = self.cfg.get('tactic_state_comment', True)
|
||
|
self.ckpt_interval = self.cfg.get('ckpt_interval', 128)
|
||
|
|
||
|
self.ckpt_filename = 'checkpoint.pkl'
|
||
|
self.node_cls = TreeNode
|
||
|
self.algorithm_pipeline = [
|
||
|
self._tactic_tree_generate_proof,
|
||
|
self._tactic_tree_parse_proof,
|
||
|
self._rmax_exploration_summarize_results,
|
||
|
]
|
||
|
|
||
|
# basic supports
|
||
|
def _save_ckpt(self, ckpt_dict: dict):
|
||
|
# save a backup before overwriting the checkpoint file
|
||
|
if os.path.exists(self.ckpt_path):
|
||
|
subprocess.run(['cp', self.ckpt_path, self.ckpt_path + '.backup'])
|
||
|
# overwrite the checkpoint file
|
||
|
with open(self.ckpt_path, 'wb') as pkl_f:
|
||
|
pickle.dump(ckpt_dict, pkl_f)
|
||
|
|
||
|
# tree structure supports
|
||
|
def _tree_setup(self, data):
|
||
|
# initialize tree
|
||
|
ckpt_info = None
|
||
|
for _ckpt_path in [self.ckpt_path, self.ckpt_path + '.backup']:
|
||
|
if os.path.exists(_ckpt_path):
|
||
|
try:
|
||
|
with open(_ckpt_path, 'rb') as pkl_f:
|
||
|
ckpt_info = pickle.load(pkl_f)
|
||
|
except:
|
||
|
self.process_print(f'Checkpoint saved at {_ckpt_path} is broken.')
|
||
|
if ckpt_info is not None:
|
||
|
root = self.node_cls.from_dict(ckpt_info['root'])
|
||
|
sample_count = ckpt_info['sample_count']
|
||
|
yield_cache = ckpt_info['yield_cache']
|
||
|
self.process_print(f'Load checkpoint from sample_count={sample_count}')
|
||
|
else:
|
||
|
root = self.node_cls(code=dict(tactic_code=str(), state_comment=str()), depth=0)
|
||
|
sample_count = 0
|
||
|
yield_cache = []
|
||
|
|
||
|
# compile the root node with `sorry`
|
||
|
self.proof_summarizer = ProofSummarizer(data=data, scheduler=self.scheduler)
|
||
|
root_sorry = self.proof_summarizer.analyze(' sorry', require_verification=True)
|
||
|
assert root_sorry.result['pass'], "Cannot parse a `sorry` tactic on root."
|
||
|
self.root_goal = root_sorry.result['sorries'][-1]['goal']
|
||
|
|
||
|
# other initialization
|
||
|
self._last_selected_node = root
|
||
|
|
||
|
return root, sample_count, yield_cache
|
||
|
|
||
|
def _tree_new_child(self, parent):
|
||
|
return self.node_cls(
|
||
|
parent=parent,
|
||
|
depth=parent['depth'] + 1,
|
||
|
)
|
||
|
|
||
|
def _tree_step(self, node, edge, code):
|
||
|
if edge not in node.children:
|
||
|
new_node = self._tree_new_child(node)
|
||
|
node.children[edge] = new_node
|
||
|
self.node_list.append(new_node)
|
||
|
child_node = node.children[edge]
|
||
|
child_node.update_code(code)
|
||
|
return child_node
|
||
|
|
||
|
def _tree_update(self, proof):
|
||
|
node_walk, partial_code = self.root, str()
|
||
|
|
||
|
# use tactic goals as tree edges
|
||
|
segments = proof.segmentation()
|
||
|
prev_goal = self.root_goal
|
||
|
for info in segments:
|
||
|
partial_code += info.tactic_code
|
||
|
code = partial_code + info.state_comment if self.tactic_state_comment else partial_code
|
||
|
if self._encode_length(code) < self.max_tokens:
|
||
|
if info.goal != prev_goal:
|
||
|
node_walk = self._tree_step(
|
||
|
node=node_walk, edge=info.goal,
|
||
|
code=dict(tactic_code=partial_code, state_comment=info.state_comment)
|
||
|
)
|
||
|
prev_goal = info.goal
|
||
|
return node_walk
|
||
|
|
||
|
# algorithm pipeline
|
||
|
def _select_node(self):
|
||
|
node = self.root
|
||
|
while len(node.children) > 0:
|
||
|
num_choice = 1 + len(node.children)
|
||
|
total_visitation = node.visitation + np.sum([child.subtree_visitation for _, child in node.children.items()])
|
||
|
def hoeffding_ucb(node_visitation):
|
||
|
return math.sqrt(2.0 * math.log(max(total_visitation, 2)) / max(node_visitation, 1e-2))
|
||
|
choice_list = [(node.value + hoeffding_ucb(node.visitation), None)]
|
||
|
for _, child in node.children.items():
|
||
|
choice_list.append((child.subtree_value + hoeffding_ucb(child.subtree_visitation), child))
|
||
|
choice_list.sort(reverse=True, key=lambda x: x[0])
|
||
|
if choice_list[0][1] is None:
|
||
|
node.start_new_job(gamma=self.gamma)
|
||
|
return node
|
||
|
else:
|
||
|
node = choice_list[0][1]
|
||
|
node.start_new_job(gamma=self.gamma)
|
||
|
return node
|
||
|
|
||
|
def _tactic_tree_generate_proof(self, data, node):
|
||
|
code_prefix = node.code
|
||
|
extra_prompt = code_prefix['tactic_code']
|
||
|
if self.tactic_state_comment:
|
||
|
extra_prompt += code_prefix['state_comment']
|
||
|
return dict(
|
||
|
node=node,
|
||
|
code_prefix=code_prefix,
|
||
|
generator_request_id=self.scheduler.generator_submit_request(
|
||
|
self._preprocess_data({**data, '_extra_prompt': extra_prompt}),
|
||
|
),
|
||
|
)
|
||
|
|
||
|
def _tactic_tree_parse_proof(self, node, code_prefix, generator_request_id):
|
||
|
code = self.scheduler.generator_get_request_status(generator_request_id)
|
||
|
if code is None:
|
||
|
return None
|
||
|
code = code_prefix['tactic_code'] + code
|
||
|
proof = self.proof_summarizer.analyze(code, require_verification=True)
|
||
|
return dict(node=node, proof=proof)
|
||
|
|
||
|
def _rmax_exploration_summarize_results(self, node, proof):
|
||
|
if not proof.is_result_ready():
|
||
|
return None
|
||
|
|
||
|
num_nodes_before = len(self.node_list)
|
||
|
self._tree_update(proof)
|
||
|
# RMax reward
|
||
|
node.update_reward(int(len(self.node_list) > num_nodes_before), gamma=self.gamma)
|
||
|
node.complete_job(gamma=self.gamma)
|
||
|
|
||
|
return dict(
|
||
|
code=proof.cleaned_code,
|
||
|
result=proof.result,
|
||
|
)
|
||
|
|
||
|
# sampler interface
|
||
|
def sample(self, data, prob_log_dir, **kwargs):
|
||
|
self.ckpt_path = os.path.join(prob_log_dir, self.ckpt_filename)
|
||
|
self.root, sample_count, yield_cache = self._tree_setup(data)
|
||
|
self.node_list = self.root.to_node_list()
|
||
|
for _proposal, _sample_info in yield_cache:
|
||
|
yield _proposal, _sample_info
|
||
|
gc.collect() # release memory
|
||
|
|
||
|
job_slots = [
|
||
|
ConcurrentJob(self.algorithm_pipeline)
|
||
|
for _ in range(self.concurrent_num)
|
||
|
]
|
||
|
|
||
|
sample_budget = self.sample_num - sample_count if len(yield_cache) == 0 else 0
|
||
|
while (sample_budget > 0) or any([not job.is_idle() for job in job_slots]):
|
||
|
for job in job_slots:
|
||
|
if job.is_idle() and sample_budget > 0:
|
||
|
node = self._select_node()
|
||
|
self._last_selected_node = node
|
||
|
job.start(data=data, node=node)
|
||
|
sample_budget -= 1
|
||
|
if not job.is_idle():
|
||
|
info = job.get_status()
|
||
|
if info is not None:
|
||
|
# output samples
|
||
|
sample_count += 1
|
||
|
if info['result']['complete']:
|
||
|
_proposal, _sample_info = info['code'], self._post_sample_info(
|
||
|
cost=sample_count, tree_size=len(self.node_list),
|
||
|
)
|
||
|
yield_cache.append((_proposal, _sample_info))
|
||
|
yield _proposal, _sample_info
|
||
|
# logging
|
||
|
if sample_count % self.log_interval == 0:
|
||
|
self.process_print('Progress: {} / {} Tree Size: {}'.format(
|
||
|
sample_count, self.sample_num, len(self.node_list),
|
||
|
))
|
||
|
# saving checkpoints
|
||
|
if sample_count % self.ckpt_interval == 0:
|
||
|
self._save_ckpt(dict(
|
||
|
root=self.root.to_dict(),
|
||
|
sample_count=sample_count,
|
||
|
yield_cache=yield_cache,
|
||
|
))
|
||
|
if len(yield_cache) > 0:
|
||
|
# return after saving the checkpoint
|
||
|
# avoid overestimation caused by interrupt-restart loop
|
||
|
sample_budget = 0
|
||
|
time.sleep(0.1)
|
||
|
|
||
|
# save the final tree structure
|
||
|
self._save_ckpt(dict(
|
||
|
root=self.root.to_dict(),
|
||
|
sample_count=sample_count,
|
||
|
yield_cache=yield_cache,
|
||
|
))
|