Fix sub-process spawn multi-process

This commit is contained in:
allegroai 2021-06-01 00:18:20 +03:00
parent ba1b8c403b
commit fcc3c12b59
3 changed files with 110 additions and 17 deletions

View File

@ -30,10 +30,11 @@ class BackgroundLogService(BackgroundMonitor):
self._last_timestamp = 0
def stop(self):
# make sure we signal the flush event before closing the queue (send everything)
self.flush()
if isinstance(self._queue, PrQueue):
self._queue.close(self._event)
super(BackgroundLogService, self).stop()
self.flush()
def daemon(self):
# multiple daemons are supported

View File

@ -419,8 +419,7 @@ class Task(_Task):
cls.__main_task._dev_worker = None
cls.__main_task._resource_monitor = None
# remove the logger from the previous process
logger = cls.__main_task.get_logger()
logger.set_flush_period(None)
cls.__main_task.get_logger()
# create a new logger (to catch stdout/err)
cls.__main_task._logger = None
cls.__main_task.__reporter = None
@ -432,6 +431,9 @@ class Task(_Task):
cls.__main_task.__register_at_exit(cls.__main_task._at_exit)
# TODO: Check if the signal handler method is safe enough, for the time being, do not unhook
# cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
# start all reporting threads
BackgroundMonitor.start_all(task=cls.__main_task)
if not running_remotely():
verify_defaults_match()
@ -456,6 +458,7 @@ class Task(_Task):
# we could not find a task ID, revert to old stub behaviour
if not is_sub_process_task_id:
return _TaskStub()
elif running_remotely() and not get_is_master_node():
# make sure we only do it once per process
cls.__forked_proc_main_pid = os.getpid()
@ -2418,6 +2421,11 @@ class Task(_Task):
task._dev_worker.unregister()
task._dev_worker = None
@classmethod
def _has_current_task_obj(cls):
# type: () -> bool
return bool(cls.__main_task)
@classmethod
def _create_dev_task(
cls, default_project_name, default_task_name, default_task_type, tags,

View File

@ -3,11 +3,10 @@ import pickle
import struct
import sys
from functools import partial
from multiprocessing import Process, Lock, Event as ProcessEvent
from multiprocessing.pool import ThreadPool
from multiprocessing import Lock, Event as ProcessEvent
from threading import Thread, Event as TrEvent
from time import sleep, time
from typing import List, Dict
from typing import List, Dict, Optional
import psutil
from six.moves.queue import Empty, Queue as TrQueue
@ -16,9 +15,59 @@ from ..py3_interop import AbstractContextManager
try:
from multiprocessing import SimpleQueue
except ImportError: # noqa
except ImportError:
from multiprocessing.queues import SimpleQueue
try:
from multiprocessing.context import ForkProcess as Process # noqa
except ImportError:
from multiprocessing import Process
class ThreadCalls(object):
def __init__(self):
self._queue = TrQueue()
self._thread = Thread(target=self._worker)
self._thread.daemon = True
self._thread.start()
def is_alive(self):
return bool(self._thread)
def apply_async(self, func, args=None):
if not func:
return False
self._queue.put((func, args))
return True
def close(self):
if not self._thread:
return
t = self._thread
# push something into queue so it knows this is the end
self._queue.put(None)
# wait fot thread
t.join()
# mark thread is done
self._thread = None
def _worker(self):
while True:
try:
request = self._queue.get(block=True, timeout=1.0)
if not request:
break
except Empty:
continue
# noinspection PyBroadException
try:
if request[1]:
request[0](*request[1])
else:
request[0]()
except Exception:
pass
class SingletonThreadPool(object):
__thread_pool = None
@ -27,7 +76,7 @@ class SingletonThreadPool(object):
@classmethod
def get(cls):
if os.getpid() != cls.__thread_pool_pid:
cls.__thread_pool = ThreadPool(1)
cls.__thread_pool = ThreadCalls()
cls.__thread_pool_pid = os.getpid()
return cls.__thread_pool
@ -38,6 +87,10 @@ class SingletonThreadPool(object):
cls.__thread_pool = None
cls.__thread_pool_pid = None
@classmethod
def is_active(cls):
return cls.__thread_pool and cls.__thread_pool.is_alive()
class SafeQueue(object):
"""
@ -47,10 +100,12 @@ class SafeQueue(object):
def __init__(self, *args, **kwargs):
self._reader_thread = None
self._reader_thread_started = False
self._q = SimpleQueue(*args, **kwargs)
# Fix the simple queue write so it uses a single OS write, making it atomic message passing
# noinspection PyBroadException
try:
# noinspection PyUnresolvedReferences,PyProtectedMember
self._q._writer._send_bytes = partial(SafeQueue._pipe_override_send_bytes, self._q._writer)
except Exception:
pass
@ -65,12 +120,17 @@ class SafeQueue(object):
# only call from main put process
return self._q_size > 0
def close(self, event):
def close(self, event, timeout=100.0):
# wait until all pending requests pushed
tic = time()
while self.is_pending():
if event:
event.set()
if not self.__thread_pool.is_active():
break
sleep(0.1)
if timeout and (time()-tic) > timeout:
break
def get(self, *args, **kwargs):
return self._get_internal_queue(*args, **kwargs)
@ -104,18 +164,30 @@ class SafeQueue(object):
# GIL will make sure it is atomic
self._q_size -= 1
def _get_internal_queue(self, *args, **kwargs):
def _init_reader_thread(self):
if not self._internal_q:
self._internal_q = TrQueue()
if not self._reader_thread:
if not self._reader_thread or not self._reader_thread.is_alive():
# read before we start the thread
self._reader_thread = Thread(target=self._reader_daemon)
self._reader_thread.daemon = True
self._reader_thread.start()
# if we have waiting results
# wait until thread is up and pushed some results
while not self._reader_thread_started:
sleep(0.2)
# just in case make sure we pulled some stuff if we had any
# todo: wait until a queue is not empty, but for some reason that might fail
sleep(1.0)
def _get_internal_queue(self, *args, **kwargs):
self._init_reader_thread()
obj = self._internal_q.get(*args, **kwargs)
# deserialize
return pickle.loads(obj)
def _reader_daemon(self):
self._reader_thread_started = True
# pull from process queue and push into thread queue
while True:
# noinspection PyBroadException
@ -287,7 +359,7 @@ class BackgroundMonitor(object):
pass
@classmethod
def start_all(cls, task, wait_for_subprocess=False):
def start_all(cls, task, wait_for_subprocess=True):
# noinspection PyProtectedMember
execute_in_subprocess = task._report_subprocess_enabled
@ -302,7 +374,11 @@ class BackgroundMonitor(object):
# setup
for d in BackgroundMonitor._instances.get(id(task.id), []):
d.set_subprocess_mode()
BackgroundMonitor._main_process = Process(target=cls._background_process_start, args=(id(task.id), ))
# todo: solve for standalone spawn subprocess
BackgroundMonitor._main_process = Process(
target=cls._background_process_start,
args=(id(task.id), cls._sub_process_started)
)
BackgroundMonitor._main_process.daemon = True
# Hack allow to create daemon subprocesses (even though python doesn't like it)
un_daemonize = False
@ -336,15 +412,19 @@ class BackgroundMonitor(object):
cls._sub_process_started.wait()
@classmethod
def _background_process_start(cls, task_obj_id):
def _background_process_start(cls, task_obj_id, event_start=None):
# type: (int, Optional[SafeEvent]) -> None
is_debugger_running = bool(getattr(sys, 'gettrace', None) and sys.gettrace())
# restore original signal, this will prevent any deadlocks
# Do not change the exception we need to catch base exception as well
# noinspection PyBroadException
try:
from ... import Task
# make sure we do not call Task.current_task() it will create a Task object for us on a subprocess!
# noinspection PyProtectedMember
Task.current_task()._remove_at_exit_callbacks()
if Task._has_current_task_obj():
# noinspection PyProtectedMember
Task.current_task()._remove_at_exit_callbacks()
except: # noqa
pass
@ -352,15 +432,19 @@ class BackgroundMonitor(object):
if is_debugger_running:
sleep(3)
instances = BackgroundMonitor._instances.get(task_obj_id, [])
# launch all the threads
for d in cls._instances.get(task_obj_id, []):
for d in instances:
d._start()
if cls._sub_process_started:
cls._sub_process_started.set()
if event_start:
event_start.set()
# wait until we are signaled
for i in BackgroundMonitor._instances.get(task_obj_id, []):
for i in instances:
# noinspection PyBroadException
try:
if i._thread and i._thread.is_alive():