Fix process pool can reuse workers

This commit is contained in:
allegroai 2022-07-28 18:48:37 +03:00
parent c58e2551c8
commit 136b0c33e7

View File

@ -1,7 +1,5 @@
import os
from functools import partial
from time import sleep
from multiprocessing import pool
import six
from ..config import TASK_LOG_ENVIRONMENT, running_remotely, config
@ -63,34 +61,11 @@ class EnvironmentBind(object):
cls._current_task.connect(env_param, cls._environment_section)
class SimpleQueueWrapper(object):
def __init__(self, task, simple_queue):
self.__current_task = task
self.__simple_queue = simple_queue
def __getattr__(self, attr):
if attr in ["__simple_queue", "__current_task"]:
return self.__dict__.get(attr)
if attr == "put":
def _patched_put(*a_args, **a_kwargs):
try:
task = self.__current_task
# noinspection PyProtectedMember
task._at_exit()
except: # noqa
pass
return getattr(self.__simple_queue, "put")(*a_args, **a_kwargs)
return _patched_put
return getattr(self.__simple_queue, attr)
class PatchOsFork(object):
_original_fork = None
_current_task = None
_original_process_run = None
_original_process_terminate = None
@classmethod
def patch_fork(cls, task):
@ -120,26 +95,11 @@ class PatchOsFork(object):
from multiprocessing.process import BaseProcess
PatchOsFork._original_process_run = BaseProcess.run
BaseProcess.run = PatchOsFork._patched_process_run
PatchOsFork._original_process_terminate = BaseProcess.terminate
BaseProcess.terminate = PatchOsFork._patched_process_terminate
except: # noqa
pass
@staticmethod
def _patched_pool_worker(original_worker, *args, **kwargs):
if not PatchOsFork._current_task:
return original_worker(*args, **kwargs)
try:
if len(args) >= 2 and hasattr(args[1], "put"):
args = list(args)
args[1] = SimpleQueueWrapper(PatchOsFork._current_task, args[1])
args = tuple(args)
elif "outqueue" in kwargs and hasattr(kwargs["outqueue"], "put"):
kwargs["outqueue"] = SimpleQueueWrapper(PatchOsFork._current_task, kwargs["outqueue"])
except: # noqa
pass
return original_worker(*args, **kwargs)
@staticmethod
def _patched_process_run(self, *args, **kwargs):
if not PatchOsFork._current_task:
@ -151,17 +111,6 @@ class PatchOsFork(object):
except: # noqa
task = None
# check if this is Process Pool function
if hasattr(self, "_target"):
# Now we have to patch Pool, because pool terminates subprocess directly after
# the return value of the pool worker function is pushed into the queue,
# which means it will terminate the process before we finish running our "atexit" call
try:
if self._target == pool.worker: # noqa
self._target = partial(PatchOsFork._patched_pool_worker, pool.worker) # noqa
except: # noqa
pass
try:
return PatchOsFork._original_process_run(self, *args, **kwargs)
finally:
@ -173,6 +122,18 @@ class PatchOsFork(object):
except: # noqa
pass
@staticmethod
def _patched_process_terminate(self, *args, **kwargs):
if PatchOsFork._current_task:
# force creating a Task
try:
# noinspection PyProtectedMember
PatchOsFork._current_task._at_exit()
except: # noqa
pass
return PatchOsFork._original_process_terminate(self, *args, **kwargs)
@staticmethod
def _patched_fork(*args, **kwargs):
if not PatchOsFork._current_task: