Fix deadlock might occur when using Process pool large number processes (#674)

This commit is contained in:
allegroai 2022-07-28 18:41:17 +03:00
parent 36481a1337
commit eb01b6f893
2 changed files with 130 additions and 20 deletions

View File

@ -1,6 +1,7 @@
import os import os
from functools import partial
from time import sleep from time import sleep
from multiprocessing import pool
import six import six
from ..config import TASK_LOG_ENVIRONMENT, running_remotely, config from ..config import TASK_LOG_ENVIRONMENT, running_remotely, config
@ -62,15 +63,44 @@ class EnvironmentBind(object):
cls._current_task.connect(env_param, cls._environment_section) cls._current_task.connect(env_param, cls._environment_section)
class SimpleQueueWrapper(object):
def __init__(self, task, simple_queue):
self.__current_task = task
self.__simple_queue = simple_queue
def __getattr__(self, attr):
if attr in ["__simple_queue", "__current_task"]:
return self.__dict__.get(attr)
if attr == "put":
def _patched_put(*a_args, **a_kwargs):
try:
task = self.__current_task
# noinspection PyProtectedMember
task._at_exit()
except: # noqa
pass
return getattr(self.__simple_queue, "put")(*a_args, **a_kwargs)
return _patched_put
return getattr(self.__simple_queue, attr)
class PatchOsFork(object): class PatchOsFork(object):
_original_fork = None _original_fork = None
_current_task = None _current_task = None
_original_process_run = None
@classmethod @classmethod
def patch_fork(cls, task): def patch_fork(cls, task):
cls._current_task = task cls._current_task = task
if not task: if not task:
return return
# first we need to patch regular fork
# because forked processes do not support atexit, they call os._exit directly)
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# only once # only once
@ -84,8 +114,70 @@ class PatchOsFork(object):
except Exception: except Exception:
pass pass
# now we need to patch Process.run because the bootstrap code
# shuts everything down before calling os._exit that we patched above
try:
from multiprocessing.process import BaseProcess
PatchOsFork._original_process_run = BaseProcess.run
BaseProcess.run = PatchOsFork._patched_process_run
except: # noqa
pass
@staticmethod
def _patched_pool_worker(original_worker, *args, **kwargs):
if not PatchOsFork._current_task:
return original_worker(*args, **kwargs)
try:
if len(args) >= 2 and hasattr(args[1], "put"):
args = list(args)
args[1] = SimpleQueueWrapper(PatchOsFork._current_task, args[1])
args = tuple(args)
elif "outqueue" in kwargs and hasattr(kwargs["outqueue"], "put"):
kwargs["outqueue"] = SimpleQueueWrapper(PatchOsFork._current_task, kwargs["outqueue"])
except: # noqa
pass
return original_worker(*args, **kwargs)
@staticmethod
def _patched_process_run(self, *args, **kwargs):
if not PatchOsFork._current_task:
return PatchOsFork._original_process_run(self, *args, **kwargs)
try:
from ..task import Task
task = Task.current_task()
except: # noqa
task = None
# check if this is Process Pool function
if hasattr(self, "_target"):
# Now we have to patch Pool, because pool terminates subprocess directly after
# the return value of the pool worker function is pushed into the queue,
# which means it will terminate the process before we finish running our "atexit" call
try:
if self._target == pool.worker: # noqa
self._target = partial(PatchOsFork._patched_pool_worker, pool.worker) # noqa
except: # noqa
pass
try:
return PatchOsFork._original_process_run(self, *args, **kwargs)
finally:
# force creating a Task
try:
if task:
# noinspection PyProtectedMember
task._at_exit()
except: # noqa
pass
@staticmethod @staticmethod
def _patched_fork(*args, **kwargs): def _patched_fork(*args, **kwargs):
if not PatchOsFork._current_task:
return PatchOsFork._original_fork(*args, **kwargs)
from ..task import Task from ..task import Task
# ensure deferred is done, but never try to generate a Task object # ensure deferred is done, but never try to generate a Task object
@ -105,8 +197,10 @@ class PatchOsFork(object):
if not task: if not task:
return ret return ret
PatchOsFork._current_task = task
# # Hack: now make sure we setup the reporter threads (Log+Reporter) # # Hack: now make sure we setup the reporter threads (Log+Reporter)
if not task._report_subprocess_enabled: # noinspection PyProtectedMember
if not bool(task._report_subprocess_enabled):
BackgroundMonitor.start_all(task=task) BackgroundMonitor.start_all(task=task)
# The signal handler method is Not enough, for the time being, we have both # The signal handler method is Not enough, for the time being, we have both

View File

@ -3,11 +3,10 @@ import pickle
import struct import struct
import sys import sys
from functools import partial from functools import partial
from multiprocessing import Lock, Semaphore, Event as ProcessEvent from multiprocessing import Process, Semaphore, Event as ProcessEvent
from threading import Thread, Event as TrEvent, RLock as ThreadRLock from threading import Thread, Event as TrEvent, RLock as ThreadRLock
from time import sleep, time from time import sleep, time
from typing import List, Dict, Optional from typing import List, Dict, Optional
from multiprocessing import Process
import psutil import psutil
from six.moves.queue import Empty, Queue as TrQueue from six.moves.queue import Empty, Queue as TrQueue
@ -29,11 +28,13 @@ except ImportError:
try: try:
from multiprocessing import get_context from multiprocessing import get_context
except ImportError: except ImportError:
def get_context(*args, **kwargs): def get_context(*args, **kwargs): # noqa
return False import multiprocessing
return multiprocessing
class _ForkSafeThreadSyncObject(object): class _ForkSafeThreadSyncObject(object):
__process_lock = get_context("fork" if sys.platform == 'linux' else "spawn").RLock()
def __init__(self, functor): def __init__(self, functor):
self._sync = None self._sync = None
@ -50,8 +51,16 @@ class _ForkSafeThreadSyncObject(object):
# Notice the order! we first create the object and THEN update the pid, # Notice the order! we first create the object and THEN update the pid,
# this is so whatever happens we Never try to used the old (pre-forked copy) of the synchronization object # this is so whatever happens we Never try to used the old (pre-forked copy) of the synchronization object
self._sync = self._functor() try:
self._instance_pid = os.getpid() while not self.__process_lock.acquire(block=True, timeout=1.0):
sleep(0.1)
# we have to check gain inside the protected locked area
if self._instance_pid != os.getpid() or not self._sync:
self._sync = self._functor()
self._instance_pid = os.getpid()
finally:
self.__process_lock.release()
class ForkSafeRLock(_ForkSafeThreadSyncObject): class ForkSafeRLock(_ForkSafeThreadSyncObject):
@ -207,7 +216,7 @@ class ThreadCalls(object):
continue continue
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if request[1]: if request[1] is not None:
request[0](*request[1]) request[0](*request[1])
else: else:
request[0]() request[0]()
@ -258,7 +267,8 @@ class SafeQueue(object):
except Exception: except Exception:
pass pass
self._internal_q = None self._internal_q = None
self._q_size = [] # list of PIDs we pushed, so this is atomic # Note we should Never! assign a new object to `self._q_size`, just work with the initial object
self._q_size = [] # list of PIDs we pushed, so this is atomic.
def empty(self): def empty(self):
return self._q.empty() and (not self._internal_q or self._internal_q.empty()) return self._q.empty() and (not self._internal_q or self._internal_q.empty())
@ -312,12 +322,18 @@ class SafeQueue(object):
# not atomic when forking for the first time # not atomic when forking for the first time
# GIL will make sure it is atomic # GIL will make sure it is atomic
self._q_size.append(os.getpid()) self._q_size.append(os.getpid())
# make sure the block put is done in the thread pool i.e. in the background try:
obj = pickle.dumps(obj) # make sure the block put is done in the thread pool i.e. in the background
if BackgroundMonitor.get_at_exit_state(): obj = pickle.dumps(obj)
self._q_put(obj) if BackgroundMonitor.get_at_exit_state():
return self._q_put(obj)
self.__thread_pool.get().apply_async(self._q_put, args=(obj, )) return
self.__thread_pool.get().apply_async(self._q_put, args=(obj, ))
except: # noqa
pid = os.getpid()
p = None
while p != pid and self._q_size:
p = self._q_size.pop()
def _get_q_size_len(self, pid=None): def _get_q_size_len(self, pid=None):
pid = pid or os.getpid() pid = pid or os.getpid()
@ -328,13 +344,13 @@ class SafeQueue(object):
self._q.put(obj) self._q.put(obj)
except BaseException: except BaseException:
# make sure we zero the _q_size of the process dies (i.e. queue put fails) # make sure we zero the _q_size of the process dies (i.e. queue put fails)
self._q_size = [] self._q_size.clear()
raise raise
pid = os.getpid() pid = os.getpid()
# GIL will make sure it is atomic # GIL will make sure it is atomic
# pop the First "counter" that is ours (i.e. pid == os.getpid()) # pop the First "counter" that is ours (i.e. pid == os.getpid())
p = None p = None
while p != pid: while p != pid and self._q_size:
p = self._q_size.pop() p = self._q_size.pop()
def _init_reader_thread(self): def _init_reader_thread(self):
@ -423,7 +439,7 @@ class SingletonLock(AbstractContextManager):
def create(self): def create(self):
if self._lock is None: if self._lock is None:
self._lock = Lock() self._lock = ForkSafeRLock()
@classmethod @classmethod
def instantiate(cls): def instantiate(cls):
@ -442,7 +458,7 @@ class SingletonLock(AbstractContextManager):
class BackgroundMonitor(object): class BackgroundMonitor(object):
# If we will need multiple monitoring contexts (i.e. subprocesses) this will become a dict # If we need multiple monitoring contexts (i.e. subprocesses) this will become a dict
_main_process = None _main_process = None
_main_process_proc_obj = None _main_process_proc_obj = None
_main_process_task_id = None _main_process_task_id = None