DeepSeek-Prover-V1.5/prover/workers/scheduler.py
2024-08-16 11:33:21 +08:00

122 lines
4.0 KiB
Python

import os
import time
import ctypes
import subprocess
import threading
import multiprocessing as mp
import numpy as np
from prover.utils import AttrDict
class TaskQueue(object):
def __init__(self, batch_size=512, name='test'):
self.name = name
self.batch_size = batch_size
self.manager = mp.Manager()
self.waiting_list = self.manager.list()
self.all_tasks_done = mp.Event()
self.lock = mp.Lock()
self._monitor_log = self.manager.list()
self._monitor_thread = threading.Thread(target=self._monitor)
self._monitor_thread.start()
def _monitor(self):
last_log_time = time.time()
while not self.all_tasks_done.is_set():
if time.time() - last_log_time >= 60.0:
with self.lock:
if len(self._monitor_log) > 0:
print('TaskQueue-{}: {} requests popped with avg batch_size {:.1f} in last period {} waiting in queue'.format(
self.name, np.sum(self._monitor_log), np.mean(self._monitor_log), len(self.waiting_list),
))
self._monitor_log[:] = []
last_log_time = time.time()
time.sleep(1.0)
def __len__(self):
return len(self.waiting_list)
def put(self, item):
with self.lock:
self.waiting_list.append(item)
def get(self, no_wait=False):
while not self.all_tasks_done.is_set():
with self.lock:
if len(self.waiting_list) > 0:
tasks = self.waiting_list[:self.batch_size]
self.waiting_list[:self.batch_size] = []
self._monitor_log.append(len(tasks))
return tasks
if no_wait:
break
time.sleep(0.1)
return None
def close(self):
self.all_tasks_done.set()
self._monitor_thread.join()
class ProcessScheduler(object):
def __init__(self, batch_size=512, name='test'):
self.name = name
self.manager = mp.Manager()
self.batch_size = batch_size
self.task_queue = TaskQueue(batch_size=batch_size, name=name)
self.request_statuses = self.manager.dict()
self.request_counter = mp.Value(ctypes.c_int32, 0)
self.lock = mp.Lock()
def submit_request(self, data):
with self.lock:
self.request_counter.value += 1
request_id = self.request_counter.value
self.request_statuses[request_id] = None
self.task_queue.put((time.time(), request_id, data))
return request_id
def submit_all_request(self, data_list):
request_id_list = [self.submit_request(data) for data in data_list]
return request_id_list
def get_request_status(self, request_id):
with self.lock:
response = self.request_statuses.get(request_id, None)
if response is not None:
self.request_statuses.pop(request_id)
return response
def get_request_outputs(self, request_id):
while True:
outputs = self.get_request_status(request_id)
if outputs is not None:
return outputs
time.sleep(1.0)
def get_all_request_outputs(self, request_id_list):
outputs_list = []
for request_id in request_id_list:
outputs_list.append(self.get_request_outputs(request_id))
return outputs_list
def close(self):
self.task_queue.close()
class Scheduler(object):
def __init__(self, scheduler_dict):
self._scheduler_dict = scheduler_dict
for name, scheduler in scheduler_dict.items():
self.__setattr__(name, scheduler)
for key in dir(scheduler):
if not key.startswith('_'):
self.__setattr__(f'{name}_{key}', scheduler.__getattribute__(key))
def close(self):
for _, scheduler in self._scheduler_dict.items():
scheduler.close()