Fix deadlock might occur when using Process pool large number processes (#674)

This commit is contained in:
allegroai 2022-07-28 18:41:17 +03:00
parent 36481a1337
commit eb01b6f893
2 changed files with 130 additions and 20 deletions

View File

@ -1,6 +1,7 @@
import os
from functools import partial
from time import sleep
from multiprocessing import pool
import six
from ..config import TASK_LOG_ENVIRONMENT, running_remotely, config
@ -62,15 +63,44 @@ class EnvironmentBind(object):
cls._current_task.connect(env_param, cls._environment_section)
class SimpleQueueWrapper(object):
def __init__(self, task, simple_queue):
self.__current_task = task
self.__simple_queue = simple_queue
def __getattr__(self, attr):
if attr in ["__simple_queue", "__current_task"]:
return self.__dict__.get(attr)
if attr == "put":
def _patched_put(*a_args, **a_kwargs):
try:
task = self.__current_task
# noinspection PyProtectedMember
task._at_exit()
except: # noqa
pass
return getattr(self.__simple_queue, "put")(*a_args, **a_kwargs)
return _patched_put
return getattr(self.__simple_queue, attr)
class PatchOsFork(object):
_original_fork = None
_current_task = None
_original_process_run = None
@classmethod
def patch_fork(cls, task):
cls._current_task = task
if not task:
return
# first we need to patch regular fork
# because forked processes do not support atexit, they call os._exit directly)
# noinspection PyBroadException
try:
# only once
@ -84,8 +114,70 @@ class PatchOsFork(object):
except Exception:
pass
# now we need to patch Process.run because the bootstrap code
# shuts everything down before calling os._exit that we patched above
try:
from multiprocessing.process import BaseProcess
PatchOsFork._original_process_run = BaseProcess.run
BaseProcess.run = PatchOsFork._patched_process_run
except: # noqa
pass
@staticmethod
def _patched_pool_worker(original_worker, *args, **kwargs):
if not PatchOsFork._current_task:
return original_worker(*args, **kwargs)
try:
if len(args) >= 2 and hasattr(args[1], "put"):
args = list(args)
args[1] = SimpleQueueWrapper(PatchOsFork._current_task, args[1])
args = tuple(args)
elif "outqueue" in kwargs and hasattr(kwargs["outqueue"], "put"):
kwargs["outqueue"] = SimpleQueueWrapper(PatchOsFork._current_task, kwargs["outqueue"])
except: # noqa
pass
return original_worker(*args, **kwargs)
@staticmethod
def _patched_process_run(self, *args, **kwargs):
if not PatchOsFork._current_task:
return PatchOsFork._original_process_run(self, *args, **kwargs)
try:
from ..task import Task
task = Task.current_task()
except: # noqa
task = None
# check if this is Process Pool function
if hasattr(self, "_target"):
# Now we have to patch Pool, because pool terminates subprocess directly after
# the return value of the pool worker function is pushed into the queue,
# which means it will terminate the process before we finish running our "atexit" call
try:
if self._target == pool.worker: # noqa
self._target = partial(PatchOsFork._patched_pool_worker, pool.worker) # noqa
except: # noqa
pass
try:
return PatchOsFork._original_process_run(self, *args, **kwargs)
finally:
# force creating a Task
try:
if task:
# noinspection PyProtectedMember
task._at_exit()
except: # noqa
pass
@staticmethod
def _patched_fork(*args, **kwargs):
if not PatchOsFork._current_task:
return PatchOsFork._original_fork(*args, **kwargs)
from ..task import Task
# ensure deferred is done, but never try to generate a Task object
@ -105,8 +197,10 @@ class PatchOsFork(object):
if not task:
return ret
PatchOsFork._current_task = task
# # Hack: now make sure we setup the reporter threads (Log+Reporter)
if not task._report_subprocess_enabled:
# noinspection PyProtectedMember
if not bool(task._report_subprocess_enabled):
BackgroundMonitor.start_all(task=task)
# The signal handler method is Not enough, for the time being, we have both

View File

@ -3,11 +3,10 @@ import pickle
import struct
import sys
from functools import partial
from multiprocessing import Lock, Semaphore, Event as ProcessEvent
from multiprocessing import Process, Semaphore, Event as ProcessEvent
from threading import Thread, Event as TrEvent, RLock as ThreadRLock
from time import sleep, time
from typing import List, Dict, Optional
from multiprocessing import Process
import psutil
from six.moves.queue import Empty, Queue as TrQueue
@ -29,11 +28,13 @@ except ImportError:
try:
from multiprocessing import get_context
except ImportError:
def get_context(*args, **kwargs):
return False
def get_context(*args, **kwargs): # noqa
import multiprocessing
return multiprocessing
class _ForkSafeThreadSyncObject(object):
__process_lock = get_context("fork" if sys.platform == 'linux' else "spawn").RLock()
def __init__(self, functor):
self._sync = None
@ -50,8 +51,16 @@ class _ForkSafeThreadSyncObject(object):
# Notice the order! we first create the object and THEN update the pid,
# this is so whatever happens we Never try to used the old (pre-forked copy) of the synchronization object
self._sync = self._functor()
self._instance_pid = os.getpid()
try:
while not self.__process_lock.acquire(block=True, timeout=1.0):
sleep(0.1)
# we have to check gain inside the protected locked area
if self._instance_pid != os.getpid() or not self._sync:
self._sync = self._functor()
self._instance_pid = os.getpid()
finally:
self.__process_lock.release()
class ForkSafeRLock(_ForkSafeThreadSyncObject):
@ -207,7 +216,7 @@ class ThreadCalls(object):
continue
# noinspection PyBroadException
try:
if request[1]:
if request[1] is not None:
request[0](*request[1])
else:
request[0]()
@ -258,7 +267,8 @@ class SafeQueue(object):
except Exception:
pass
self._internal_q = None
self._q_size = [] # list of PIDs we pushed, so this is atomic
# Note we should Never! assign a new object to `self._q_size`, just work with the initial object
self._q_size = [] # list of PIDs we pushed, so this is atomic.
def empty(self):
return self._q.empty() and (not self._internal_q or self._internal_q.empty())
@ -312,12 +322,18 @@ class SafeQueue(object):
# not atomic when forking for the first time
# GIL will make sure it is atomic
self._q_size.append(os.getpid())
# make sure the block put is done in the thread pool i.e. in the background
obj = pickle.dumps(obj)
if BackgroundMonitor.get_at_exit_state():
self._q_put(obj)
return
self.__thread_pool.get().apply_async(self._q_put, args=(obj, ))
try:
# make sure the block put is done in the thread pool i.e. in the background
obj = pickle.dumps(obj)
if BackgroundMonitor.get_at_exit_state():
self._q_put(obj)
return
self.__thread_pool.get().apply_async(self._q_put, args=(obj, ))
except: # noqa
pid = os.getpid()
p = None
while p != pid and self._q_size:
p = self._q_size.pop()
def _get_q_size_len(self, pid=None):
pid = pid or os.getpid()
@ -328,13 +344,13 @@ class SafeQueue(object):
self._q.put(obj)
except BaseException:
# make sure we zero the _q_size of the process dies (i.e. queue put fails)
self._q_size = []
self._q_size.clear()
raise
pid = os.getpid()
# GIL will make sure it is atomic
# pop the First "counter" that is ours (i.e. pid == os.getpid())
p = None
while p != pid:
while p != pid and self._q_size:
p = self._q_size.pop()
def _init_reader_thread(self):
@ -423,7 +439,7 @@ class SingletonLock(AbstractContextManager):
def create(self):
if self._lock is None:
self._lock = Lock()
self._lock = ForkSafeRLock()
@classmethod
def instantiate(cls):
@ -442,7 +458,7 @@ class SingletonLock(AbstractContextManager):
class BackgroundMonitor(object):
# If we will need multiple monitoring contexts (i.e. subprocesses) this will become a dict
# If we need multiple monitoring contexts (i.e. subprocesses) this will become a dict
_main_process = None
_main_process_proc_obj = None
_main_process_task_id = None