diff --git a/clearml/binding/environ_bind.py b/clearml/binding/environ_bind.py index f91b86cb..dbf53540 100644 --- a/clearml/binding/environ_bind.py +++ b/clearml/binding/environ_bind.py @@ -1,6 +1,7 @@ import os +from functools import partial from time import sleep - +from multiprocessing import pool import six from ..config import TASK_LOG_ENVIRONMENT, running_remotely, config @@ -62,15 +63,44 @@ class EnvironmentBind(object): cls._current_task.connect(env_param, cls._environment_section) +class SimpleQueueWrapper(object): + def __init__(self, task, simple_queue): + self.__current_task = task + self.__simple_queue = simple_queue + + def __getattr__(self, attr): + if attr in ["__simple_queue", "__current_task"]: + return self.__dict__.get(attr) + + if attr == "put": + def _patched_put(*a_args, **a_kwargs): + try: + task = self.__current_task + # noinspection PyProtectedMember + task._at_exit() + except: # noqa + pass + return getattr(self.__simple_queue, "put")(*a_args, **a_kwargs) + + return _patched_put + + return getattr(self.__simple_queue, attr) + + class PatchOsFork(object): _original_fork = None _current_task = None + _original_process_run = None @classmethod def patch_fork(cls, task): cls._current_task = task if not task: return + + # first we need to patch regular fork + # because forked processes do not support atexit, they call os._exit directly) + # noinspection PyBroadException try: # only once @@ -84,8 +114,70 @@ class PatchOsFork(object): except Exception: pass + # now we need to patch Process.run because the bootstrap code + # shuts everything down before calling os._exit that we patched above + try: + from multiprocessing.process import BaseProcess + PatchOsFork._original_process_run = BaseProcess.run + BaseProcess.run = PatchOsFork._patched_process_run + except: # noqa + pass + + @staticmethod + def _patched_pool_worker(original_worker, *args, **kwargs): + if not PatchOsFork._current_task: + return original_worker(*args, **kwargs) + + try: + if len(args) >= 2 and hasattr(args[1], "put"): + args = list(args) + args[1] = SimpleQueueWrapper(PatchOsFork._current_task, args[1]) + args = tuple(args) + elif "outqueue" in kwargs and hasattr(kwargs["outqueue"], "put"): + kwargs["outqueue"] = SimpleQueueWrapper(PatchOsFork._current_task, kwargs["outqueue"]) + except: # noqa + pass + + return original_worker(*args, **kwargs) + + @staticmethod + def _patched_process_run(self, *args, **kwargs): + if not PatchOsFork._current_task: + return PatchOsFork._original_process_run(self, *args, **kwargs) + + try: + from ..task import Task + task = Task.current_task() + except: # noqa + task = None + + # check if this is Process Pool function + if hasattr(self, "_target"): + # Now we have to patch Pool, because pool terminates subprocess directly after + # the return value of the pool worker function is pushed into the queue, + # which means it will terminate the process before we finish running our "atexit" call + try: + if self._target == pool.worker: # noqa + self._target = partial(PatchOsFork._patched_pool_worker, pool.worker) # noqa + except: # noqa + pass + + try: + return PatchOsFork._original_process_run(self, *args, **kwargs) + finally: + # force creating a Task + try: + if task: + # noinspection PyProtectedMember + task._at_exit() + except: # noqa + pass + @staticmethod def _patched_fork(*args, **kwargs): + if not PatchOsFork._current_task: + return PatchOsFork._original_fork(*args, **kwargs) + from ..task import Task # ensure deferred is done, but never try to generate a Task object @@ -105,8 +197,10 @@ class PatchOsFork(object): if not task: return ret + PatchOsFork._current_task = task # # Hack: now make sure we setup the reporter threads (Log+Reporter) - if not task._report_subprocess_enabled: + # noinspection PyProtectedMember + if not bool(task._report_subprocess_enabled): BackgroundMonitor.start_all(task=task) # The signal handler method is Not enough, for the time being, we have both diff --git a/clearml/utilities/process/mp.py b/clearml/utilities/process/mp.py index 67f2727b..f15a018f 100644 --- a/clearml/utilities/process/mp.py +++ b/clearml/utilities/process/mp.py @@ -3,11 +3,10 @@ import pickle import struct import sys from functools import partial -from multiprocessing import Lock, Semaphore, Event as ProcessEvent +from multiprocessing import Process, Semaphore, Event as ProcessEvent from threading import Thread, Event as TrEvent, RLock as ThreadRLock from time import sleep, time from typing import List, Dict, Optional -from multiprocessing import Process import psutil from six.moves.queue import Empty, Queue as TrQueue @@ -29,11 +28,13 @@ except ImportError: try: from multiprocessing import get_context except ImportError: - def get_context(*args, **kwargs): - return False + def get_context(*args, **kwargs): # noqa + import multiprocessing + return multiprocessing class _ForkSafeThreadSyncObject(object): + __process_lock = get_context("fork" if sys.platform == 'linux' else "spawn").RLock() def __init__(self, functor): self._sync = None @@ -50,8 +51,16 @@ class _ForkSafeThreadSyncObject(object): # Notice the order! we first create the object and THEN update the pid, # this is so whatever happens we Never try to used the old (pre-forked copy) of the synchronization object - self._sync = self._functor() - self._instance_pid = os.getpid() + try: + while not self.__process_lock.acquire(block=True, timeout=1.0): + sleep(0.1) + + # we have to check gain inside the protected locked area + if self._instance_pid != os.getpid() or not self._sync: + self._sync = self._functor() + self._instance_pid = os.getpid() + finally: + self.__process_lock.release() class ForkSafeRLock(_ForkSafeThreadSyncObject): @@ -207,7 +216,7 @@ class ThreadCalls(object): continue # noinspection PyBroadException try: - if request[1]: + if request[1] is not None: request[0](*request[1]) else: request[0]() @@ -258,7 +267,8 @@ class SafeQueue(object): except Exception: pass self._internal_q = None - self._q_size = [] # list of PIDs we pushed, so this is atomic + # Note we should Never! assign a new object to `self._q_size`, just work with the initial object + self._q_size = [] # list of PIDs we pushed, so this is atomic. def empty(self): return self._q.empty() and (not self._internal_q or self._internal_q.empty()) @@ -312,12 +322,18 @@ class SafeQueue(object): # not atomic when forking for the first time # GIL will make sure it is atomic self._q_size.append(os.getpid()) - # make sure the block put is done in the thread pool i.e. in the background - obj = pickle.dumps(obj) - if BackgroundMonitor.get_at_exit_state(): - self._q_put(obj) - return - self.__thread_pool.get().apply_async(self._q_put, args=(obj, )) + try: + # make sure the block put is done in the thread pool i.e. in the background + obj = pickle.dumps(obj) + if BackgroundMonitor.get_at_exit_state(): + self._q_put(obj) + return + self.__thread_pool.get().apply_async(self._q_put, args=(obj, )) + except: # noqa + pid = os.getpid() + p = None + while p != pid and self._q_size: + p = self._q_size.pop() def _get_q_size_len(self, pid=None): pid = pid or os.getpid() @@ -328,13 +344,13 @@ class SafeQueue(object): self._q.put(obj) except BaseException: # make sure we zero the _q_size of the process dies (i.e. queue put fails) - self._q_size = [] + self._q_size.clear() raise pid = os.getpid() # GIL will make sure it is atomic # pop the First "counter" that is ours (i.e. pid == os.getpid()) p = None - while p != pid: + while p != pid and self._q_size: p = self._q_size.pop() def _init_reader_thread(self): @@ -423,7 +439,7 @@ class SingletonLock(AbstractContextManager): def create(self): if self._lock is None: - self._lock = Lock() + self._lock = ForkSafeRLock() @classmethod def instantiate(cls): @@ -442,7 +458,7 @@ class SingletonLock(AbstractContextManager): class BackgroundMonitor(object): - # If we will need multiple monitoring contexts (i.e. subprocesses) this will become a dict + # If we need multiple monitoring contexts (i.e. subprocesses) this will become a dict _main_process = None _main_process_proc_obj = None _main_process_task_id = None