Fix process sometimes hangs issue by improving exit and exception handlers, unregistering signal handling in child processes except for the BackgroundMonitor

This commit is contained in:
allegroai 2023-12-13 16:41:21 +02:00
parent 07dcbcac08
commit 8cb7d14abb
5 changed files with 265 additions and 180 deletions

View File

@ -175,10 +175,11 @@ class PatchOsFork(object):
try: try:
return PatchOsFork._original_process_run(self, *args, **kwargs) return PatchOsFork._original_process_run(self, *args, **kwargs)
finally: finally:
if task: if task and patched_worker:
try: try:
if patched_worker: # noinspection PyProtectedMember
# remove at exit hooks, we will deadlock when the if task._report_subprocess_enabled:
# just in case, remove at exit hooks, we will deadlock when the
# main Pool manager will terminate this process, and it will... # main Pool manager will terminate this process, and it will...
# noinspection PyProtectedMember # noinspection PyProtectedMember
task._at_exit_called = True task._at_exit_called = True
@ -214,12 +215,30 @@ class PatchOsFork(object):
if not task: if not task:
return return
if not Task._report_subprocess_enabled:
# https://stackoverflow.com/a/34507557
# NOTICE: subprocesses do not exit through exit we have to register signals
if task._Task__exit_hook:
task._Task__exit_hook.register_signal_and_exception_hooks()
else:
# noinspection PyProtectedMember
task._remove_signal_hooks()
# noinspection PyProtectedMember
if Task._report_subprocess_enabled:
# noinspection PyProtectedMember
task._remove_exception_hooks()
PatchOsFork._current_task = task PatchOsFork._current_task = task
# # Hack: now make sure we setup the reporter threads (Log+Reporter) # # Hack: now make sure we setup the reporter threads (Log+Reporter)
# noinspection PyProtectedMember # noinspection PyProtectedMember
if not bool(task._report_subprocess_enabled): if not bool(task._report_subprocess_enabled):
BackgroundMonitor.start_all(task=task) BackgroundMonitor.start_all(task=task)
# if we are reporting into a subprocess, no need to further patch the exit functions
if Task._report_subprocess_enabled:
return
# The signal handler method is Not enough, for the time being, we have both # The signal handler method is Not enough, for the time being, we have both
# even though it makes little sense # even though it makes little sense
# # if we got here patch the os._exit of our instance to call us # # if we got here patch the os._exit of our instance to call us
@ -244,6 +263,10 @@ class PatchOsFork(object):
# noinspection PyProtectedMember, PyUnresolvedReferences # noinspection PyProtectedMember, PyUnresolvedReferences
os._org_exit = os._exit os._org_exit = os._exit
# noinspection PyProtectedMember
# https://stackoverflow.com/a/34507557
# NOTICE: subprocesses do not exit through exit, and in most cases not with _exit,
# this means at_exit calls are Not registered respected
os._exit = _at_exit_callback os._exit = _at_exit_callback
@staticmethod @staticmethod
@ -261,3 +284,23 @@ class PatchOsFork(object):
PatchOsFork._fork_callback_after_child() PatchOsFork._fork_callback_after_child()
return ret return ret
@staticmethod
def unpatch_fork():
try:
if PatchOsFork._original_fork and os._exit != PatchOsFork._original_fork:
os._exit = PatchOsFork._original_fork
PatchOsFork._original_fork = None
except Exception:
pass
@staticmethod
def unpatch_process_run():
try:
from multiprocessing.process import BaseProcess
if PatchOsFork._original_process_run and BaseProcess.run != PatchOsFork._original_process_run:
BaseProcess.run = PatchOsFork._original_process_run
PatchOsFork._original_process_run = None
except Exception:
pass

View File

