From e8439b3b654cc1793f264afcce008740a82f3995 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 28 Jul 2022 18:49:44 +0300 Subject: [PATCH] Flush everything before pool worker push results back (external termination) --- clearml/binding/environ_bind.py | 69 ++++++++++++++++++++++++++------- 1 file changed, 54 insertions(+), 15 deletions(-) diff --git a/clearml/binding/environ_bind.py b/clearml/binding/environ_bind.py index 3969b6b6..acb429c2 100644 --- a/clearml/binding/environ_bind.py +++ b/clearml/binding/environ_bind.py @@ -1,5 +1,7 @@ import os +from functools import partial from time import sleep +from multiprocessing import pool import six from ..config import TASK_LOG_ENVIRONMENT, running_remotely, config @@ -61,11 +63,34 @@ class EnvironmentBind(object): cls._current_task.connect(env_param, cls._environment_section) +class SimpleQueueWrapper(object): + def __init__(self, task, simple_queue): + self.__current_task = task + self.__simple_queue = simple_queue + + def __getattr__(self, attr): + if attr in ["__simple_queue", "__current_task"]: + return self.__dict__.get(attr) + + if attr == "put": + def _patched_put(*a_args, **a_kwargs): + # make sure we flush everything, because after we push the result we will get terminated + try: + task = self.__current_task + task.flush(wait_for_uploads=True) + except: # noqa + pass + return getattr(self.__simple_queue, "put")(*a_args, **a_kwargs) + + return _patched_put + + return getattr(self.__simple_queue, attr) + + class PatchOsFork(object): _original_fork = None _current_task = None _original_process_run = None - _original_process_terminate = None @classmethod def patch_fork(cls, task): @@ -95,11 +120,26 @@ class PatchOsFork(object): from multiprocessing.process import BaseProcess PatchOsFork._original_process_run = BaseProcess.run BaseProcess.run = PatchOsFork._patched_process_run - PatchOsFork._original_process_terminate = BaseProcess.terminate - BaseProcess.terminate = PatchOsFork._patched_process_terminate except: # noqa pass + @staticmethod + def _patched_pool_worker(original_worker, *args, **kwargs): + if not PatchOsFork._current_task: + return original_worker(*args, **kwargs) + + try: + if len(args) >= 2 and hasattr(args[1], "put"): + args = list(args) + args[1] = SimpleQueueWrapper(PatchOsFork._current_task, args[1]) + args = tuple(args) + elif "outqueue" in kwargs and hasattr(kwargs["outqueue"], "put"): + kwargs["outqueue"] = SimpleQueueWrapper(PatchOsFork._current_task, kwargs["outqueue"]) + except: # noqa + pass + + return original_worker(*args, **kwargs) + @staticmethod def _patched_process_run(self, *args, **kwargs): if not PatchOsFork._current_task: @@ -111,6 +151,17 @@ class PatchOsFork(object): except: # noqa task = None + # check if this is Process Pool function + if hasattr(self, "_target"): + # Now we have to patch Pool, because pool terminates subprocess directly after + # the return value of the pool worker function is pushed into the queue, + # which means it will terminate the process before we finish running our "atexit" call + try: + if self._target == pool.worker: # noqa + self._target = partial(PatchOsFork._patched_pool_worker, pool.worker) # noqa + except: # noqa + pass + try: return PatchOsFork._original_process_run(self, *args, **kwargs) finally: @@ -122,18 +173,6 @@ class PatchOsFork(object): except: # noqa pass - @staticmethod - def _patched_process_terminate(self, *args, **kwargs): - if PatchOsFork._current_task: - # force creating a Task - try: - # noinspection PyProtectedMember - PatchOsFork._current_task._at_exit() - except: # noqa - pass - - return PatchOsFork._original_process_terminate(self, *args, **kwargs) - @staticmethod def _patched_fork(*args, **kwargs): if not PatchOsFork._current_task: