DeepSeek-Prover-V1.5/prover/algorithms/rmax_tree_search.py

321 lines
13 KiB
Python
Raw Permalink Normal View History

2024-08-16 03:33:21 +00:00
import os
import gc
import time
import math
import random
import pickle
import subprocess
import numpy as np
from .base import SamplingAlgorithmBase
from prover.lean.proof import ProofSummarizer
from prover.utils import ConcurrentJob
class TreeNode(object):
def __init__(self, parent=None, code=None, **kwargs):
self.parent = parent
self.children = dict()
self._info = {key: val for key, val in kwargs.items()}
if code is not None:
self.update_code(code)
if '_discounted_rewards' not in self._info:
self._info['_discounted_rewards'] = 0.0
if '_discounted_visitation' not in self._info:
self._info['_discounted_visitation'] = 0.0
if '_subtree_discounted_rewards' not in self._info:
self._info['_subtree_discounted_rewards'] = 0.0
if '_subtree_discounted_visitation' not in self._info:
self._info['_subtree_discounted_visitation'] = 0.0
self._num_running_jobs = 0
self._subtree_num_running_jobs = 0
self._update_value(gamma=0.0) # gamma=0.0 is okay for initialization
# basic tree supports
@property
def code(self):
return random.choice(self._info['_code_list'])
def update_code(self, code):
if '_code_list' not in self._info:
self._info['_code_list'] = []
if code not in self._info['_code_list']:
self._info['_code_list'].append(code)
def __getitem__(self, key):
return self._info[key]
def to_node_list(self):
return sum([child.to_node_list() for _, child in self.children.items()], start=[self])
def to_dict(self):
return dict(
info=self._info,
children={
edge: child_node.to_dict()
for edge, child_node in self.children.items()
}
)
@classmethod
def from_dict(cls, dict_data, parent=None):
node = cls(
parent=parent,
**dict_data['info'],
)
node.children = {
edge: cls.from_dict(child_dict, parent=node)
for edge, child_dict in dict_data['children'].items()
}
return node
# algorithm supports
def update_reward(self, reward, gamma, first_node=True):
if first_node:
self._info['_discounted_rewards'] = self._info['_discounted_rewards'] * gamma + reward
self._info['_discounted_visitation'] = self._info['_discounted_visitation'] * gamma + 1.0
self._info['_subtree_discounted_rewards'] = self._info['_subtree_discounted_rewards'] * gamma + reward
self._info['_subtree_discounted_visitation'] = self._info['_subtree_discounted_visitation'] * gamma + 1.0
self._update_value(gamma)
if self.parent is not None:
self.parent.update_reward(reward, gamma, first_node=False)
def start_new_job(self, gamma, first_node=True):
if first_node:
self._num_running_jobs += 1
self._subtree_num_running_jobs += 1
self._update_value(gamma)
if self.parent is not None:
self.parent.start_new_job(gamma, first_node=False)
def complete_job(self, gamma, first_node=True):
if first_node:
self._num_running_jobs -= 1
self._subtree_num_running_jobs -= 1
self._update_value(gamma)
if self.parent is not None:
self.parent.complete_job(gamma, first_node=False)
def _update_value(self, gamma):
discounted_rewards = self._info['_discounted_rewards'] * (gamma ** self._num_running_jobs)
discounted_visitation = \
self._info['_discounted_visitation'] * (gamma ** self._num_running_jobs) \
+ (1.0 - (gamma ** self._num_running_jobs)) / (1.0 - gamma)
self.value = discounted_rewards / max(discounted_visitation, 1e-2)
self.visitation = discounted_visitation
subtree_discounted_rewards = self._info['_subtree_discounted_rewards'] * (gamma ** self._subtree_num_running_jobs)
subtree_discounted_visitation = \
self._info['_subtree_discounted_visitation'] * (gamma ** self._subtree_num_running_jobs) \
+ (1.0 - (gamma ** self._subtree_num_running_jobs)) / (1.0 - gamma)
self.subtree_value = subtree_discounted_rewards / max(subtree_discounted_visitation, 1e-2)
self.subtree_visitation = subtree_discounted_visitation
class RMaxTS(SamplingAlgorithmBase):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.gamma = self.cfg.get('gamma', 0.99)
self.sample_num = self.cfg.get('sample_num', 6400)
self.concurrent_num = self.cfg.get('concurrent_num', 32)
self.tactic_state_comment = self.cfg.get('tactic_state_comment', True)
self.ckpt_interval = self.cfg.get('ckpt_interval', 128)
self.ckpt_filename = 'checkpoint.pkl'
self.node_cls = TreeNode
self.algorithm_pipeline = [
self._tactic_tree_generate_proof,
self._tactic_tree_parse_proof,
self._rmax_exploration_summarize_results,
]
# basic supports
def _save_ckpt(self, ckpt_dict: dict):
# save a backup before overwriting the checkpoint file
if os.path.exists(self.ckpt_path):
subprocess.run(['cp', self.ckpt_path, self.ckpt_path + '.backup'])
# overwrite the checkpoint file
with open(self.ckpt_path, 'wb') as pkl_f:
pickle.dump(ckpt_dict, pkl_f)
# tree structure supports
def _tree_setup(self, data):
# initialize tree
ckpt_info = None
for _ckpt_path in [self.ckpt_path, self.ckpt_path + '.backup']:
if os.path.exists(_ckpt_path):
try:
with open(_ckpt_path, 'rb') as pkl_f:
ckpt_info = pickle.load(pkl_f)
except:
self.process_print(f'Checkpoint saved at {_ckpt_path} is broken.')
if ckpt_info is not None:
root = self.node_cls.from_dict(ckpt_info['root'])
sample_count = ckpt_info['sample_count']
yield_cache = ckpt_info['yield_cache']
self.process_print(f'Load checkpoint from sample_count={sample_count}')
else:
root = self.node_cls(code=dict(tactic_code=str(), state_comment=str()), depth=0)
sample_count = 0
yield_cache = []
# compile the root node with `sorry`
self.proof_summarizer = ProofSummarizer(data=data, scheduler=self.scheduler)
root_sorry = self.proof_summarizer.analyze(' sorry', require_verification=True)
assert root_sorry.result['pass'], "Cannot parse a `sorry` tactic on root."
self.root_goal = root_sorry.result['sorries'][-1]['goal']
# other initialization
self._last_selected_node = root
return root, sample_count, yield_cache
def _tree_new_child(self, parent):
return self.node_cls(
parent=parent,
depth=parent['depth'] + 1,
)
def _tree_step(self, node, edge, code):
if edge not in node.children:
new_node = self._tree_new_child(node)
node.children[edge] = new_node
self.node_list.append(new_node)
child_node = node.children[edge]
child_node.update_code(code)
return child_node
def _tree_update(self, proof):
node_walk, partial_code = self.root, str()
# use tactic goals as tree edges
segments = proof.segmentation()
prev_goal = self.root_goal
for info in segments:
partial_code += info.tactic_code
code = partial_code + info.state_comment if self.tactic_state_comment else partial_code
if self._encode_length(code) < self.max_tokens:
if info.goal != prev_goal:
node_walk = self._tree_step(
node=node_walk, edge=info.goal,
code=dict(tactic_code=partial_code, state_comment=info.state_comment)
)
prev_goal = info.goal
return node_walk
# algorithm pipeline
def _select_node(self):
node = self.root
while len(node.children) > 0:
num_choice = 1 + len(node.children)
total_visitation = node.visitation + np.sum([child.subtree_visitation for _, child in node.children.items()])
def hoeffding_ucb(node_visitation):
return math.sqrt(2.0 * math.log(max(total_visitation, 2)) / max(node_visitation, 1e-2))
choice_list = [(node.value + hoeffding_ucb(node.visitation), None)]
for _, child in node.children.items():
choice_list.append((child.subtree_value + hoeffding_ucb(child.subtree_visitation), child))
choice_list.sort(reverse=True, key=lambda x: x[0])
if choice_list[0][1] is None:
node.start_new_job(gamma=self.gamma)
return node
else:
node = choice_list[0][1]
node.start_new_job(gamma=self.gamma)
return node
def _tactic_tree_generate_proof(self, data, node):
code_prefix = node.code
extra_prompt = code_prefix['tactic_code']
if self.tactic_state_comment:
extra_prompt += code_prefix['state_comment']
return dict(
node=node,
code_prefix=code_prefix,
generator_request_id=self.scheduler.generator_submit_request(
self._preprocess_data({**data, '_extra_prompt': extra_prompt}),
),
)
def _tactic_tree_parse_proof(self, node, code_prefix, generator_request_id):
code = self.scheduler.generator_get_request_status(generator_request_id)
if code is None:
return None
code = code_prefix['tactic_code'] + code
proof = self.proof_summarizer.analyze(code, require_verification=True)
return dict(node=node, proof=proof)
def _rmax_exploration_summarize_results(self, node, proof):
if not proof.is_result_ready():
return None
num_nodes_before = len(self.node_list)
self._tree_update(proof)
# RMax reward
node.update_reward(int(len(self.node_list) > num_nodes_before), gamma=self.gamma)
node.complete_job(gamma=self.gamma)
return dict(
code=proof.cleaned_code,
result=proof.result,
)
# sampler interface
def sample(self, data, prob_log_dir, **kwargs):
self.ckpt_path = os.path.join(prob_log_dir, self.ckpt_filename)
self.root, sample_count, yield_cache = self._tree_setup(data)
self.node_list = self.root.to_node_list()
for _proposal, _sample_info in yield_cache:
yield _proposal, _sample_info
gc.collect() # release memory
job_slots = [
ConcurrentJob(self.algorithm_pipeline)
for _ in range(self.concurrent_num)
]
sample_budget = self.sample_num - sample_count if len(yield_cache) == 0 else 0
while (sample_budget > 0) or any([not job.is_idle() for job in job_slots]):
for job in job_slots:
if job.is_idle() and sample_budget > 0:
node = self._select_node()
self._last_selected_node = node
job.start(data=data, node=node)
sample_budget -= 1
if not job.is_idle():
info = job.get_status()
if info is not None:
# output samples
sample_count += 1
if info['result']['complete']:
_proposal, _sample_info = info['code'], self._post_sample_info(
cost=sample_count, tree_size=len(self.node_list),
)
yield_cache.append((_proposal, _sample_info))
yield _proposal, _sample_info
# logging
if sample_count % self.log_interval == 0:
self.process_print('Progress: {} / {} Tree Size: {}'.format(
sample_count, self.sample_num, len(self.node_list),
))
# saving checkpoints
if sample_count % self.ckpt_interval == 0:
self._save_ckpt(dict(
root=self.root.to_dict(),
sample_count=sample_count,
yield_cache=yield_cache,
))
if len(yield_cache) > 0:
# return after saving the checkpoint
# avoid overestimation caused by interrupt-restart loop
sample_budget = 0
time.sleep(0.1)
# save the final tree structure
self._save_ckpt(dict(
root=self.root.to_dict(),
sample_count=sample_count,
yield_cache=yield_cache,
))