DeepSeek-Prover-V1.5/prover/workers/generator.py
2024-08-16 11:33:21 +08:00

53 lines
1.9 KiB
Python

import os
import time
import torch
import torch.multiprocessing as mp
from vllm import LLM, SamplingParams
from prover.utils import AttrDict, MODEL_FORMAT
class GeneratorProcess(mp.Process):
def __init__(self, local_rank, node_rank, model_path, task_queue, request_statuses, lock, args):
super().__init__()
self.local_rank = local_rank
self.node_rank = node_rank
self.model_path = model_path
self.task_queue = task_queue
self.request_statuses = request_statuses
self.lock = lock
self.sampling_params = SamplingParams(
temperature=args.temperature,
max_tokens=args.max_tokens,
top_p=args.top_p,
n=1,
)
self.prompt_func = MODEL_FORMAT[args.mode]['prompt']
self.output_func = MODEL_FORMAT[args.mode]['output']
def run(self):
seed = int(time.time()) % 1000 + (self.node_rank * 8 + self.local_rank) * 1000
os.environ['LOCAL_RANK'] = str(self.local_rank)
llm = LLM(model=self.model_path, max_num_batched_tokens=8192, seed=seed, trust_remote_code=True)
while True:
inputs = self.task_queue.get()
if inputs is None: # Terminate when receiving None
break
model_inputs = [
''.join([
item.get('_extra_header', str()),
self.prompt_func(item),
item.get('_extra_prompt', str()),
]) for _, _, item in inputs
]
model_outputs = llm.generate(
model_inputs,
self.sampling_params,
use_tqdm=False,
)
outputs = [self.output_func(_output.outputs[0].text) for _output in model_outputs]
with self.lock:
for (_, request_id, _), output in zip(inputs, outputs):
self.request_statuses[request_id] = output