mirror of
https://github.com/clearml/clearml
synced 2025-04-03 20:41:07 +00:00
Fix deadlock might occur when using Process pool large number processes (#674)
This commit is contained in:
parent
36481a1337
commit
eb01b6f893
@ -1,6 +1,7 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from time import sleep
|
||||
|
||||
from multiprocessing import pool
|
||||
import six
|
||||
|
||||
from ..config import TASK_LOG_ENVIRONMENT, running_remotely, config
|
||||
@ -62,15 +63,44 @@ class EnvironmentBind(object):
|
||||
cls._current_task.connect(env_param, cls._environment_section)
|
||||
|
||||
|
||||
class SimpleQueueWrapper(object):
|
||||
def __init__(self, task, simple_queue):
|
||||
self.__current_task = task
|
||||
self.__simple_queue = simple_queue
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in ["__simple_queue", "__current_task"]:
|
||||
return self.__dict__.get(attr)
|
||||
|
||||
if attr == "put":
|
||||
def _patched_put(*a_args, **a_kwargs):
|
||||
try:
|
||||
task = self.__current_task
|
||||
# noinspection PyProtectedMember
|
||||
task._at_exit()
|
||||
except: # noqa
|
||||
pass
|
||||
return getattr(self.__simple_queue, "put")(*a_args, **a_kwargs)
|
||||
|
||||
return _patched_put
|
||||
|
||||
return getattr(self.__simple_queue, attr)
|
||||
|
||||
|
||||
class PatchOsFork(object):
|
||||
_original_fork = None
|
||||
_current_task = None
|
||||
_original_process_run = None
|
||||
|
||||
@classmethod
|
||||
def patch_fork(cls, task):
|
||||
cls._current_task = task
|
||||
if not task:
|
||||
return
|
||||
|
||||
# first we need to patch regular fork
|
||||
# because forked processes do not support atexit, they call os._exit directly)
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# only once
|
||||
@ -84,8 +114,70 @@ class PatchOsFork(object):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# now we need to patch Process.run because the bootstrap code
|
||||
# shuts everything down before calling os._exit that we patched above
|
||||
try:
|
||||
from multiprocessing.process import BaseProcess
|
||||
PatchOsFork._original_process_run = BaseProcess.run
|
||||
BaseProcess.run = PatchOsFork._patched_process_run
|
||||
except: # noqa
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _patched_pool_worker(original_worker, *args, **kwargs):
|
||||
if not PatchOsFork._current_task:
|
||||
return original_worker(*args, **kwargs)
|
||||
|
||||
try:
|
||||
if len(args) >= 2 and hasattr(args[1], "put"):
|
||||
args = list(args)
|
||||
args[1] = SimpleQueueWrapper(PatchOsFork._current_task, args[1])
|
||||
args = tuple(args)
|
||||
elif "outqueue" in kwargs and hasattr(kwargs["outqueue"], "put"):
|
||||
kwargs["outqueue"] = SimpleQueueWrapper(PatchOsFork._current_task, kwargs["outqueue"])
|
||||
except: # noqa
|
||||
pass
|
||||
|
||||
return original_worker(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _patched_process_run(self, *args, **kwargs):
|
||||
if not PatchOsFork._current_task:
|
||||
return PatchOsFork._original_process_run(self, *args, **kwargs)
|
||||
|
||||
try:
|
||||
from ..task import Task
|
||||
task = Task.current_task()
|
||||
except: # noqa
|
||||
task = None
|
||||
|
||||
# check if this is Process Pool function
|
||||
if hasattr(self, "_target"):
|
||||
# Now we have to patch Pool, because pool terminates subprocess directly after
|
||||
# the return value of the pool worker function is pushed into the queue,
|
||||
# which means it will terminate the process before we finish running our "atexit" call
|
||||
try:
|
||||
if self._target == pool.worker: # noqa
|
||||
self._target = partial(PatchOsFork._patched_pool_worker, pool.worker) # noqa
|
||||
except: # noqa
|
||||
pass
|
||||
|
||||
try:
|
||||
return PatchOsFork._original_process_run(self, *args, **kwargs)
|
||||
finally:
|
||||
# force creating a Task
|
||||
try:
|
||||
if task:
|
||||
# noinspection PyProtectedMember
|
||||
task._at_exit()
|
||||
except: # noqa
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _patched_fork(*args, **kwargs):
|
||||
if not PatchOsFork._current_task:
|
||||
return PatchOsFork._original_fork(*args, **kwargs)
|
||||
|
||||
from ..task import Task
|
||||
|
||||
# ensure deferred is done, but never try to generate a Task object
|
||||
@ -105,8 +197,10 @@ class PatchOsFork(object):
|
||||
if not task:
|
||||
return ret
|
||||
|
||||
PatchOsFork._current_task = task
|
||||
# # Hack: now make sure we setup the reporter threads (Log+Reporter)
|
||||
if not task._report_subprocess_enabled:
|
||||
# noinspection PyProtectedMember
|
||||
if not bool(task._report_subprocess_enabled):
|
||||
BackgroundMonitor.start_all(task=task)
|
||||
|
||||
# The signal handler method is Not enough, for the time being, we have both
|
||||
|
@ -3,11 +3,10 @@ import pickle
|
||||
import struct
|
||||
import sys
|
||||
from functools import partial
|
||||
from multiprocessing import Lock, Semaphore, Event as ProcessEvent
|
||||
from multiprocessing import Process, Semaphore, Event as ProcessEvent
|
||||
from threading import Thread, Event as TrEvent, RLock as ThreadRLock
|
||||
from time import sleep, time
|
||||
from typing import List, Dict, Optional
|
||||
from multiprocessing import Process
|
||||
|
||||
import psutil
|
||||
from six.moves.queue import Empty, Queue as TrQueue
|
||||
@ -29,11 +28,13 @@ except ImportError:
|
||||
try:
|
||||
from multiprocessing import get_context
|
||||
except ImportError:
|
||||
def get_context(*args, **kwargs):
|
||||
return False
|
||||
def get_context(*args, **kwargs): # noqa
|
||||
import multiprocessing
|
||||
return multiprocessing
|
||||
|
||||
|
||||
class _ForkSafeThreadSyncObject(object):
|
||||
__process_lock = get_context("fork" if sys.platform == 'linux' else "spawn").RLock()
|
||||
|
||||
def __init__(self, functor):
|
||||
self._sync = None
|
||||
@ -50,8 +51,16 @@ class _ForkSafeThreadSyncObject(object):
|
||||
|
||||
# Notice the order! we first create the object and THEN update the pid,
|
||||
# this is so whatever happens we Never try to used the old (pre-forked copy) of the synchronization object
|
||||
self._sync = self._functor()
|
||||
self._instance_pid = os.getpid()
|
||||
try:
|
||||
while not self.__process_lock.acquire(block=True, timeout=1.0):
|
||||
sleep(0.1)
|
||||
|
||||
# we have to check gain inside the protected locked area
|
||||
if self._instance_pid != os.getpid() or not self._sync:
|
||||
self._sync = self._functor()
|
||||
self._instance_pid = os.getpid()
|
||||
finally:
|
||||
self.__process_lock.release()
|
||||
|
||||
|
||||
class ForkSafeRLock(_ForkSafeThreadSyncObject):
|
||||
@ -207,7 +216,7 @@ class ThreadCalls(object):
|
||||
continue
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if request[1]:
|
||||
if request[1] is not None:
|
||||
request[0](*request[1])
|
||||
else:
|
||||
request[0]()
|
||||
@ -258,7 +267,8 @@ class SafeQueue(object):
|
||||
except Exception:
|
||||
pass
|
||||
self._internal_q = None
|
||||
self._q_size = [] # list of PIDs we pushed, so this is atomic
|
||||
# Note we should Never! assign a new object to `self._q_size`, just work with the initial object
|
||||
self._q_size = [] # list of PIDs we pushed, so this is atomic.
|
||||
|
||||
def empty(self):
|
||||
return self._q.empty() and (not self._internal_q or self._internal_q.empty())
|
||||
@ -312,12 +322,18 @@ class SafeQueue(object):
|
||||
# not atomic when forking for the first time
|
||||
# GIL will make sure it is atomic
|
||||
self._q_size.append(os.getpid())
|
||||
# make sure the block put is done in the thread pool i.e. in the background
|
||||
obj = pickle.dumps(obj)
|
||||
if BackgroundMonitor.get_at_exit_state():
|
||||
self._q_put(obj)
|
||||
return
|
||||
self.__thread_pool.get().apply_async(self._q_put, args=(obj, ))
|
||||
try:
|
||||
# make sure the block put is done in the thread pool i.e. in the background
|
||||
obj = pickle.dumps(obj)
|
||||
if BackgroundMonitor.get_at_exit_state():
|
||||
self._q_put(obj)
|
||||
return
|
||||
self.__thread_pool.get().apply_async(self._q_put, args=(obj, ))
|
||||
except: # noqa
|
||||
pid = os.getpid()
|
||||
p = None
|
||||
while p != pid and self._q_size:
|
||||
p = self._q_size.pop()
|
||||
|
||||
def _get_q_size_len(self, pid=None):
|
||||
pid = pid or os.getpid()
|
||||
@ -328,13 +344,13 @@ class SafeQueue(object):
|
||||
self._q.put(obj)
|
||||
except BaseException:
|
||||
# make sure we zero the _q_size of the process dies (i.e. queue put fails)
|
||||
self._q_size = []
|
||||
self._q_size.clear()
|
||||
raise
|
||||
pid = os.getpid()
|
||||
# GIL will make sure it is atomic
|
||||
# pop the First "counter" that is ours (i.e. pid == os.getpid())
|
||||
p = None
|
||||
while p != pid:
|
||||
while p != pid and self._q_size:
|
||||
p = self._q_size.pop()
|
||||
|
||||
def _init_reader_thread(self):
|
||||
@ -423,7 +439,7 @@ class SingletonLock(AbstractContextManager):
|
||||
|
||||
def create(self):
|
||||
if self._lock is None:
|
||||
self._lock = Lock()
|
||||
self._lock = ForkSafeRLock()
|
||||
|
||||
@classmethod
|
||||
def instantiate(cls):
|
||||
@ -442,7 +458,7 @@ class SingletonLock(AbstractContextManager):
|
||||
|
||||
|
||||
class BackgroundMonitor(object):
|
||||
# If we will need multiple monitoring contexts (i.e. subprocesses) this will become a dict
|
||||
# If we need multiple monitoring contexts (i.e. subprocesses) this will become a dict
|
||||
_main_process = None
|
||||
_main_process_proc_obj = None
|
||||
_main_process_task_id = None
|
||||
|
Loading…
Reference in New Issue
Block a user