Fix sub-process spawn multi-process

This commit is contained in:
allegroai 2021-06-01 00:18:20 +03:00
parent ba1b8c403b
commit fcc3c12b59
3 changed files with 110 additions and 17 deletions

View File

@ -30,10 +30,11 @@ class BackgroundLogService(BackgroundMonitor):
self._last_timestamp = 0 self._last_timestamp = 0
def stop(self): def stop(self):
# make sure we signal the flush event before closing the queue (send everything)
self.flush()
if isinstance(self._queue, PrQueue): if isinstance(self._queue, PrQueue):
self._queue.close(self._event) self._queue.close(self._event)
super(BackgroundLogService, self).stop() super(BackgroundLogService, self).stop()
self.flush()
def daemon(self): def daemon(self):
# multiple daemons are supported # multiple daemons are supported

View File

@ -419,8 +419,7 @@ class Task(_Task):
cls.__main_task._dev_worker = None cls.__main_task._dev_worker = None
cls.__main_task._resource_monitor = None cls.__main_task._resource_monitor = None
# remove the logger from the previous process # remove the logger from the previous process
logger = cls.__main_task.get_logger() cls.__main_task.get_logger()
logger.set_flush_period(None)
# create a new logger (to catch stdout/err) # create a new logger (to catch stdout/err)
cls.__main_task._logger = None cls.__main_task._logger = None
cls.__main_task.__reporter = None cls.__main_task.__reporter = None
@ -432,6 +431,9 @@ class Task(_Task):
cls.__main_task.__register_at_exit(cls.__main_task._at_exit) cls.__main_task.__register_at_exit(cls.__main_task._at_exit)
# TODO: Check if the signal handler method is safe enough, for the time being, do not unhook # TODO: Check if the signal handler method is safe enough, for the time being, do not unhook
# cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True) # cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
# start all reporting threads
BackgroundMonitor.start_all(task=cls.__main_task)
if not running_remotely(): if not running_remotely():
verify_defaults_match() verify_defaults_match()
@ -456,6 +458,7 @@ class Task(_Task):
# we could not find a task ID, revert to old stub behaviour # we could not find a task ID, revert to old stub behaviour
if not is_sub_process_task_id: if not is_sub_process_task_id:
return _TaskStub() return _TaskStub()
elif running_remotely() and not get_is_master_node(): elif running_remotely() and not get_is_master_node():
# make sure we only do it once per process # make sure we only do it once per process
cls.__forked_proc_main_pid = os.getpid() cls.__forked_proc_main_pid = os.getpid()
@ -2418,6 +2421,11 @@ class Task(_Task):
task._dev_worker.unregister() task._dev_worker.unregister()
task._dev_worker = None task._dev_worker = None
@classmethod
def _has_current_task_obj(cls):
# type: () -> bool
return bool(cls.__main_task)
@classmethod @classmethod
def _create_dev_task( def _create_dev_task(
cls, default_project_name, default_task_name, default_task_type, tags, cls, default_project_name, default_task_name, default_task_type, tags,

View File

@ -3,11 +3,10 @@ import pickle
import struct import struct
import sys import sys
from functools import partial from functools import partial
from multiprocessing import Process, Lock, Event as ProcessEvent from multiprocessing import Lock, Event as ProcessEvent
from multiprocessing.pool import ThreadPool
from threading import Thread, Event as TrEvent from threading import Thread, Event as TrEvent
from time import sleep, time from time import sleep, time
from typing import List, Dict from typing import List, Dict, Optional
import psutil import psutil
from six.moves.queue import Empty, Queue as TrQueue from six.moves.queue import Empty, Queue as TrQueue
@ -16,9 +15,59 @@ from ..py3_interop import AbstractContextManager
try: try:
from multiprocessing import SimpleQueue from multiprocessing import SimpleQueue
except ImportError: # noqa except ImportError:
from multiprocessing.queues import SimpleQueue from multiprocessing.queues import SimpleQueue
try:
from multiprocessing.context import ForkProcess as Process # noqa
except ImportError:
from multiprocessing import Process
class ThreadCalls(object):
def __init__(self):
self._queue = TrQueue()
self._thread = Thread(target=self._worker)
self._thread.daemon = True
self._thread.start()
def is_alive(self):
return bool(self._thread)
def apply_async(self, func, args=None):
if not func:
return False
self._queue.put((func, args))
return True
def close(self):
if not self._thread:
return
t = self._thread
# push something into queue so it knows this is the end
self._queue.put(None)
# wait fot thread
t.join()
# mark thread is done
self._thread = None
def _worker(self):
while True:
try:
request = self._queue.get(block=True, timeout=1.0)
if not request:
break
except Empty:
continue
# noinspection PyBroadException
try:
if request[1]:
request[0](*request[1])
else:
request[0]()
except Exception:
pass
class SingletonThreadPool(object): class SingletonThreadPool(object):
__thread_pool = None __thread_pool = None
@ -27,7 +76,7 @@ class SingletonThreadPool(object):
@classmethod @classmethod
def get(cls): def get(cls):
if os.getpid() != cls.__thread_pool_pid: if os.getpid() != cls.__thread_pool_pid:
cls.__thread_pool = ThreadPool(1) cls.__thread_pool = ThreadCalls()
cls.__thread_pool_pid = os.getpid() cls.__thread_pool_pid = os.getpid()
return cls.__thread_pool return cls.__thread_pool
@ -38,6 +87,10 @@ class SingletonThreadPool(object):
cls.__thread_pool = None cls.__thread_pool = None
cls.__thread_pool_pid = None cls.__thread_pool_pid = None
@classmethod
def is_active(cls):
return cls.__thread_pool and cls.__thread_pool.is_alive()
class SafeQueue(object): class SafeQueue(object):
""" """
@ -47,10 +100,12 @@ class SafeQueue(object):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._reader_thread = None self._reader_thread = None
self._reader_thread_started = False
self._q = SimpleQueue(*args, **kwargs) self._q = SimpleQueue(*args, **kwargs)
# Fix the simple queue write so it uses a single OS write, making it atomic message passing # Fix the simple queue write so it uses a single OS write, making it atomic message passing
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# noinspection PyUnresolvedReferences,PyProtectedMember
self._q._writer._send_bytes = partial(SafeQueue._pipe_override_send_bytes, self._q._writer) self._q._writer._send_bytes = partial(SafeQueue._pipe_override_send_bytes, self._q._writer)
except Exception: except Exception:
pass pass
@ -65,12 +120,17 @@ class SafeQueue(object):
# only call from main put process # only call from main put process
return self._q_size > 0 return self._q_size > 0
def close(self, event): def close(self, event, timeout=100.0):
# wait until all pending requests pushed # wait until all pending requests pushed
tic = time()
while self.is_pending(): while self.is_pending():
if event: if event:
event.set() event.set()
if not self.__thread_pool.is_active():
break
sleep(0.1) sleep(0.1)
if timeout and (time()-tic) > timeout:
break
def get(self, *args, **kwargs): def get(self, *args, **kwargs):
return self._get_internal_queue(*args, **kwargs) return self._get_internal_queue(*args, **kwargs)
@ -104,18 +164,30 @@ class SafeQueue(object):
# GIL will make sure it is atomic # GIL will make sure it is atomic
self._q_size -= 1 self._q_size -= 1
def _get_internal_queue(self, *args, **kwargs): def _init_reader_thread(self):
if not self._internal_q: if not self._internal_q:
self._internal_q = TrQueue() self._internal_q = TrQueue()
if not self._reader_thread: if not self._reader_thread or not self._reader_thread.is_alive():
# read before we start the thread
self._reader_thread = Thread(target=self._reader_daemon) self._reader_thread = Thread(target=self._reader_daemon)
self._reader_thread.daemon = True self._reader_thread.daemon = True
self._reader_thread.start() self._reader_thread.start()
# if we have waiting results
# wait until thread is up and pushed some results
while not self._reader_thread_started:
sleep(0.2)
# just in case make sure we pulled some stuff if we had any
# todo: wait until a queue is not empty, but for some reason that might fail
sleep(1.0)
def _get_internal_queue(self, *args, **kwargs):
self._init_reader_thread()
obj = self._internal_q.get(*args, **kwargs) obj = self._internal_q.get(*args, **kwargs)
# deserialize # deserialize
return pickle.loads(obj) return pickle.loads(obj)
def _reader_daemon(self): def _reader_daemon(self):
self._reader_thread_started = True
# pull from process queue and push into thread queue # pull from process queue and push into thread queue
while True: while True:
# noinspection PyBroadException # noinspection PyBroadException
@ -287,7 +359,7 @@ class BackgroundMonitor(object):
pass pass
@classmethod @classmethod
def start_all(cls, task, wait_for_subprocess=False): def start_all(cls, task, wait_for_subprocess=True):
# noinspection PyProtectedMember # noinspection PyProtectedMember
execute_in_subprocess = task._report_subprocess_enabled execute_in_subprocess = task._report_subprocess_enabled
@ -302,7 +374,11 @@ class BackgroundMonitor(object):
# setup # setup
for d in BackgroundMonitor._instances.get(id(task.id), []): for d in BackgroundMonitor._instances.get(id(task.id), []):
d.set_subprocess_mode() d.set_subprocess_mode()
BackgroundMonitor._main_process = Process(target=cls._background_process_start, args=(id(task.id), )) # todo: solve for standalone spawn subprocess
BackgroundMonitor._main_process = Process(
target=cls._background_process_start,
args=(id(task.id), cls._sub_process_started)
)
BackgroundMonitor._main_process.daemon = True BackgroundMonitor._main_process.daemon = True
# Hack allow to create daemon subprocesses (even though python doesn't like it) # Hack allow to create daemon subprocesses (even though python doesn't like it)
un_daemonize = False un_daemonize = False
@ -336,15 +412,19 @@ class BackgroundMonitor(object):
cls._sub_process_started.wait() cls._sub_process_started.wait()
@classmethod @classmethod
def _background_process_start(cls, task_obj_id): def _background_process_start(cls, task_obj_id, event_start=None):
# type: (int, Optional[SafeEvent]) -> None
is_debugger_running = bool(getattr(sys, 'gettrace', None) and sys.gettrace()) is_debugger_running = bool(getattr(sys, 'gettrace', None) and sys.gettrace())
# restore original signal, this will prevent any deadlocks # restore original signal, this will prevent any deadlocks
# Do not change the exception we need to catch base exception as well # Do not change the exception we need to catch base exception as well
# noinspection PyBroadException # noinspection PyBroadException
try: try:
from ... import Task from ... import Task
# make sure we do not call Task.current_task() it will create a Task object for us on a subprocess!
# noinspection PyProtectedMember # noinspection PyProtectedMember
Task.current_task()._remove_at_exit_callbacks() if Task._has_current_task_obj():
# noinspection PyProtectedMember
Task.current_task()._remove_at_exit_callbacks()
except: # noqa except: # noqa
pass pass
@ -352,15 +432,19 @@ class BackgroundMonitor(object):
if is_debugger_running: if is_debugger_running:
sleep(3) sleep(3)
instances = BackgroundMonitor._instances.get(task_obj_id, [])
# launch all the threads # launch all the threads
for d in cls._instances.get(task_obj_id, []): for d in instances:
d._start() d._start()
if cls._sub_process_started: if cls._sub_process_started:
cls._sub_process_started.set() cls._sub_process_started.set()
if event_start:
event_start.set()
# wait until we are signaled # wait until we are signaled
for i in BackgroundMonitor._instances.get(task_obj_id, []): for i in instances:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if i._thread and i._thread.is_alive(): if i._thread and i._thread.is_alive():