@ -355,6 +355,21 @@ def stdout_print(*args, **kwargs):
sys.stdout.write(line) sys.stdout.write(line)
def debug_print(*args, **kwargs):
"""
Print directly to stdout, with process and timestamp from last print call
Example: [pid=123, t=0.003] message here
"""
global tic
tic = globals().get('tic', time.time())
stdout_print(
"\033[1;33m[pid={}, t={:.04f}] ".format(os.getpid(), time.time()-tic)
+ str(args[0] if len(args) == 1 else ("" if not args else args)) + "\033[0m"
, **kwargs
)
tic = time.time()
if __name__ == '__main__': if __name__ == '__main__':
# from clearml import Task # from clearml import Task
# task = Task.init(project_name="examples", task_name="trace test") # task = Task.init(project_name="examples", task_name="trace test")

View File

@ -100,6 +100,7 @@ from .utilities.resource_monitor import ResourceMonitor
from .utilities.seed import make_deterministic from .utilities.seed import make_deterministic
from .utilities.lowlevel.threads import get_current_thread_id from .utilities.lowlevel.threads import get_current_thread_id
from .utilities.process.mp import BackgroundMonitor, leave_process from .utilities.process.mp import BackgroundMonitor, leave_process
from .utilities.process.exit_hooks import ExitHooks
from .utilities.matching import matches_any_wildcard from .utilities.matching import matches_any_wildcard
from .utilities.parallel import FutureTaskCaller from .utilities.parallel import FutureTaskCaller
from .utilities.networking import get_private_ip from .utilities.networking import get_private_ip
@ -487,8 +488,6 @@ class Task(_Task):
# unregister signal hooks, they cause subprocess to hang # unregister signal hooks, they cause subprocess to hang
# noinspection PyProtectedMember # noinspection PyProtectedMember
cls.__main_task.__register_at_exit(cls.__main_task._at_exit) cls.__main_task.__register_at_exit(cls.__main_task._at_exit)
# TODO: Check if the signal handler method is safe enough, for the time being, do not unhook
# cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
# start all reporting threads # start all reporting threads
BackgroundMonitor.start_all(task=cls.__main_task) BackgroundMonitor.start_all(task=cls.__main_task)
@ -636,7 +635,11 @@ class Task(_Task):
# register at exist only on the real (none deferred) Task # register at exist only on the real (none deferred) Task
if not is_deferred: if not is_deferred:
# register the main task for at exit hooks (there should only be one) # register the main task for at exit hooks (there should only be one)
# noinspection PyProtectedMember
task.__register_at_exit(task._at_exit) task.__register_at_exit(task._at_exit)
# noinspection PyProtectedMember
if cls.__exit_hook:
cls.__exit_hook.register_signal_and_exception_hooks()
# always patch OS forking because of ProcessPool and the alike # always patch OS forking because of ProcessPool and the alike
PatchOsFork.patch_fork(task) PatchOsFork.patch_fork(task)
@ -2078,6 +2081,8 @@ class Task(_Task):
# unregister atexit callbacks and signal hooks, if we are the main task # unregister atexit callbacks and signal hooks, if we are the main task
if is_main: if is_main:
self.__register_at_exit(None) self.__register_at_exit(None)
self._remove_signal_hooks()
self._remove_exception_hooks()
if not is_sub_process: if not is_sub_process:
# make sure we enable multiple Task.init callas with reporting sub-processes # make sure we enable multiple Task.init callas with reporting sub-processes
BackgroundMonitor.clear_main_process(self) BackgroundMonitor.clear_main_process(self)
@ -4220,172 +4225,21 @@ class Task(_Task):
if not is_sub_process and BackgroundMonitor.is_subprocess_enabled(): if not is_sub_process and BackgroundMonitor.is_subprocess_enabled():
BackgroundMonitor.wait_for_sub_process(self) BackgroundMonitor.wait_for_sub_process(self)
# we are done
return
@classmethod @classmethod
def __register_at_exit(cls, exit_callback, only_remove_signal_and_exception_hooks=False): def _remove_exception_hooks(cls):
class ExitHooks(object): if cls.__exit_hook:
_orig_exit = None cls.__exit_hook.remove_exception_hooks()
_orig_exc_handler = None
remote_user_aborted = False
def __init__(self, callback): @classmethod
self.exit_code = None def _remove_signal_hooks(cls):
self.exception = None if cls.__exit_hook:
self.signal = None cls.__exit_hook.remove_signal_hooks()
self._exit_callback = callback
self._org_handlers = {}
self._signal_recursion_protection_flag = False
self._except_recursion_protection_flag = False
self._import_bind_path = os.path.join("clearml", "binding", "import_bind.py")
def update_callback(self, callback):
if self._exit_callback and not six.PY2:
# noinspection PyBroadException
try:
atexit.unregister(self._exit_callback)
except Exception:
pass
self._exit_callback = callback
if callback:
self.hook()
else:
# un register int hook
if self._orig_exc_handler:
sys.excepthook = self._orig_exc_handler
self._orig_exc_handler = None
for h in self._org_handlers:
# noinspection PyBroadException
try:
signal.signal(h, self._org_handlers[h])
except Exception:
pass
self._org_handlers = {}
def hook(self):
if self._orig_exit is None:
self._orig_exit = sys.exit
sys.exit = self.exit
if self._orig_exc_handler is None:
self._orig_exc_handler = sys.excepthook
sys.excepthook = self.exc_handler
if self._exit_callback:
atexit.register(self._exit_callback)
# TODO: check if sub-process hooks are safe enough, for the time being allow it
if not self._org_handlers: # ## and not Task._Task__is_subprocess():
if sys.platform == 'win32':
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE]
else:
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
for c in catch_signals:
# noinspection PyBroadException
try:
self._org_handlers[c] = signal.getsignal(c)
signal.signal(c, self.signal_handler)
except Exception:
pass
def exit(self, code=0):
self.exit_code = code
self._orig_exit(code)
def exc_handler(self, exctype, value, traceback, *args, **kwargs):
if self._except_recursion_protection_flag:
# noinspection PyArgumentList
return sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
self._except_recursion_protection_flag = True
self.exception = value
try:
# remove us from import errors
if six.PY3 and isinstance(exctype, type) and issubclass(exctype, ImportError):
prev = cur = traceback
while cur is not None:
tb_next = cur.tb_next
# if this is the import frame, we should remove it
if cur.tb_frame.f_code.co_filename.endswith(self._import_bind_path):
# remove this frame by connecting the previous one to the next one
prev.tb_next = tb_next
cur.tb_next = None
del cur
cur = prev
prev = cur
cur = tb_next
except: # noqa
pass
if self._orig_exc_handler:
# noinspection PyArgumentList
ret = self._orig_exc_handler(exctype, value, traceback, *args, **kwargs)
else:
# noinspection PyNoneFunctionAssignment, PyArgumentList
ret = sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
self._except_recursion_protection_flag = False
return ret
def signal_handler(self, sig, frame):
self.signal = sig
org_handler = self._org_handlers.get(sig)
signal.signal(sig, org_handler or signal.SIG_DFL)
# if this is a sig term, we wait until __at_exit is called (basically do nothing)
if sig == signal.SIGINT:
# return original handler result
return org_handler if not callable(org_handler) else org_handler(sig, frame)
if self._signal_recursion_protection_flag:
# call original
os.kill(os.getpid(), sig)
return org_handler if not callable(org_handler) else org_handler(sig, frame)
self._signal_recursion_protection_flag = True
# call exit callback
if self._exit_callback:
# noinspection PyBroadException
try:
self._exit_callback()
except Exception:
pass
# remove stdout logger, just in case
# noinspection PyBroadException
try:
# noinspection PyProtectedMember
Logger._remove_std_logger()
except Exception:
pass
# noinspection PyUnresolvedReferences
os.kill(os.getpid(), sig)
self._signal_recursion_protection_flag = False
# return handler result
return org_handler if not callable(org_handler) else org_handler(sig, frame)
# we only remove the signals since this will hang subprocesses
if only_remove_signal_and_exception_hooks:
if not cls.__exit_hook:
return
if cls.__exit_hook._orig_exc_handler:
sys.excepthook = cls.__exit_hook._orig_exc_handler
cls.__exit_hook._orig_exc_handler = None
for s in cls.__exit_hook._org_handlers:
# noinspection PyBroadException
try:
signal.signal(s, cls.__exit_hook._org_handlers[s])
except Exception:
pass
cls.__exit_hook._org_handlers = {}
return
@classmethod
def __register_at_exit(cls, exit_callback):
if cls.__exit_hook is None: if cls.__exit_hook is None:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -4396,12 +4250,6 @@ class Task(_Task):
else: else:
cls.__exit_hook.update_callback(exit_callback) cls.__exit_hook.update_callback(exit_callback)
def _remove_at_exit_callbacks(self):
self.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
# noinspection PyProtectedMember
atexit.unregister(self.__exit_hook._exit_callback)
self._at_exit_called = True
@classmethod @classmethod
def __get_task( def __get_task(
cls, cls,

View File

@ -0,0 +1,168 @@
import os
import atexit
import sys
import signal
import six
from ...logger import Logger
class ExitHooks(object):
_orig_exit = None
_orig_exc_handler = None
remote_user_aborted = False
def __init__(self, callback):
self.exit_code = None
self.exception = None
self.signal = None
self._exit_callback = callback
self._org_handlers = {}
self._signal_recursion_protection_flag = False
self._except_recursion_protection_flag = False
self._import_bind_path = os.path.join("clearml", "binding", "import_bind.py")
def update_callback(self, callback):
if self._exit_callback and not six.PY2:
# noinspection PyBroadException
try:
atexit.unregister(self._exit_callback)
except Exception:
pass
self._exit_callback = callback
if callback:
self.hook()
else:
# un register int hook
if self._orig_exc_handler:
sys.excepthook = self._orig_exc_handler
self._orig_exc_handler = None
for h in self._org_handlers:
# noinspection PyBroadException
try:
signal.signal(h, self._org_handlers[h])
except Exception:
pass
self._org_handlers = {}
def hook(self):
if self._orig_exit is None:
self._orig_exit = sys.exit
sys.exit = self.exit
if self._exit_callback:
atexit.register(self._exit_callback)
def register_signal_and_exception_hooks(self):
if self._orig_exc_handler is None:
self._orig_exc_handler = sys.excepthook
sys.excepthook = self.exc_handler
if not self._org_handlers:
if sys.platform == "win32":
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE]
else:
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
for c in catch_signals:
# noinspection PyBroadException
try:
self._org_handlers[c] = signal.getsignal(c)
signal.signal(c, self.signal_handler)
except Exception:
pass
def remove_signal_hooks(self):
for org_handler_k, org_handler_v in self._org_handlers.items():
signal.signal(org_handler_k, org_handler_v)
self._org_handlers = {}
def remove_exception_hooks(self):
if self._orig_exc_handler:
sys.excepthook = self._orig_exc_handler
self._orig_exc_handler = None
def exit(self, code=0):
self.exit_code = code
self._orig_exit(code)
def exc_handler(self, exctype, value, traceback, *args, **kwargs):
if self._except_recursion_protection_flag or not self._orig_exc_handler:
# noinspection PyArgumentList
return sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
self._except_recursion_protection_flag = True
self.exception = value
try:
# remove us from import errors
if six.PY3 and isinstance(exctype, type) and issubclass(exctype, ImportError):
prev = cur = traceback
while cur is not None:
tb_next = cur.tb_next
# if this is the import frame, we should remove it
if cur.tb_frame.f_code.co_filename.endswith(self._import_bind_path):
# remove this frame by connecting the previous one to the next one
prev.tb_next = tb_next
cur.tb_next = None
del cur
cur = prev
prev = cur
cur = tb_next
except: # noqa
pass
if self._orig_exc_handler:
# noinspection PyArgumentList
ret = self._orig_exc_handler(exctype, value, traceback, *args, **kwargs)
else:
# noinspection PyNoneFunctionAssignment, PyArgumentList
ret = sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
self._except_recursion_protection_flag = False
return ret
def signal_handler(self, sig, frame):
org_handler = self._org_handlers.get(sig)
if not org_handler:
return signal.SIG_DFL
self.signal = sig
signal.signal(sig, org_handler or signal.SIG_DFL)
# if this is a sig term, we wait until __at_exit is called (basically do nothing)
if sig == signal.SIGINT:
# return original handler result
return org_handler if not callable(org_handler) else org_handler(sig, frame)
if self._signal_recursion_protection_flag:
# call original
os.kill(os.getpid(), sig)
return org_handler if not callable(org_handler) else org_handler(sig, frame)
self._signal_recursion_protection_flag = True
# call exit callback
if self._exit_callback:
# noinspection PyBroadException
try:
self._exit_callback()
except Exception:
pass
# remove stdout logger, just in case
# noinspection PyBroadException
try:
# noinspection PyProtectedMember
Logger._remove_std_logger()
except Exception:
pass
# noinspection PyUnresolvedReferences
os.kill(os.getpid(), sig)
self._signal_recursion_protection_flag = False
# return handler result
return org_handler if not callable(org_handler) else org_handler(sig, frame)

View File

@ -541,6 +541,15 @@ class BackgroundMonitor(object):
self._event.set() self._event.set()
if isinstance(self._thread, Thread): if isinstance(self._thread, Thread):
# should we wait for the thread to finish
# noinspection PyBroadException
try:
# there is a race here, and if someone else closes the
# thread it can become True/None and we will fail, it is fine
self._thread.join()
except BaseException:
pass
try: try:
self._get_instances().remove(self) self._get_instances().remove(self)
except ValueError: except ValueError:
@ -669,21 +678,23 @@ class BackgroundMonitor(object):
@classmethod @classmethod
def _background_process_start(cls, task_obj_id, event_start=None, parent_pid=None): def _background_process_start(cls, task_obj_id, event_start=None, parent_pid=None):
# type: (int, Optional[SafeEvent], Optional[int]) -> None # type: (int, Optional[SafeEvent], Optional[int]) -> None
# noinspection PyProtectedMember
is_debugger_running = bool(getattr(sys, 'gettrace', None) and sys.gettrace()) is_debugger_running = bool(getattr(sys, 'gettrace', None) and sys.gettrace())
# make sure we update the pid to our own # make sure we update the pid to our own
cls._main_process = os.getpid() cls._main_process = os.getpid()
cls._main_process_proc_obj = psutil.Process(cls._main_process) cls._main_process_proc_obj = psutil.Process(cls._main_process)
# restore original signal, this will prevent any deadlocks
# Do not change the exception we need to catch base exception as well
# noinspection PyBroadException # noinspection PyBroadException
try: try:
from ... import Task from ... import Task
# make sure we do not call Task.current_task() it will create a Task object for us on a subprocess! if Task._Task__current_task and Task._Task__current_task._Task__exit_hook: # noqa
Task._Task__current_task._Task__exit_hook.register_signal_and_exception_hooks() # noqa
# noinspection PyProtectedMember # noinspection PyProtectedMember
if Task._has_current_task_obj(): from ...binding.environ_bind import PatchOsFork
# noinspection PyProtectedMember PatchOsFork.unpatch_fork()
Task.current_task()._remove_at_exit_callbacks() PatchOsFork.unpatch_process_run()
except: # noqa except: # noqa
# Do not change the exception we need to catch base exception as well
pass pass
# if a debugger is running, wait for it to attach to the subprocess # if a debugger is running, wait for it to attach to the subprocess