DeepSeek-Prover-V1.5/prover/algorithms/base.py
2024-08-16 11:33:21 +08:00

54 lines
1.9 KiB
Python

import os
import numpy as np
from transformers import AutoTokenizer
from prover.utils import get_datetime, load_jsonl_objects, MODEL_FORMAT
class SamplingAlgorithmBase(object):
def __init__(self, scheduler, tokenizer_path, process_print, cfg, **kwargs):
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
self.scheduler = scheduler
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.process_print = process_print
self.cfg = cfg
self.max_tokens = cfg.max_tokens
self.few_shot_dataset = cfg.get('few_shot_dataset', None)
if self.few_shot_dataset is not None:
self.few_shot_dataset = load_jsonl_objects(self.few_shot_dataset)
self.few_shot_num = cfg.get('few_shot_num', 3)
self.few_shot_func = MODEL_FORMAT[cfg.mode]['few_shot']
self.log_interval = cfg.get('log_interval', 32)
@property
def algorithm_name(self):
return self.__class__.__name__
def _post_sample_info(self, **kwargs):
return dict(
algorithm=self.algorithm_name,
datetime=get_datetime(),
**kwargs,
)
def _encode_length(self, code):
return len(self.tokenizer.encode(code))
def _preprocess_data(self, input_data):
if self.few_shot_dataset is None or self.few_shot_num == 0:
return input_data
return {
**input_data,
'_extra_header': ''.join([
self.few_shot_func(self.few_shot_dataset[idx])
for idx in np.random.choice([
_idx for _idx, _data in enumerate(self.few_shot_dataset)
if _data['name'] != input_data['name']
], size=self.few_shot_num, replace=False)
] + [input_data.get('_extra_header', str())]),
}
def sample(self, **kwargs):
raise NotImplementedError