mirror of
https://github.com/clearml/clearml
synced 2025-04-28 02:01:51 +00:00
Flush everything before pool worker push results back (external termination)
This commit is contained in:
parent
44a4dc99b3
commit
e8439b3b65
@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
from functools import partial
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
from multiprocessing import pool
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from ..config import TASK_LOG_ENVIRONMENT, running_remotely, config
|
from ..config import TASK_LOG_ENVIRONMENT, running_remotely, config
|
||||||
@ -61,11 +63,34 @@ class EnvironmentBind(object):
|
|||||||
cls._current_task.connect(env_param, cls._environment_section)
|
cls._current_task.connect(env_param, cls._environment_section)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleQueueWrapper(object):
|
||||||
|
def __init__(self, task, simple_queue):
|
||||||
|
self.__current_task = task
|
||||||
|
self.__simple_queue = simple_queue
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
if attr in ["__simple_queue", "__current_task"]:
|
||||||
|
return self.__dict__.get(attr)
|
||||||
|
|
||||||
|
if attr == "put":
|
||||||
|
def _patched_put(*a_args, **a_kwargs):
|
||||||
|
# make sure we flush everything, because after we push the result we will get terminated
|
||||||
|
try:
|
||||||
|
task = self.__current_task
|
||||||
|
task.flush(wait_for_uploads=True)
|
||||||
|
except: # noqa
|
||||||
|
pass
|
||||||
|
return getattr(self.__simple_queue, "put")(*a_args, **a_kwargs)
|
||||||
|
|
||||||
|
return _patched_put
|
||||||
|
|
||||||
|
return getattr(self.__simple_queue, attr)
|
||||||
|
|
||||||
|
|
||||||
class PatchOsFork(object):
|
class PatchOsFork(object):
|
||||||
_original_fork = None
|
_original_fork = None
|
||||||
_current_task = None
|
_current_task = None
|
||||||
_original_process_run = None
|
_original_process_run = None
|
||||||
_original_process_terminate = None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def patch_fork(cls, task):
|
def patch_fork(cls, task):
|
||||||
@ -95,11 +120,26 @@ class PatchOsFork(object):
|
|||||||
from multiprocessing.process import BaseProcess
|
from multiprocessing.process import BaseProcess
|
||||||
PatchOsFork._original_process_run = BaseProcess.run
|
PatchOsFork._original_process_run = BaseProcess.run
|
||||||
BaseProcess.run = PatchOsFork._patched_process_run
|
BaseProcess.run = PatchOsFork._patched_process_run
|
||||||
PatchOsFork._original_process_terminate = BaseProcess.terminate
|
|
||||||
BaseProcess.terminate = PatchOsFork._patched_process_terminate
|
|
||||||
except: # noqa
|
except: # noqa
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patched_pool_worker(original_worker, *args, **kwargs):
|
||||||
|
if not PatchOsFork._current_task:
|
||||||
|
return original_worker(*args, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if len(args) >= 2 and hasattr(args[1], "put"):
|
||||||
|
args = list(args)
|
||||||
|
args[1] = SimpleQueueWrapper(PatchOsFork._current_task, args[1])
|
||||||
|
args = tuple(args)
|
||||||
|
elif "outqueue" in kwargs and hasattr(kwargs["outqueue"], "put"):
|
||||||
|
kwargs["outqueue"] = SimpleQueueWrapper(PatchOsFork._current_task, kwargs["outqueue"])
|
||||||
|
except: # noqa
|
||||||
|
pass
|
||||||
|
|
||||||
|
return original_worker(*args, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patched_process_run(self, *args, **kwargs):
|
def _patched_process_run(self, *args, **kwargs):
|
||||||
if not PatchOsFork._current_task:
|
if not PatchOsFork._current_task:
|
||||||
@ -111,6 +151,17 @@ class PatchOsFork(object):
|
|||||||
except: # noqa
|
except: # noqa
|
||||||
task = None
|
task = None
|
||||||
|
|
||||||
|
# check if this is Process Pool function
|
||||||
|
if hasattr(self, "_target"):
|
||||||
|
# Now we have to patch Pool, because pool terminates subprocess directly after
|
||||||
|
# the return value of the pool worker function is pushed into the queue,
|
||||||
|
# which means it will terminate the process before we finish running our "atexit" call
|
||||||
|
try:
|
||||||
|
if self._target == pool.worker: # noqa
|
||||||
|
self._target = partial(PatchOsFork._patched_pool_worker, pool.worker) # noqa
|
||||||
|
except: # noqa
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return PatchOsFork._original_process_run(self, *args, **kwargs)
|
return PatchOsFork._original_process_run(self, *args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
@ -122,18 +173,6 @@ class PatchOsFork(object):
|
|||||||
except: # noqa
|
except: # noqa
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _patched_process_terminate(self, *args, **kwargs):
|
|
||||||
if PatchOsFork._current_task:
|
|
||||||
# force creating a Task
|
|
||||||
try:
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
PatchOsFork._current_task._at_exit()
|
|
||||||
except: # noqa
|
|
||||||
pass
|
|
||||||
|
|
||||||
return PatchOsFork._original_process_terminate(self, *args, **kwargs)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patched_fork(*args, **kwargs):
|
def _patched_fork(*args, **kwargs):
|
||||||
if not PatchOsFork._current_task:
|
if not PatchOsFork._current_task:
|
||||||
|
Loading…
Reference in New Issue
Block a user