Fix fork (process pool) hangs or drops reports when reports are at the end of the forked function, in both threaded and subprocess mode reporting

This commit is contained in:
allegroai 2022-01-08 22:58:34 +02:00
parent fccb14974c
commit c7c0462e89
4 changed files with 250 additions and 85 deletions

View File

@ -2,13 +2,12 @@ import datetime
import json
import logging
import math
from multiprocessing import Semaphore
from threading import Event as TrEvent
import os
from time import sleep, time
import numpy as np
import six
from six.moves.queue import Queue as TrQueue, Empty
from six.moves.queue import Empty
from .events import (
ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload,
@ -21,7 +20,7 @@ from ...utilities.plotly_reporter import (
create_2d_histogram_plot, create_value_matrix, create_3d_surface,
create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict,
create_image_plot, create_plotly_table, )
from ...utilities.process.mp import BackgroundMonitor
from ...utilities.process.mp import BackgroundMonitor, ForkSemaphore, ForkEvent, ForkQueue
from ...utilities.py3_interop import AbstractContextManager
from ...utilities.process.mp import SafeQueue as PrQueue, SafeEvent
@ -36,41 +35,48 @@ class BackgroundReportService(BackgroundMonitor, AsyncManagerMixin):
super(BackgroundReportService, self).__init__(
task=task, wait_period=flush_frequency)
self._flush_threshold = flush_threshold
self._exit_event = TrEvent()
self._empty_state_event = TrEvent()
self._queue = TrQueue()
self._flush_event = ForkEvent()
self._empty_state_event = ForkEvent()
self._queue = ForkQueue()
self._queue_size = 0
self._res_waiting = Semaphore()
self._res_waiting = ForkSemaphore()
self._metrics = metrics
self._storage_uri = None
self._async_enable = async_enable
self._is_thread_mode_in_subprocess_flag = None
def set_storage_uri(self, uri):
self._storage_uri = uri
def set_subprocess_mode(self):
if isinstance(self._queue, TrQueue):
if isinstance(self._queue, ForkQueue):
self._write()
self._queue = PrQueue()
if not isinstance(self._exit_event, SafeEvent):
self._exit_event = SafeEvent()
if not isinstance(self._event, SafeEvent):
self._event = SafeEvent()
if not isinstance(self._empty_state_event, SafeEvent):
self._empty_state_event = SafeEvent()
super(BackgroundReportService, self).set_subprocess_mode()
def stop(self):
if isinstance(self._queue, PrQueue):
self._queue.close(self._event)
self._queue.close(self._flush_event)
if not self.is_subprocess_mode() or self.is_subprocess_alive():
self._exit_event.set()
self._flush_event.set()
super(BackgroundReportService, self).stop()
def flush(self):
while isinstance(self._queue, PrQueue) and self._queue.is_pending():
sleep(0.1)
self._queue_size = 0
# stop background process?!
if not self.is_subprocess_mode() or self.is_subprocess_alive():
self._event.set()
self._flush_event.set()
def wait_for_events(self, timeout=None):
if self._is_subprocess_mode_and_not_parent_process() and self.get_at_exit_state():
return
# noinspection PyProtectedMember
if self._is_subprocess_mode_and_not_parent_process():
while self._queue and not self._queue.empty():
@ -78,7 +84,7 @@ class BackgroundReportService(BackgroundMonitor, AsyncManagerMixin):
return
self._empty_state_event.clear()
if isinstance(self._empty_state_event, TrEvent):
if isinstance(self._empty_state_event, ForkEvent):
tic = time()
while self._thread and self._thread.is_alive() and (not timeout or time()-tic < timeout):
if self._empty_state_event.wait(timeout=1.0):
@ -100,16 +106,17 @@ class BackgroundReportService(BackgroundMonitor, AsyncManagerMixin):
# gel all data, work on local queue:
self._write()
# replace queue:
self._queue = TrQueue()
self._queue = ForkQueue()
self._queue_size = 0
self._event = TrEvent()
self._done_ev = TrEvent()
self._start_ev = TrEvent()
self._exit_event = TrEvent()
self._empty_state_event = TrEvent()
self._res_waiting = Semaphore()
self._event = ForkEvent()
self._done_ev = ForkEvent()
self._start_ev = ForkEvent()
self._flush_event = ForkEvent()
self._empty_state_event = ForkEvent()
self._res_waiting = ForkSemaphore()
# set thread mode
self._subprocess = False
self._is_thread_mode_in_subprocess_flag = None
# start background thread
self._thread = None
self._start()
@ -122,9 +129,11 @@ class BackgroundReportService(BackgroundMonitor, AsyncManagerMixin):
self.flush()
def daemon(self):
while not self._exit_event.wait(0):
self._event.wait(self._wait_timeout)
self._event.clear()
self._is_thread_mode_in_subprocess_flag = self._is_thread_mode_and_not_main_process()
while not self._event.wait(0):
self._flush_event.wait(self._wait_timeout)
self._flush_event.clear()
# lock state
self._res_waiting.acquire()
self._write()
@ -132,7 +141,7 @@ class BackgroundReportService(BackgroundMonitor, AsyncManagerMixin):
if self.get_num_results() > 0:
self.wait_for_results()
# set empty flag only if we are not waiting for exit signal
if not self._exit_event.wait(0):
if not self._event.wait(0):
self._empty_state_event.set()
# unlock state
self._res_waiting.release()
@ -157,8 +166,15 @@ class BackgroundReportService(BackgroundMonitor, AsyncManagerMixin):
break
if not events:
return
if self._is_thread_mode_in_subprocess_flag:
for e in events:
if isinstance(e, UploadEvent):
# noinspection PyProtectedMember
e._generate_file_name(force_pid_suffix=os.getpid())
res = self._metrics.write_events(
events, async_enable=self._async_enable, storage_uri=self._storage_uri)
if self._async_enable:
self._add_async_result(res)

View File

@ -3,14 +3,12 @@ import sys
from pathlib2 import Path
from logging import LogRecord, getLogger, basicConfig, getLevelName, INFO, WARNING, Formatter, makeLogRecord, warning
from logging.handlers import BufferingHandler
from six.moves.queue import Queue as TrQueue
from threading import Event as TrEvent
from .development.worker import DevWorker
from ...backend_api.services import events
from ...backend_api.session.session import MaxRequestSizeError
from ...config import config
from ...utilities.process.mp import BackgroundMonitor
from ...utilities.process.mp import BackgroundMonitor, ForkEvent, ForkQueue
from ...utilities.process.mp import SafeQueue as PrQueue, SafeEvent
@ -21,8 +19,8 @@ class BackgroundLogService(BackgroundMonitor):
super(BackgroundLogService, self).__init__(task=task, wait_period=wait_period)
self._worker = worker
self._task_id = task.id
self._queue = TrQueue()
self._flush = TrEvent()
self._queue = ForkQueue()
self._flush = ForkEvent()
self._last_event = None
self._offline_log_filename = offline_log_filename
self.session = session
@ -76,7 +74,7 @@ class BackgroundLogService(BackgroundMonitor):
self._queue.put(a_request)
def set_subprocess_mode(self):
if isinstance(self._queue, TrQueue):
if isinstance(self._queue, ForkQueue):
self.send_all_records()
self._queue = PrQueue()
super(BackgroundLogService, self).set_subprocess_mode()
@ -84,16 +82,16 @@ class BackgroundLogService(BackgroundMonitor):
def add_to_queue(self, record):
# check that we did not loose the reporter sub-process
if self.is_subprocess_mode() and not self._fast_is_subprocess_alive():
if self.is_subprocess_mode() and not self._fast_is_subprocess_alive() and not self.get_at_exit_state(): ##HANGS IF RACE HOLDS!
# we lost the reporting subprocess, let's switch to thread mode
# gel all data, work on local queue:
self.send_all_records()
# replace queue:
self._queue = TrQueue()
self._flush = TrEvent()
self._event = TrEvent()
self._done_ev = TrEvent()
self._start_ev = TrEvent()
self._queue = ForkQueue()
self._flush = ForkEvent()
self._event = ForkEvent()
self._done_ev = ForkEvent()
self._start_ev = ForkEvent()
# set thread mode
self._subprocess = False
# start background thread
@ -275,7 +273,8 @@ class TaskHandler(BufferingHandler):
if _background_log:
if not _background_log.is_subprocess_mode() or _background_log.is_alive():
_background_log.stop()
if wait:
if wait and (not _background_log.is_subprocess_mode() or
_background_log.is_subprocess_mode_and_parent_process()):
# noinspection PyBroadException
try:
timeout = 1. if _background_log.empty() else self.__wait_for_flush_timeout

View File

@ -551,9 +551,8 @@ class Task(_Task):
Task.__main_task = task
# register the main task for at exit hooks (there should only be one)
task.__register_at_exit(task._at_exit)
# patch OS forking if we are not logging with a subprocess
if not cls._report_subprocess_enabled:
PatchOsFork.patch_fork()
# always patch OS forking because of ProcessPool and the alike
PatchOsFork.patch_fork()
if auto_connect_frameworks:
is_auto_connect_frameworks_bool = not isinstance(auto_connect_frameworks, dict)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('hydra', True):
@ -3120,6 +3119,12 @@ class Task(_Task):
# protect sub-process at_exit (should never happen)
if self._at_exit_called and self._at_exit_called != get_current_thread_id():
return
# make sure we do not try to use events, because Python might deadlock itself.
# https://bugs.python.org/issue41606
if self.__is_subprocess():
BackgroundMonitor.set_at_exit_state(True)
# shutdown will clear the main, so we have to store it before.
# is_main = self.is_main_task()
# fix debugger signal in the middle, catch everything
@ -3313,7 +3318,8 @@ class Task(_Task):
# make sure no one will re-enter the shutdown method
self._at_exit_called = True
BackgroundMonitor.wait_for_sub_process(self)
if not is_sub_process and BackgroundMonitor.is_subprocess_enabled():
BackgroundMonitor.wait_for_sub_process(self)
@classmethod
def __register_at_exit(cls, exit_callback, only_remove_signal_and_exception_hooks=False):
@ -3470,9 +3476,10 @@ class Task(_Task):
else:
cls.__exit_hook.update_callback(exit_callback)
@classmethod
def _remove_at_exit_callbacks(cls):
cls.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
def _remove_at_exit_callbacks(self):
self.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
atexit.unregister(self.__exit_hook._exit_callback)
self._at_exit_called = True
@classmethod
def __get_task(

View File

@ -3,7 +3,7 @@ import pickle
import struct
import sys
from functools import partial
from multiprocessing import Lock, Event as ProcessEvent
from multiprocessing import Lock, Semaphore, Event as ProcessEvent
from threading import Thread, Event as TrEvent, RLock as ThreadRLock
from time import sleep, time
from typing import List, Dict, Optional
@ -33,25 +33,40 @@ except ImportError:
return False
class ForkSafeRLock(object):
def __init__(self):
self._lock = None
class _ForkSafeThreadSyncObject(object):
def __init__(self, functor):
self._sync = None
self._instance_pid = None
self._functor = functor
def _create(self):
# this part is not atomic, and there is not a lot we can do about it.
if self._instance_pid != os.getpid() or not self._sync:
# Notice! This is NOT atomic, this means the first time accessed, two concurrent calls might
# end up overwriting each others, object
# even tough it sounds horrible, the worst case in our usage scenario
# is the first call usage is not "atomic"
# Notice the order! we first create the object and THEN update the pid,
# this is so whatever happens we Never try to used the old (pre-forked copy) of the synchronization object
self._sync = self._functor()
self._instance_pid = os.getpid()
class ForkSafeRLock(_ForkSafeThreadSyncObject):
def __init__(self):
super(ForkSafeRLock, self).__init__(ThreadRLock)
def acquire(self, *args, **kwargs):
self.create()
return self._lock.acquire(*args, **kwargs)
self._create()
return self._sync.acquire(*args, **kwargs)
def release(self, *args, **kwargs):
if self._lock is None:
if self._sync is None:
return None
return self._lock.release(*args, **kwargs)
def create(self):
# this part is not atomic, and there is not a lot we can do about it.
if self._instance_pid != os.getpid() or not self._lock:
self._lock = ThreadRLock()
self._instance_pid = os.getpid()
self._create()
return self._sync.release(*args, **kwargs)
def __enter__(self):
"""Return `self` upon entering the runtime context."""
@ -64,15 +79,98 @@ class ForkSafeRLock(object):
self.release()
class ForkSemaphore(_ForkSafeThreadSyncObject):
def __init__(self, value=1):
super(ForkSemaphore, self).__init__(functor=partial(Semaphore, value))
def acquire(self, *args, **kwargs):
self._create()
return self._sync.acquire(*args, **kwargs)
def release(self, *args, **kwargs):
if self._sync is None:
return None
self._create()
return self._sync.release(*args, **kwargs)
def get_value(self):
self._create()
return self._sync.get_value()
def __enter__(self):
"""Return `self` upon entering the runtime context."""
self.acquire()
return self
def __exit__(self, exc_type, exc_value, traceback):
"""Raise any exception triggered within the runtime context."""
# Do whatever cleanup.
self.release()
class ForkEvent(_ForkSafeThreadSyncObject):
def __init__(self):
super(ForkEvent, self).__init__(TrEvent)
def set(self):
self._create()
return self._sync.set()
def clear(self):
if self._sync is None:
return None
self._create()
return self._sync.clear()
def is_set(self):
self._create()
return self._sync.is_set()
def wait(self, *args, **kwargs):
self._create()
return self._sync.wait(*args, **kwargs)
class ForkQueue(_ForkSafeThreadSyncObject):
def __init__(self):
super(ForkQueue, self).__init__(TrQueue)
def get(self, *args, **kwargs):
self._create()
return self._sync.get(*args, **kwargs)
def put(self, *args, **kwargs):
self._create()
return self._sync.put(*args, **kwargs)
def empty(self):
if not self._sync:
return True
self._create()
return self._sync.empty()
def full(self):
if not self._sync:
return False
self._create()
return self._sync.full()
def close(self):
if not self._sync:
return
self._create()
return self._sync.close()
class ThreadCalls(object):
def __init__(self):
self._queue = TrQueue()
self._queue = ForkQueue()
self._thread = Thread(target=self._worker)
self._thread.daemon = True
self._thread.start()
def is_alive(self):
return bool(self._thread)
return bool(self._thread) and self._thread.is_alive()
def apply_async(self, func, args=None):
if not func:
@ -134,7 +232,7 @@ class SingletonThreadPool(object):
@classmethod
def is_active(cls):
return cls.__thread_pool and cls.__thread_pool.is_alive()
return cls.__thread_pool and cls.__thread_pool_pid == os.getpid() and cls.__thread_pool.is_alive()
class SafeQueue(object):
@ -156,7 +254,7 @@ class SafeQueue(object):
except Exception:
pass
self._internal_q = None
self._q_size = 0
self._q_size = [] # list of PIDs we pushed, so this is atomic
def empty(self):
return self._q.empty() and (not self._internal_q or self._internal_q.empty())
@ -164,12 +262,13 @@ class SafeQueue(object):
def is_pending(self):
# check if we have pending requests to be pushed (it does not mean they were pulled)
# only call from main put process
return self._q_size > 0
return self._get_q_size_len() > 0
def close(self, event, timeout=3.0):
# wait until all pending requests pushed
tic = time()
prev_q_size = self._q_size
pid = os.getpid()
prev_q_size = self._get_q_size_len(pid)
while self.is_pending():
if event:
event.set()
@ -179,10 +278,10 @@ class SafeQueue(object):
# timeout is for the maximum time to pull a single object from the queue,
# this way if we get stuck we notice quickly and abort
if timeout and (time()-tic) > timeout:
if prev_q_size == self._q_size:
if prev_q_size == self._get_q_size_len(pid):
break
else:
prev_q_size = self._q_size
prev_q_size = self._get_q_size_len(pid)
tic = time()
def get(self, *args, **kwargs):
@ -206,25 +305,37 @@ class SafeQueue(object):
return buffer
def put(self, obj):
# not atomic when forking for the first time
# GIL will make sure it is atomic
self._q_size += 1
self._q_size.append(os.getpid())
# make sure the block put is done in the thread pool i.e. in the background
obj = pickle.dumps(obj)
if BackgroundMonitor.get_at_exit_state():
self._q_put(obj)
return
self.__thread_pool.get().apply_async(self._q_put, args=(obj, ))
def _get_q_size_len(self, pid=None):
pid = pid or os.getpid()
return len([p for p in self._q_size if p == pid])
def _q_put(self, obj):
try:
self._q.put(obj)
except BaseException:
# make sure we zero the _q_size of the process dies (i.e. queue put fails)
self._q_size = 0
self._q_size = []
raise
pid = os.getpid()
# GIL will make sure it is atomic
self._q_size -= 1
# pop the First "counter" that is ours (i.e. pid == os.getpid())
p = None
while p != pid:
p = self._q_size.pop()
def _init_reader_thread(self):
if not self._internal_q:
self._internal_q = TrQueue()
self._internal_q = ForkQueue()
if not self._reader_thread or not self._reader_thread.is_alive():
# read before we start the thread
self._reader_thread = Thread(target=self._reader_daemon)
@ -333,14 +444,16 @@ class BackgroundMonitor(object):
_main_process_task_id = None
_parent_pid = None
_sub_process_started = None
_at_exit = False
_instances = {} # type: Dict[int, List[BackgroundMonitor]]
def __init__(self, task, wait_period):
self._event = TrEvent()
self._done_ev = TrEvent()
self._start_ev = TrEvent()
self._event = ForkEvent()
self._done_ev = ForkEvent()
self._start_ev = ForkEvent()
self._task_pid = os.getpid()
self._thread = None
self._thread_pid = None
self._wait_timeout = wait_period
self._subprocess = None if task.is_main_task() else False
self._task_id = task.id
@ -362,12 +475,15 @@ class BackgroundMonitor(object):
def wait(self, timeout=None):
if not self._done_ev:
return
self._done_ev.wait(timeout=timeout)
if not self.is_subprocess_mode() or self.is_subprocess_mode_and_parent_process():
self._done_ev.wait(timeout=timeout)
def _start(self):
# if we already started do nothing
if isinstance(self._thread, Thread):
return
if self._thread_pid == os.getpid():
return
self._thread_pid = os.getpid()
self._thread = Thread(target=self._daemon)
self._thread.daemon = True
self._thread.start()
@ -376,7 +492,8 @@ class BackgroundMonitor(object):
if not self._thread:
return
if not self.is_subprocess_mode() or self.is_subprocess_alive():
if not self._is_subprocess_mode_and_not_parent_process() and (
not self.is_subprocess_mode() or self.is_subprocess_alive()):
self._event.set()
if isinstance(self._thread, Thread):
@ -384,7 +501,7 @@ class BackgroundMonitor(object):
self._get_instances().remove(self)
except ValueError:
pass
self._thread = None
self._thread = False
def daemon(self):
while True:
@ -396,7 +513,7 @@ class BackgroundMonitor(object):
self._start_ev.set()
self.daemon()
self.post_execution()
self._thread = None
self._thread = False
def post_execution(self):
self._done_ev.set()
@ -540,7 +657,12 @@ class BackgroundMonitor(object):
for i in instances:
# DO NOT CHANGE, we need to catch base exception, if the process gte's killed
try:
while i._thread and i._thread.is_alive():
while i._thread is None or (i._thread and i._thread.is_alive()):
# thread is still not up
if i._thread is None:
sleep(0.1)
continue
# noinspection PyBroadException
try:
p = psutil.Process(parent_pid)
@ -564,12 +686,15 @@ class BackgroundMonitor(object):
return
def is_alive(self):
if self.is_subprocess_mode():
return self.is_subprocess_alive() and self._thread \
and self._start_ev.is_set() and not self._done_ev.is_set()
else:
if not self.is_subprocess_mode():
return isinstance(self._thread, Thread) and self._thread.is_alive()
if self.get_at_exit_state():
return self.is_subprocess_alive() and self._thread
return self.is_subprocess_alive() and self._thread and \
self._start_ev.is_set() and not self._done_ev.is_set()
@classmethod
def _fast_is_subprocess_alive(cls):
if not cls._main_process_proc_obj:
@ -616,6 +741,16 @@ class BackgroundMonitor(object):
def _is_subprocess_mode_and_not_parent_process(self):
return self.is_subprocess_mode() and self._parent_pid != os.getpid()
def is_subprocess_mode_and_parent_process(self):
return self.is_subprocess_mode() and self._parent_pid == os.getpid()
def _is_thread_mode_and_not_main_process(self):
if self.is_subprocess_mode():
return False
from ... import Task
# noinspection PyProtectedMember
return Task._Task__is_subprocess()
@classmethod
def is_subprocess_enabled(cls, task=None):
return bool(cls._main_process) and (not task or task.id == cls._main_process_task_id)
@ -645,6 +780,14 @@ class BackgroundMonitor(object):
while cls.is_subprocess_alive(task=task) and (not timeout or time()-tic < timeout):
sleep(0.03)
@classmethod
def set_at_exit_state(cls, state=True):
cls._at_exit = bool(state)
@classmethod
def get_at_exit_state(cls):
return cls._at_exit
def leave_process(status=0):
# type: (int) -> None