DeepSeek-Prover-V1.5/prover/algorithms/sampling.py
2024-08-16 11:33:21 +08:00

23 lines
880 B
Python

from .base import SamplingAlgorithmBase
class Sampling(SamplingAlgorithmBase):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.sample_num = self.cfg.get('sample_num', 32)
def sample(self, data, **kwargs):
request_id_list = [
self.scheduler.generator_submit_request(
# add few-shot prompts
self._preprocess_data(data),
) for _ in range(self.sample_num)
]
for _idx, request_id in enumerate(request_id_list):
outputs = self.scheduler.generator_get_request_outputs(request_id)
yield outputs, self._post_sample_info(cost=_idx+1)
if _idx + 1 < self.sample_num and (_idx + 1) % self.log_interval == 0:
self.process_print('Progress: {} / {}'.format(
_idx + 1, self.sample_num
))