From 8cb7d14abb37f5032c80c476005c17da1ddf8993 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 13 Dec 2023 16:41:21 +0200 Subject: [PATCH] Fix process sometimes hangs issue by improving exit and exception handlers, unregistering signal handling in child processes except for the BackgroundMonitor --- clearml/binding/environ_bind.py | 49 +++++- clearml/debugging/trace.py | 15 ++ clearml/task.py | 190 +++--------------------- clearml/utilities/process/exit_hooks.py | 168 +++++++++++++++++++++ clearml/utilities/process/mp.py | 23 ++- 5 files changed, 265 insertions(+), 180 deletions(-) create mode 100644 clearml/utilities/process/exit_hooks.py diff --git a/clearml/binding/environ_bind.py b/clearml/binding/environ_bind.py index 69b6ffb3..186d52c8 100644 --- a/clearml/binding/environ_bind.py +++ b/clearml/binding/environ_bind.py @@ -175,10 +175,11 @@ class PatchOsFork(object): try: return PatchOsFork._original_process_run(self, *args, **kwargs) finally: - if task: + if task and patched_worker: try: - if patched_worker: - # remove at exit hooks, we will deadlock when the + # noinspection PyProtectedMember + if task._report_subprocess_enabled: + # just in case, remove at exit hooks, we will deadlock when the # main Pool manager will terminate this process, and it will... # noinspection PyProtectedMember task._at_exit_called = True @@ -214,12 +215,30 @@ class PatchOsFork(object): if not task: return + if not Task._report_subprocess_enabled: + # https://stackoverflow.com/a/34507557 + # NOTICE: subprocesses do not exit through exit we have to register signals + if task._Task__exit_hook: + task._Task__exit_hook.register_signal_and_exception_hooks() + else: + # noinspection PyProtectedMember + task._remove_signal_hooks() + + # noinspection PyProtectedMember + if Task._report_subprocess_enabled: + # noinspection PyProtectedMember + task._remove_exception_hooks() + PatchOsFork._current_task = task # # Hack: now make sure we setup the reporter threads (Log+Reporter) # noinspection PyProtectedMember if not bool(task._report_subprocess_enabled): BackgroundMonitor.start_all(task=task) + # if we are reporting into a subprocess, no need to further patch the exit functions + if Task._report_subprocess_enabled: + return + # The signal handler method is Not enough, for the time being, we have both # even though it makes little sense # # if we got here patch the os._exit of our instance to call us @@ -244,6 +263,10 @@ class PatchOsFork(object): # noinspection PyProtectedMember, PyUnresolvedReferences os._org_exit = os._exit + # noinspection PyProtectedMember + # https://stackoverflow.com/a/34507557 + # NOTICE: subprocesses do not exit through exit, and in most cases not with _exit, + # this means at_exit calls are Not registered respected os._exit = _at_exit_callback @staticmethod @@ -261,3 +284,23 @@ class PatchOsFork(object): PatchOsFork._fork_callback_after_child() return ret + + @staticmethod + def unpatch_fork(): + try: + if PatchOsFork._original_fork and os._exit != PatchOsFork._original_fork: + os._exit = PatchOsFork._original_fork + PatchOsFork._original_fork = None + except Exception: + pass + + @staticmethod + def unpatch_process_run(): + try: + from multiprocessing.process import BaseProcess + + if PatchOsFork._original_process_run and BaseProcess.run != PatchOsFork._original_process_run: + BaseProcess.run = PatchOsFork._original_process_run + PatchOsFork._original_process_run = None + except Exception: + pass diff --git a/clearml/debugging/trace.py b/clearml/debugging/trace.py index c9bd3263..c5f46e2c 100644 --- a/clearml/debugging/trace.py +++ b/clearml/debugging/trace.py @@ -355,6 +355,21 @@ def stdout_print(*args, **kwargs): sys.stdout.write(line) +def debug_print(*args, **kwargs): + """ + Print directly to stdout, with process and timestamp from last print call + Example: [pid=123, t=0.003] message here + """ + global tic + tic = globals().get('tic', time.time()) + stdout_print( + "\033[1;33m[pid={}, t={:.04f}] ".format(os.getpid(), time.time()-tic) + + str(args[0] if len(args) == 1 else ("" if not args else args)) + "\033[0m" + , **kwargs + ) + tic = time.time() + + if __name__ == '__main__': # from clearml import Task # task = Task.init(project_name="examples", task_name="trace test") diff --git a/clearml/task.py b/clearml/task.py index 153a1695..475793c8 100644 --- a/clearml/task.py +++ b/clearml/task.py @@ -100,6 +100,7 @@ from .utilities.resource_monitor import ResourceMonitor from .utilities.seed import make_deterministic from .utilities.lowlevel.threads import get_current_thread_id from .utilities.process.mp import BackgroundMonitor, leave_process +from .utilities.process.exit_hooks import ExitHooks from .utilities.matching import matches_any_wildcard from .utilities.parallel import FutureTaskCaller from .utilities.networking import get_private_ip @@ -487,8 +488,6 @@ class Task(_Task): # unregister signal hooks, they cause subprocess to hang # noinspection PyProtectedMember cls.__main_task.__register_at_exit(cls.__main_task._at_exit) - # TODO: Check if the signal handler method is safe enough, for the time being, do not unhook - # cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True) # start all reporting threads BackgroundMonitor.start_all(task=cls.__main_task) @@ -636,7 +635,11 @@ class Task(_Task): # register at exist only on the real (none deferred) Task if not is_deferred: # register the main task for at exit hooks (there should only be one) + # noinspection PyProtectedMember task.__register_at_exit(task._at_exit) + # noinspection PyProtectedMember + if cls.__exit_hook: + cls.__exit_hook.register_signal_and_exception_hooks() # always patch OS forking because of ProcessPool and the alike PatchOsFork.patch_fork(task) @@ -2078,6 +2081,8 @@ class Task(_Task): # unregister atexit callbacks and signal hooks, if we are the main task if is_main: self.__register_at_exit(None) + self._remove_signal_hooks() + self._remove_exception_hooks() if not is_sub_process: # make sure we enable multiple Task.init callas with reporting sub-processes BackgroundMonitor.clear_main_process(self) @@ -4220,172 +4225,21 @@ class Task(_Task): if not is_sub_process and BackgroundMonitor.is_subprocess_enabled(): BackgroundMonitor.wait_for_sub_process(self) + # we are done + return + @classmethod - def __register_at_exit(cls, exit_callback, only_remove_signal_and_exception_hooks=False): - class ExitHooks(object): - _orig_exit = None - _orig_exc_handler = None - remote_user_aborted = False + def _remove_exception_hooks(cls): + if cls.__exit_hook: + cls.__exit_hook.remove_exception_hooks() - def __init__(self, callback): - self.exit_code = None - self.exception = None - self.signal = None - self._exit_callback = callback - self._org_handlers = {} - self._signal_recursion_protection_flag = False - self._except_recursion_protection_flag = False - self._import_bind_path = os.path.join("clearml", "binding", "import_bind.py") - - def update_callback(self, callback): - if self._exit_callback and not six.PY2: - # noinspection PyBroadException - try: - atexit.unregister(self._exit_callback) - except Exception: - pass - self._exit_callback = callback - if callback: - self.hook() - else: - # un register int hook - if self._orig_exc_handler: - sys.excepthook = self._orig_exc_handler - self._orig_exc_handler = None - for h in self._org_handlers: - # noinspection PyBroadException - try: - signal.signal(h, self._org_handlers[h]) - except Exception: - pass - self._org_handlers = {} - - def hook(self): - if self._orig_exit is None: - self._orig_exit = sys.exit - sys.exit = self.exit - - if self._orig_exc_handler is None: - self._orig_exc_handler = sys.excepthook - sys.excepthook = self.exc_handler - - if self._exit_callback: - atexit.register(self._exit_callback) - - # TODO: check if sub-process hooks are safe enough, for the time being allow it - if not self._org_handlers: # ## and not Task._Task__is_subprocess(): - if sys.platform == 'win32': - catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT, - signal.SIGILL, signal.SIGFPE] - else: - catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT, - signal.SIGILL, signal.SIGFPE, signal.SIGQUIT] - for c in catch_signals: - # noinspection PyBroadException - try: - self._org_handlers[c] = signal.getsignal(c) - signal.signal(c, self.signal_handler) - except Exception: - pass - - def exit(self, code=0): - self.exit_code = code - self._orig_exit(code) - - def exc_handler(self, exctype, value, traceback, *args, **kwargs): - if self._except_recursion_protection_flag: - # noinspection PyArgumentList - return sys.__excepthook__(exctype, value, traceback, *args, **kwargs) - - self._except_recursion_protection_flag = True - self.exception = value - - try: - # remove us from import errors - if six.PY3 and isinstance(exctype, type) and issubclass(exctype, ImportError): - prev = cur = traceback - while cur is not None: - tb_next = cur.tb_next - # if this is the import frame, we should remove it - if cur.tb_frame.f_code.co_filename.endswith(self._import_bind_path): - # remove this frame by connecting the previous one to the next one - prev.tb_next = tb_next - cur.tb_next = None - del cur - cur = prev - - prev = cur - cur = tb_next - except: # noqa - pass - - if self._orig_exc_handler: - # noinspection PyArgumentList - ret = self._orig_exc_handler(exctype, value, traceback, *args, **kwargs) - else: - # noinspection PyNoneFunctionAssignment, PyArgumentList - ret = sys.__excepthook__(exctype, value, traceback, *args, **kwargs) - self._except_recursion_protection_flag = False - - return ret - - def signal_handler(self, sig, frame): - self.signal = sig - - org_handler = self._org_handlers.get(sig) - signal.signal(sig, org_handler or signal.SIG_DFL) - - # if this is a sig term, we wait until __at_exit is called (basically do nothing) - if sig == signal.SIGINT: - # return original handler result - return org_handler if not callable(org_handler) else org_handler(sig, frame) - - if self._signal_recursion_protection_flag: - # call original - os.kill(os.getpid(), sig) - return org_handler if not callable(org_handler) else org_handler(sig, frame) - - self._signal_recursion_protection_flag = True - - # call exit callback - if self._exit_callback: - # noinspection PyBroadException - try: - self._exit_callback() - except Exception: - pass - - # remove stdout logger, just in case - # noinspection PyBroadException - try: - # noinspection PyProtectedMember - Logger._remove_std_logger() - except Exception: - pass - - # noinspection PyUnresolvedReferences - os.kill(os.getpid(), sig) - - self._signal_recursion_protection_flag = False - # return handler result - return org_handler if not callable(org_handler) else org_handler(sig, frame) - - # we only remove the signals since this will hang subprocesses - if only_remove_signal_and_exception_hooks: - if not cls.__exit_hook: - return - if cls.__exit_hook._orig_exc_handler: - sys.excepthook = cls.__exit_hook._orig_exc_handler - cls.__exit_hook._orig_exc_handler = None - for s in cls.__exit_hook._org_handlers: - # noinspection PyBroadException - try: - signal.signal(s, cls.__exit_hook._org_handlers[s]) - except Exception: - pass - cls.__exit_hook._org_handlers = {} - return + @classmethod + def _remove_signal_hooks(cls): + if cls.__exit_hook: + cls.__exit_hook.remove_signal_hooks() + @classmethod + def __register_at_exit(cls, exit_callback): if cls.__exit_hook is None: # noinspection PyBroadException try: @@ -4396,12 +4250,6 @@ class Task(_Task): else: cls.__exit_hook.update_callback(exit_callback) - def _remove_at_exit_callbacks(self): - self.__register_at_exit(None, only_remove_signal_and_exception_hooks=True) - # noinspection PyProtectedMember - atexit.unregister(self.__exit_hook._exit_callback) - self._at_exit_called = True - @classmethod def __get_task( cls, diff --git a/clearml/utilities/process/exit_hooks.py b/clearml/utilities/process/exit_hooks.py new file mode 100644 index 00000000..24311396 --- /dev/null +++ b/clearml/utilities/process/exit_hooks.py @@ -0,0 +1,168 @@ +import os +import atexit +import sys +import signal +import six +from ...logger import Logger + + +class ExitHooks(object): + _orig_exit = None + _orig_exc_handler = None + remote_user_aborted = False + + def __init__(self, callback): + self.exit_code = None + self.exception = None + self.signal = None + self._exit_callback = callback + self._org_handlers = {} + self._signal_recursion_protection_flag = False + self._except_recursion_protection_flag = False + self._import_bind_path = os.path.join("clearml", "binding", "import_bind.py") + + def update_callback(self, callback): + if self._exit_callback and not six.PY2: + # noinspection PyBroadException + try: + atexit.unregister(self._exit_callback) + except Exception: + pass + self._exit_callback = callback + if callback: + self.hook() + else: + # un register int hook + if self._orig_exc_handler: + sys.excepthook = self._orig_exc_handler + self._orig_exc_handler = None + for h in self._org_handlers: + # noinspection PyBroadException + try: + signal.signal(h, self._org_handlers[h]) + except Exception: + pass + self._org_handlers = {} + + def hook(self): + if self._orig_exit is None: + self._orig_exit = sys.exit + sys.exit = self.exit + + if self._exit_callback: + atexit.register(self._exit_callback) + + def register_signal_and_exception_hooks(self): + if self._orig_exc_handler is None: + self._orig_exc_handler = sys.excepthook + + sys.excepthook = self.exc_handler + + if not self._org_handlers: + if sys.platform == "win32": + catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT, + signal.SIGILL, signal.SIGFPE] + else: + catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT, + signal.SIGILL, signal.SIGFPE, signal.SIGQUIT] + for c in catch_signals: + # noinspection PyBroadException + try: + self._org_handlers[c] = signal.getsignal(c) + signal.signal(c, self.signal_handler) + except Exception: + pass + + def remove_signal_hooks(self): + for org_handler_k, org_handler_v in self._org_handlers.items(): + signal.signal(org_handler_k, org_handler_v) + self._org_handlers = {} + + def remove_exception_hooks(self): + if self._orig_exc_handler: + sys.excepthook = self._orig_exc_handler + self._orig_exc_handler = None + + def exit(self, code=0): + self.exit_code = code + self._orig_exit(code) + + def exc_handler(self, exctype, value, traceback, *args, **kwargs): + if self._except_recursion_protection_flag or not self._orig_exc_handler: + # noinspection PyArgumentList + return sys.__excepthook__(exctype, value, traceback, *args, **kwargs) + + self._except_recursion_protection_flag = True + self.exception = value + + try: + # remove us from import errors + if six.PY3 and isinstance(exctype, type) and issubclass(exctype, ImportError): + prev = cur = traceback + while cur is not None: + tb_next = cur.tb_next + # if this is the import frame, we should remove it + if cur.tb_frame.f_code.co_filename.endswith(self._import_bind_path): + # remove this frame by connecting the previous one to the next one + prev.tb_next = tb_next + cur.tb_next = None + del cur + cur = prev + + prev = cur + cur = tb_next + except: # noqa + pass + + if self._orig_exc_handler: + # noinspection PyArgumentList + ret = self._orig_exc_handler(exctype, value, traceback, *args, **kwargs) + else: + # noinspection PyNoneFunctionAssignment, PyArgumentList + ret = sys.__excepthook__(exctype, value, traceback, *args, **kwargs) + self._except_recursion_protection_flag = False + + return ret + + def signal_handler(self, sig, frame): + org_handler = self._org_handlers.get(sig) + if not org_handler: + return signal.SIG_DFL + + self.signal = sig + signal.signal(sig, org_handler or signal.SIG_DFL) + + # if this is a sig term, we wait until __at_exit is called (basically do nothing) + if sig == signal.SIGINT: + # return original handler result + return org_handler if not callable(org_handler) else org_handler(sig, frame) + + if self._signal_recursion_protection_flag: + # call original + os.kill(os.getpid(), sig) + return org_handler if not callable(org_handler) else org_handler(sig, frame) + + self._signal_recursion_protection_flag = True + + # call exit callback + if self._exit_callback: + # noinspection PyBroadException + try: + self._exit_callback() + except Exception: + pass + + # remove stdout logger, just in case + # noinspection PyBroadException + try: + # noinspection PyProtectedMember + Logger._remove_std_logger() + except Exception: + pass + + # noinspection PyUnresolvedReferences + os.kill(os.getpid(), sig) + + self._signal_recursion_protection_flag = False + # return handler result + return org_handler if not callable(org_handler) else org_handler(sig, frame) diff --git a/clearml/utilities/process/mp.py b/clearml/utilities/process/mp.py index 24baec76..6a1d95e6 100644 --- a/clearml/utilities/process/mp.py +++ b/clearml/utilities/process/mp.py @@ -541,6 +541,15 @@ class BackgroundMonitor(object): self._event.set() if isinstance(self._thread, Thread): + # should we wait for the thread to finish + # noinspection PyBroadException + try: + # there is a race here, and if someone else closes the + # thread it can become True/None and we will fail, it is fine + self._thread.join() + except BaseException: + pass + try: self._get_instances().remove(self) except ValueError: @@ -669,21 +678,23 @@ class BackgroundMonitor(object): @classmethod def _background_process_start(cls, task_obj_id, event_start=None, parent_pid=None): # type: (int, Optional[SafeEvent], Optional[int]) -> None + # noinspection PyProtectedMember is_debugger_running = bool(getattr(sys, 'gettrace', None) and sys.gettrace()) # make sure we update the pid to our own cls._main_process = os.getpid() cls._main_process_proc_obj = psutil.Process(cls._main_process) - # restore original signal, this will prevent any deadlocks - # Do not change the exception we need to catch base exception as well # noinspection PyBroadException try: from ... import Task - # make sure we do not call Task.current_task() it will create a Task object for us on a subprocess! + if Task._Task__current_task and Task._Task__current_task._Task__exit_hook: # noqa + Task._Task__current_task._Task__exit_hook.register_signal_and_exception_hooks() # noqa + # noinspection PyProtectedMember - if Task._has_current_task_obj(): - # noinspection PyProtectedMember - Task.current_task()._remove_at_exit_callbacks() + from ...binding.environ_bind import PatchOsFork + PatchOsFork.unpatch_fork() + PatchOsFork.unpatch_process_run() except: # noqa + # Do not change the exception we need to catch base exception as well pass # if a debugger is running, wait for it to attach to the subprocess