mirror of
https://github.com/deepseek-ai/DeepSeek-Prover-V1.5
synced 2024-11-21 19:07:44 +00:00
54 lines
1.9 KiB
Python
54 lines
1.9 KiB
Python
import os
|
|
|
|
import numpy as np
|
|
from transformers import AutoTokenizer
|
|
|
|
from prover.utils import get_datetime, load_jsonl_objects, MODEL_FORMAT
|
|
|
|
|
|
class SamplingAlgorithmBase(object):
|
|
def __init__(self, scheduler, tokenizer_path, process_print, cfg, **kwargs):
|
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
|
self.scheduler = scheduler
|
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
|
self.process_print = process_print
|
|
self.cfg = cfg
|
|
|
|
self.max_tokens = cfg.max_tokens
|
|
self.few_shot_dataset = cfg.get('few_shot_dataset', None)
|
|
if self.few_shot_dataset is not None:
|
|
self.few_shot_dataset = load_jsonl_objects(self.few_shot_dataset)
|
|
self.few_shot_num = cfg.get('few_shot_num', 3)
|
|
self.few_shot_func = MODEL_FORMAT[cfg.mode]['few_shot']
|
|
self.log_interval = cfg.get('log_interval', 32)
|
|
|
|
@property
|
|
def algorithm_name(self):
|
|
return self.__class__.__name__
|
|
|
|
def _post_sample_info(self, **kwargs):
|
|
return dict(
|
|
algorithm=self.algorithm_name,
|
|
datetime=get_datetime(),
|
|
**kwargs,
|
|
)
|
|
|
|
def _encode_length(self, code):
|
|
return len(self.tokenizer.encode(code))
|
|
|
|
def _preprocess_data(self, input_data):
|
|
if self.few_shot_dataset is None or self.few_shot_num == 0:
|
|
return input_data
|
|
return {
|
|
**input_data,
|
|
'_extra_header': ''.join([
|
|
self.few_shot_func(self.few_shot_dataset[idx])
|
|
for idx in np.random.choice([
|
|
_idx for _idx, _data in enumerate(self.few_shot_dataset)
|
|
if _data['name'] != input_data['name']
|
|
], size=self.few_shot_num, replace=False)
|
|
] + [input_data.get('_extra_header', str())]),
|
|
}
|
|
|
|
def sample(self, **kwargs):
|
|
raise NotImplementedError |