From c7c0462e8921d7f663ac3fced93fabc503e50c67 Mon Sep 17 00:00:00 2001
From: allegroai <>
Date: Sat, 8 Jan 2022 22:58:34 +0200
Subject: [PATCH] Fix fork (process pool) hangs or drops reports when reports
 are at the end of the forked function, in both threaded and subprocess mode
 reporting

---
 clearml/backend_interface/metrics/reporter.py |  68 +++---
 clearml/backend_interface/task/log.py         |  25 +-
 clearml/task.py                               |  21 +-
 clearml/utilities/process/mp.py               | 221 ++++++++++++++----
 4 files changed, 250 insertions(+), 85 deletions(-)

diff --git a/clearml/backend_interface/metrics/reporter.py b/clearml/backend_interface/metrics/reporter.py
index 834900cb..e90413d1 100644
--- a/clearml/backend_interface/metrics/reporter.py
+++ b/clearml/backend_interface/metrics/reporter.py
@@ -2,13 +2,12 @@ import datetime
 import json
 import logging
 import math
-from multiprocessing import Semaphore
-from threading import Event as TrEvent
+import os
 from time import sleep, time
 
 import numpy as np
 import six
-from six.moves.queue import Queue as TrQueue, Empty
+from six.moves.queue import Empty
 
 from .events import (
     ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload,
@@ -21,7 +20,7 @@ from ...utilities.plotly_reporter import (
     create_2d_histogram_plot, create_value_matrix, create_3d_surface,
     create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict,
     create_image_plot, create_plotly_table, )
-from ...utilities.process.mp import BackgroundMonitor
+from ...utilities.process.mp import BackgroundMonitor, ForkSemaphore, ForkEvent, ForkQueue
 from ...utilities.py3_interop import AbstractContextManager
 from ...utilities.process.mp import SafeQueue as PrQueue, SafeEvent
 
@@ -36,41 +35,48 @@ class BackgroundReportService(BackgroundMonitor, AsyncManagerMixin):
         super(BackgroundReportService, self).__init__(
             task=task, wait_period=flush_frequency)
         self._flush_threshold = flush_threshold
-        self._exit_event = TrEvent()
-        self._empty_state_event = TrEvent()
-        self._queue = TrQueue()
+        self._flush_event = ForkEvent()
+        self._empty_state_event = ForkEvent()
+        self._queue = ForkQueue()
         self._queue_size = 0
-        self._res_waiting = Semaphore()
+        self._res_waiting = ForkSemaphore()
         self._metrics = metrics
         self._storage_uri = None
         self._async_enable = async_enable
+        self._is_thread_mode_in_subprocess_flag = None
 
     def set_storage_uri(self, uri):
         self._storage_uri = uri
 
     def set_subprocess_mode(self):
-        if isinstance(self._queue, TrQueue):
+        if isinstance(self._queue, ForkQueue):
             self._write()
             self._queue = PrQueue()
-        if not isinstance(self._exit_event, SafeEvent):
-            self._exit_event = SafeEvent()
+        if not isinstance(self._event, SafeEvent):
+            self._event = SafeEvent()
         if not isinstance(self._empty_state_event, SafeEvent):
             self._empty_state_event = SafeEvent()
         super(BackgroundReportService, self).set_subprocess_mode()
 
     def stop(self):
         if isinstance(self._queue, PrQueue):
-            self._queue.close(self._event)
+            self._queue.close(self._flush_event)
         if not self.is_subprocess_mode() or self.is_subprocess_alive():
-            self._exit_event.set()
+            self._flush_event.set()
         super(BackgroundReportService, self).stop()
 
     def flush(self):
+        while isinstance(self._queue, PrQueue) and self._queue.is_pending():
+            sleep(0.1)
         self._queue_size = 0
+        # stop background process?!
         if not self.is_subprocess_mode() or self.is_subprocess_alive():
-            self._event.set()
+            self._flush_event.set()
 
     def wait_for_events(self, timeout=None):
+        if self._is_subprocess_mode_and_not_parent_process() and self.get_at_exit_state():
+            return
+
         # noinspection PyProtectedMember
         if self._is_subprocess_mode_and_not_parent_process():
             while self._queue and not self._queue.empty():
@@ -78,7 +84,7 @@ class BackgroundReportService(BackgroundMonitor, AsyncManagerMixin):
             return
 
         self._empty_state_event.clear()
-        if isinstance(self._empty_state_event, TrEvent):
+        if isinstance(self._empty_state_event, ForkEvent):
             tic = time()
             while self._thread and self._thread.is_alive() and (not timeout or time()-tic < timeout):
                 if self._empty_state_event.wait(timeout=1.0):
@@ -100,16 +106,17 @@ class BackgroundReportService(BackgroundMonitor, AsyncManagerMixin):
             # gel all data, work on local queue:
             self._write()
             # replace queue:
-            self._queue = TrQueue()
+            self._queue = ForkQueue()
             self._queue_size = 0
-            self._event = TrEvent()
-            self._done_ev = TrEvent()
-            self._start_ev = TrEvent()
-            self._exit_event = TrEvent()
-            self._empty_state_event = TrEvent()
-            self._res_waiting = Semaphore()
+            self._event = ForkEvent()
+            self._done_ev = ForkEvent()
+            self._start_ev = ForkEvent()
+            self._flush_event = ForkEvent()
+            self._empty_state_event = ForkEvent()
+            self._res_waiting = ForkSemaphore()
             # set thread mode
             self._subprocess = False
+            self._is_thread_mode_in_subprocess_flag = None
             # start background thread
             self._thread = None
             self._start()
@@ -122,9 +129,11 @@ class BackgroundReportService(BackgroundMonitor, AsyncManagerMixin):
             self.flush()
 
     def daemon(self):
-        while not self._exit_event.wait(0):
-            self._event.wait(self._wait_timeout)
-            self._event.clear()
+        self._is_thread_mode_in_subprocess_flag = self._is_thread_mode_and_not_main_process()
+
+        while not self._event.wait(0):
+            self._flush_event.wait(self._wait_timeout)
+            self._flush_event.clear()
             # lock state
             self._res_waiting.acquire()
             self._write()
@@ -132,7 +141,7 @@ class BackgroundReportService(BackgroundMonitor, AsyncManagerMixin):
             if self.get_num_results() > 0:
                 self.wait_for_results()
             # set empty flag only if we are not waiting for exit signal
-            if not self._exit_event.wait(0):
+            if not self._event.wait(0):
                 self._empty_state_event.set()
             # unlock state
             self._res_waiting.release()
@@ -157,8 +166,15 @@ class BackgroundReportService(BackgroundMonitor, AsyncManagerMixin):
                 break
         if not events:
             return
+        if self._is_thread_mode_in_subprocess_flag:
+            for e in events:
+                if isinstance(e, UploadEvent):
+                    # noinspection PyProtectedMember
+                    e._generate_file_name(force_pid_suffix=os.getpid())
+
         res = self._metrics.write_events(
             events, async_enable=self._async_enable, storage_uri=self._storage_uri)
+
         if self._async_enable:
             self._add_async_result(res)
 
diff --git a/clearml/backend_interface/task/log.py b/clearml/backend_interface/task/log.py
index b7637ad7..7f9a28c9 100644
--- a/clearml/backend_interface/task/log.py
+++ b/clearml/backend_interface/task/log.py
@@ -3,14 +3,12 @@ import sys
 from pathlib2 import Path
 from logging import LogRecord, getLogger, basicConfig, getLevelName, INFO, WARNING, Formatter, makeLogRecord, warning
 from logging.handlers import BufferingHandler
-from six.moves.queue import Queue as TrQueue
-from threading import Event as TrEvent
 
 from .development.worker import DevWorker
 from ...backend_api.services import events
 from ...backend_api.session.session import MaxRequestSizeError
 from ...config import config
-from ...utilities.process.mp import BackgroundMonitor
+from ...utilities.process.mp import BackgroundMonitor, ForkEvent, ForkQueue
 from ...utilities.process.mp import SafeQueue as PrQueue, SafeEvent
 
 
@@ -21,8 +19,8 @@ class BackgroundLogService(BackgroundMonitor):
         super(BackgroundLogService, self).__init__(task=task, wait_period=wait_period)
         self._worker = worker
         self._task_id = task.id
-        self._queue = TrQueue()
-        self._flush = TrEvent()
+        self._queue = ForkQueue()
+        self._flush = ForkEvent()
         self._last_event = None
         self._offline_log_filename = offline_log_filename
         self.session = session
@@ -76,7 +74,7 @@ class BackgroundLogService(BackgroundMonitor):
                 self._queue.put(a_request)
 
     def set_subprocess_mode(self):
-        if isinstance(self._queue, TrQueue):
+        if isinstance(self._queue, ForkQueue):
             self.send_all_records()
             self._queue = PrQueue()
         super(BackgroundLogService, self).set_subprocess_mode()
@@ -84,16 +82,16 @@ class BackgroundLogService(BackgroundMonitor):
 
     def add_to_queue(self, record):
         # check that we did not loose the reporter sub-process
-        if self.is_subprocess_mode() and not self._fast_is_subprocess_alive():
+        if self.is_subprocess_mode() and not self._fast_is_subprocess_alive() and not self.get_at_exit_state():  ##HANGS IF RACE HOLDS!
             # we lost the reporting subprocess, let's switch to thread mode
             # gel all data, work on local queue:
             self.send_all_records()
             # replace queue:
-            self._queue = TrQueue()
-            self._flush = TrEvent()
-            self._event = TrEvent()
-            self._done_ev = TrEvent()
-            self._start_ev = TrEvent()
+            self._queue = ForkQueue()
+            self._flush = ForkEvent()
+            self._event = ForkEvent()
+            self._done_ev = ForkEvent()
+            self._start_ev = ForkEvent()
             # set thread mode
             self._subprocess = False
             # start background thread
@@ -275,7 +273,8 @@ class TaskHandler(BufferingHandler):
         if _background_log:
             if not _background_log.is_subprocess_mode() or _background_log.is_alive():
                 _background_log.stop()
-                if wait:
+                if wait and (not _background_log.is_subprocess_mode() or
+                             _background_log.is_subprocess_mode_and_parent_process()):
                     # noinspection PyBroadException
                     try:
                         timeout = 1. if _background_log.empty() else self.__wait_for_flush_timeout
diff --git a/clearml/task.py b/clearml/task.py
index 037dbe40..9a6bb519 100644
--- a/clearml/task.py
+++ b/clearml/task.py
@@ -551,9 +551,8 @@ class Task(_Task):
             Task.__main_task = task
             # register the main task for at exit hooks (there should only be one)
             task.__register_at_exit(task._at_exit)
-            # patch OS forking if we are not logging with a subprocess
-            if not cls._report_subprocess_enabled:
-                PatchOsFork.patch_fork()
+            # always patch OS forking because of ProcessPool and the alike
+            PatchOsFork.patch_fork()
             if auto_connect_frameworks:
                 is_auto_connect_frameworks_bool = not isinstance(auto_connect_frameworks, dict)
                 if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('hydra', True):
@@ -3120,6 +3119,12 @@ class Task(_Task):
         # protect sub-process at_exit (should never happen)
         if self._at_exit_called and self._at_exit_called != get_current_thread_id():
             return
+
+        # make sure we do not try to use events, because Python might deadlock itself.
+        # https://bugs.python.org/issue41606
+        if self.__is_subprocess():
+            BackgroundMonitor.set_at_exit_state(True)
+
         # shutdown will clear the main, so we have to store it before.
         # is_main = self.is_main_task()
         # fix debugger signal in the middle, catch everything
@@ -3313,7 +3318,8 @@ class Task(_Task):
 
         # make sure no one will re-enter the shutdown method
         self._at_exit_called = True
-        BackgroundMonitor.wait_for_sub_process(self)
+        if not is_sub_process and BackgroundMonitor.is_subprocess_enabled():
+            BackgroundMonitor.wait_for_sub_process(self)
 
     @classmethod
     def __register_at_exit(cls, exit_callback, only_remove_signal_and_exception_hooks=False):
@@ -3470,9 +3476,10 @@ class Task(_Task):
         else:
             cls.__exit_hook.update_callback(exit_callback)
 
-    @classmethod
-    def _remove_at_exit_callbacks(cls):
-        cls.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
+    def _remove_at_exit_callbacks(self):
+        self.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
+        atexit.unregister(self.__exit_hook._exit_callback)
+        self._at_exit_called = True
 
     @classmethod
     def __get_task(
diff --git a/clearml/utilities/process/mp.py b/clearml/utilities/process/mp.py
index 6c035926..0e3c4d9c 100644
--- a/clearml/utilities/process/mp.py
+++ b/clearml/utilities/process/mp.py
@@ -3,7 +3,7 @@ import pickle
 import struct
 import sys
 from functools import partial
-from multiprocessing import Lock, Event as ProcessEvent
+from multiprocessing import Lock, Semaphore, Event as ProcessEvent
 from threading import Thread, Event as TrEvent, RLock as ThreadRLock
 from time import sleep, time
 from typing import List, Dict, Optional
@@ -33,25 +33,40 @@ except ImportError:
         return False
 
 
-class ForkSafeRLock(object):
-    def __init__(self):
-        self._lock = None
+class _ForkSafeThreadSyncObject(object):
+
+    def __init__(self, functor):
+        self._sync = None
         self._instance_pid = None
+        self._functor = functor
+
+    def _create(self):
+        # this part is not atomic, and there is not a lot we can do about it.
+        if self._instance_pid != os.getpid() or not self._sync:
+            # Notice! This is NOT atomic, this means the first time accessed, two concurrent calls might
+            # end up overwriting each others, object
+            # even tough it sounds horrible, the worst case in our usage scenario
+            # is the first call usage is not "atomic"
+
+            # Notice the order! we first create the object and THEN update the pid,
+            # this is so whatever happens we Never try to used the old (pre-forked copy) of the synchronization object
+            self._sync = self._functor()
+            self._instance_pid = os.getpid()
+
+
+class ForkSafeRLock(_ForkSafeThreadSyncObject):
+    def __init__(self):
+        super(ForkSafeRLock, self).__init__(ThreadRLock)
 
     def acquire(self, *args, **kwargs):
-        self.create()
-        return self._lock.acquire(*args, **kwargs)
+        self._create()
+        return self._sync.acquire(*args, **kwargs)
 
     def release(self, *args, **kwargs):
-        if self._lock is None:
+        if self._sync is None:
             return None
-        return self._lock.release(*args, **kwargs)
-
-    def create(self):
-        # this part is not atomic, and there is not a lot we can do about it.
-        if self._instance_pid != os.getpid() or not self._lock:
-            self._lock = ThreadRLock()
-            self._instance_pid = os.getpid()
+        self._create()
+        return self._sync.release(*args, **kwargs)
 
     def __enter__(self):
         """Return `self` upon entering the runtime context."""
@@ -64,15 +79,98 @@ class ForkSafeRLock(object):
         self.release()
 
 
+class ForkSemaphore(_ForkSafeThreadSyncObject):
+    def __init__(self, value=1):
+        super(ForkSemaphore, self).__init__(functor=partial(Semaphore, value))
+
+    def acquire(self, *args, **kwargs):
+        self._create()
+        return self._sync.acquire(*args, **kwargs)
+
+    def release(self, *args, **kwargs):
+        if self._sync is None:
+            return None
+        self._create()
+        return self._sync.release(*args, **kwargs)
+
+    def get_value(self):
+        self._create()
+        return self._sync.get_value()
+
+    def __enter__(self):
+        """Return `self` upon entering the runtime context."""
+        self.acquire()
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        """Raise any exception triggered within the runtime context."""
+        # Do whatever cleanup.
+        self.release()
+
+
+class ForkEvent(_ForkSafeThreadSyncObject):
+    def __init__(self):
+        super(ForkEvent, self).__init__(TrEvent)
+
+    def set(self):
+        self._create()
+        return self._sync.set()
+
+    def clear(self):
+        if self._sync is None:
+            return None
+        self._create()
+        return self._sync.clear()
+
+    def is_set(self):
+        self._create()
+        return self._sync.is_set()
+
+    def wait(self, *args, **kwargs):
+        self._create()
+        return self._sync.wait(*args, **kwargs)
+
+
+class ForkQueue(_ForkSafeThreadSyncObject):
+    def __init__(self):
+        super(ForkQueue, self).__init__(TrQueue)
+
+    def get(self, *args, **kwargs):
+        self._create()
+        return self._sync.get(*args, **kwargs)
+
+    def put(self, *args, **kwargs):
+        self._create()
+        return self._sync.put(*args, **kwargs)
+
+    def empty(self):
+        if not self._sync:
+            return True
+        self._create()
+        return self._sync.empty()
+
+    def full(self):
+        if not self._sync:
+            return False
+        self._create()
+        return self._sync.full()
+
+    def close(self):
+        if not self._sync:
+            return
+        self._create()
+        return self._sync.close()
+
+
 class ThreadCalls(object):
     def __init__(self):
-        self._queue = TrQueue()
+        self._queue = ForkQueue()
         self._thread = Thread(target=self._worker)
         self._thread.daemon = True
         self._thread.start()
 
     def is_alive(self):
-        return bool(self._thread)
+        return bool(self._thread) and self._thread.is_alive()
 
     def apply_async(self, func, args=None):
         if not func:
@@ -134,7 +232,7 @@ class SingletonThreadPool(object):
 
     @classmethod
     def is_active(cls):
-        return cls.__thread_pool and cls.__thread_pool.is_alive()
+        return cls.__thread_pool and cls.__thread_pool_pid == os.getpid() and cls.__thread_pool.is_alive()
 
 
 class SafeQueue(object):
@@ -156,7 +254,7 @@ class SafeQueue(object):
         except Exception:
             pass
         self._internal_q = None
-        self._q_size = 0
+        self._q_size = []  # list of PIDs we pushed, so this is atomic
 
     def empty(self):
         return self._q.empty() and (not self._internal_q or self._internal_q.empty())
@@ -164,12 +262,13 @@ class SafeQueue(object):
     def is_pending(self):
         # check if we have pending requests to be pushed (it does not mean they were pulled)
         # only call from main put process
-        return self._q_size > 0
+        return self._get_q_size_len() > 0
 
     def close(self, event, timeout=3.0):
         # wait until all pending requests pushed
         tic = time()
-        prev_q_size = self._q_size
+        pid = os.getpid()
+        prev_q_size = self._get_q_size_len(pid)
         while self.is_pending():
             if event:
                 event.set()
@@ -179,10 +278,10 @@ class SafeQueue(object):
             # timeout is for the maximum time to pull a single object from the queue,
             # this way if we get stuck we notice quickly and abort
             if timeout and (time()-tic) > timeout:
-                if prev_q_size == self._q_size:
+                if prev_q_size == self._get_q_size_len(pid):
                     break
                 else:
-                    prev_q_size = self._q_size
+                    prev_q_size = self._get_q_size_len(pid)
                     tic = time()
 
     def get(self, *args, **kwargs):
@@ -206,25 +305,37 @@ class SafeQueue(object):
         return buffer
 
     def put(self, obj):
+        # not atomic when forking for the first time
         # GIL will make sure it is atomic
-        self._q_size += 1
+        self._q_size.append(os.getpid())
         # make sure the block put is done in the thread pool i.e. in the background
         obj = pickle.dumps(obj)
+        if BackgroundMonitor.get_at_exit_state():
+            self._q_put(obj)
+            return
         self.__thread_pool.get().apply_async(self._q_put, args=(obj, ))
 
+    def _get_q_size_len(self, pid=None):
+        pid = pid or os.getpid()
+        return len([p for p in self._q_size if p == pid])
+
     def _q_put(self, obj):
         try:
             self._q.put(obj)
         except BaseException:
             # make sure we zero the _q_size of the process dies (i.e. queue put fails)
-            self._q_size = 0
+            self._q_size = []
             raise
+        pid = os.getpid()
         # GIL will make sure it is atomic
-        self._q_size -= 1
+        # pop the First "counter" that is ours (i.e. pid == os.getpid())
+        p = None
+        while p != pid:
+            p = self._q_size.pop()
 
     def _init_reader_thread(self):
         if not self._internal_q:
-            self._internal_q = TrQueue()
+            self._internal_q = ForkQueue()
         if not self._reader_thread or not self._reader_thread.is_alive():
             # read before we start the thread
             self._reader_thread = Thread(target=self._reader_daemon)
@@ -333,14 +444,16 @@ class BackgroundMonitor(object):
     _main_process_task_id = None
     _parent_pid = None
     _sub_process_started = None
+    _at_exit = False
     _instances = {}  # type: Dict[int, List[BackgroundMonitor]]
 
     def __init__(self, task, wait_period):
-        self._event = TrEvent()
-        self._done_ev = TrEvent()
-        self._start_ev = TrEvent()
+        self._event = ForkEvent()
+        self._done_ev = ForkEvent()
+        self._start_ev = ForkEvent()
         self._task_pid = os.getpid()
         self._thread = None
+        self._thread_pid = None
         self._wait_timeout = wait_period
         self._subprocess = None if task.is_main_task() else False
         self._task_id = task.id
@@ -362,12 +475,15 @@ class BackgroundMonitor(object):
     def wait(self, timeout=None):
         if not self._done_ev:
             return
-        self._done_ev.wait(timeout=timeout)
+        if not self.is_subprocess_mode() or self.is_subprocess_mode_and_parent_process():
+            self._done_ev.wait(timeout=timeout)
 
     def _start(self):
         # if we already started do nothing
         if isinstance(self._thread, Thread):
-            return
+            if self._thread_pid == os.getpid():
+                return
+        self._thread_pid = os.getpid()
         self._thread = Thread(target=self._daemon)
         self._thread.daemon = True
         self._thread.start()
@@ -376,7 +492,8 @@ class BackgroundMonitor(object):
         if not self._thread:
             return
 
-        if not self.is_subprocess_mode() or self.is_subprocess_alive():
+        if not self._is_subprocess_mode_and_not_parent_process() and (
+                not self.is_subprocess_mode() or self.is_subprocess_alive()):
             self._event.set()
 
         if isinstance(self._thread, Thread):
@@ -384,7 +501,7 @@ class BackgroundMonitor(object):
                 self._get_instances().remove(self)
             except ValueError:
                 pass
-            self._thread = None
+            self._thread = False
 
     def daemon(self):
         while True:
@@ -396,7 +513,7 @@ class BackgroundMonitor(object):
         self._start_ev.set()
         self.daemon()
         self.post_execution()
-        self._thread = None
+        self._thread = False
 
     def post_execution(self):
         self._done_ev.set()
@@ -540,7 +657,12 @@ class BackgroundMonitor(object):
         for i in instances:
             # DO NOT CHANGE, we need to catch base exception, if the process gte's killed
             try:
-                while i._thread and i._thread.is_alive():
+                while i._thread is None or (i._thread and i._thread.is_alive()):
+                    # thread is still not up
+                    if i._thread is None:
+                        sleep(0.1)
+                        continue
+
                     # noinspection PyBroadException
                     try:
                         p = psutil.Process(parent_pid)
@@ -564,12 +686,15 @@ class BackgroundMonitor(object):
         return
 
     def is_alive(self):
-        if self.is_subprocess_mode():
-            return self.is_subprocess_alive() and self._thread \
-                   and self._start_ev.is_set() and not self._done_ev.is_set()
-        else:
+        if not self.is_subprocess_mode():
             return isinstance(self._thread, Thread) and self._thread.is_alive()
 
+        if self.get_at_exit_state():
+            return self.is_subprocess_alive() and self._thread
+
+        return self.is_subprocess_alive() and self._thread and \
+               self._start_ev.is_set() and not self._done_ev.is_set()
+
     @classmethod
     def _fast_is_subprocess_alive(cls):
         if not cls._main_process_proc_obj:
@@ -616,6 +741,16 @@ class BackgroundMonitor(object):
     def _is_subprocess_mode_and_not_parent_process(self):
         return self.is_subprocess_mode() and self._parent_pid != os.getpid()
 
+    def is_subprocess_mode_and_parent_process(self):
+        return self.is_subprocess_mode() and self._parent_pid == os.getpid()
+
+    def _is_thread_mode_and_not_main_process(self):
+        if self.is_subprocess_mode():
+            return False
+        from ... import Task
+        # noinspection PyProtectedMember
+        return Task._Task__is_subprocess()
+
     @classmethod
     def is_subprocess_enabled(cls, task=None):
         return bool(cls._main_process) and (not task or task.id == cls._main_process_task_id)
@@ -645,6 +780,14 @@ class BackgroundMonitor(object):
         while cls.is_subprocess_alive(task=task) and (not timeout or time()-tic < timeout):
             sleep(0.03)
 
+    @classmethod
+    def set_at_exit_state(cls, state=True):
+        cls._at_exit = bool(state)
+
+    @classmethod
+    def get_at_exit_state(cls):
+        return cls._at_exit
+
 
 def leave_process(status=0):
     # type: (int) -> None