mirror of
https://github.com/deepseek-ai/DeepSeek-Prover-V1.5
synced 2024-11-22 03:17:43 +00:00
53 lines
1.9 KiB
Python
53 lines
1.9 KiB
Python
|
import os
|
||
|
import time
|
||
|
|
||
|
import torch
|
||
|
import torch.multiprocessing as mp
|
||
|
from vllm import LLM, SamplingParams
|
||
|
|
||
|
from prover.utils import AttrDict, MODEL_FORMAT
|
||
|
|
||
|
|
||
|
class GeneratorProcess(mp.Process):
|
||
|
def __init__(self, local_rank, node_rank, model_path, task_queue, request_statuses, lock, args):
|
||
|
super().__init__()
|
||
|
self.local_rank = local_rank
|
||
|
self.node_rank = node_rank
|
||
|
self.model_path = model_path
|
||
|
self.task_queue = task_queue
|
||
|
self.request_statuses = request_statuses
|
||
|
self.lock = lock
|
||
|
self.sampling_params = SamplingParams(
|
||
|
temperature=args.temperature,
|
||
|
max_tokens=args.max_tokens,
|
||
|
top_p=args.top_p,
|
||
|
n=1,
|
||
|
)
|
||
|
self.prompt_func = MODEL_FORMAT[args.mode]['prompt']
|
||
|
self.output_func = MODEL_FORMAT[args.mode]['output']
|
||
|
|
||
|
def run(self):
|
||
|
seed = int(time.time()) % 1000 + (self.node_rank * 8 + self.local_rank) * 1000
|
||
|
os.environ['LOCAL_RANK'] = str(self.local_rank)
|
||
|
llm = LLM(model=self.model_path, max_num_batched_tokens=8192, seed=seed, trust_remote_code=True)
|
||
|
while True:
|
||
|
inputs = self.task_queue.get()
|
||
|
if inputs is None: # Terminate when receiving None
|
||
|
break
|
||
|
model_inputs = [
|
||
|
''.join([
|
||
|
item.get('_extra_header', str()),
|
||
|
self.prompt_func(item),
|
||
|
item.get('_extra_prompt', str()),
|
||
|
]) for _, _, item in inputs
|
||
|
]
|
||
|
model_outputs = llm.generate(
|
||
|
model_inputs,
|
||
|
self.sampling_params,
|
||
|
use_tqdm=False,
|
||
|
)
|
||
|
outputs = [self.output_func(_output.outputs[0].text) for _output in model_outputs]
|
||
|
with self.lock:
|
||
|
for (_, request_id, _), output in zip(inputs, outputs):
|
||
|
self.request_statuses[request_id] = output
|