diff --git a/clearml/binding/environ_bind.py b/clearml/binding/environ_bind.py index dbf53540..3969b6b6 100644 --- a/clearml/binding/environ_bind.py +++ b/clearml/binding/environ_bind.py @@ -1,7 +1,5 @@ import os -from functools import partial from time import sleep -from multiprocessing import pool import six from ..config import TASK_LOG_ENVIRONMENT, running_remotely, config @@ -63,34 +61,11 @@ class EnvironmentBind(object): cls._current_task.connect(env_param, cls._environment_section) -class SimpleQueueWrapper(object): - def __init__(self, task, simple_queue): - self.__current_task = task - self.__simple_queue = simple_queue - - def __getattr__(self, attr): - if attr in ["__simple_queue", "__current_task"]: - return self.__dict__.get(attr) - - if attr == "put": - def _patched_put(*a_args, **a_kwargs): - try: - task = self.__current_task - # noinspection PyProtectedMember - task._at_exit() - except: # noqa - pass - return getattr(self.__simple_queue, "put")(*a_args, **a_kwargs) - - return _patched_put - - return getattr(self.__simple_queue, attr) - - class PatchOsFork(object): _original_fork = None _current_task = None _original_process_run = None + _original_process_terminate = None @classmethod def patch_fork(cls, task): @@ -120,26 +95,11 @@ class PatchOsFork(object): from multiprocessing.process import BaseProcess PatchOsFork._original_process_run = BaseProcess.run BaseProcess.run = PatchOsFork._patched_process_run + PatchOsFork._original_process_terminate = BaseProcess.terminate + BaseProcess.terminate = PatchOsFork._patched_process_terminate except: # noqa pass - @staticmethod - def _patched_pool_worker(original_worker, *args, **kwargs): - if not PatchOsFork._current_task: - return original_worker(*args, **kwargs) - - try: - if len(args) >= 2 and hasattr(args[1], "put"): - args = list(args) - args[1] = SimpleQueueWrapper(PatchOsFork._current_task, args[1]) - args = tuple(args) - elif "outqueue" in kwargs and hasattr(kwargs["outqueue"], "put"): - kwargs["outqueue"] = SimpleQueueWrapper(PatchOsFork._current_task, kwargs["outqueue"]) - except: # noqa - pass - - return original_worker(*args, **kwargs) - @staticmethod def _patched_process_run(self, *args, **kwargs): if not PatchOsFork._current_task: @@ -151,17 +111,6 @@ class PatchOsFork(object): except: # noqa task = None - # check if this is Process Pool function - if hasattr(self, "_target"): - # Now we have to patch Pool, because pool terminates subprocess directly after - # the return value of the pool worker function is pushed into the queue, - # which means it will terminate the process before we finish running our "atexit" call - try: - if self._target == pool.worker: # noqa - self._target = partial(PatchOsFork._patched_pool_worker, pool.worker) # noqa - except: # noqa - pass - try: return PatchOsFork._original_process_run(self, *args, **kwargs) finally: @@ -173,6 +122,18 @@ class PatchOsFork(object): except: # noqa pass + @staticmethod + def _patched_process_terminate(self, *args, **kwargs): + if PatchOsFork._current_task: + # force creating a Task + try: + # noinspection PyProtectedMember + PatchOsFork._current_task._at_exit() + except: # noqa + pass + + return PatchOsFork._original_process_terminate(self, *args, **kwargs) + @staticmethod def _patched_fork(*args, **kwargs): if not PatchOsFork._current_task: