diff --git a/trains/task.py b/trains/task.py index ddaf2531..f93d66ea 100644 --- a/trains/task.py +++ b/trains/task.py @@ -55,6 +55,7 @@ from .utilities.proxy_object import ProxyDictPreWrite, ProxyDictPostWrite, flatt nested_from_flat_dictionary, naive_nested_from_flat_dictionary from .utilities.resource_monitor import ResourceMonitor from .utilities.seed import make_deterministic +from .utilities.lowlevel.threads import get_current_thread_id # noinspection PyProtectedMember from .backend_interface.task.args import _Arguments @@ -2307,7 +2308,7 @@ class Task(_Task): def _at_exit(self): # protect sub-process at_exit (should never happen) - if self._at_exit_called: + if self._at_exit_called and self._at_exit_called != get_current_thread_id(): return # shutdown will clear the main, so we have to store it before. # is_main = self.is_main_task() @@ -2324,14 +2325,22 @@ class Task(_Task): """ # protect sub-process at_exit if self._at_exit_called: + # if we are called twice (signal in the middle of the shutdown), + # make sure we flush stdout, this is the best we can do. + if self._at_exit_called == get_current_thread_id() and self._logger and self.__is_subprocess(): + self._logger.set_flush_period(None) + # noinspection PyProtectedMember + self._logger._close_stdout_handler(wait=True) + self._at_exit_called = True return + # from here only a single thread can re-enter + self._at_exit_called = get_current_thread_id() + is_sub_process = self.__is_subprocess() # noinspection PyBroadException try: - # from here do not get into watch dog - self._at_exit_called = True wait_for_uploads = True # first thing mark task as stopped, so we will not end up with "running" on lost tasks # if we are running remotely, the daemon will take care of it @@ -2464,6 +2473,9 @@ class Task(_Task): pass self._edit_lock = None + # make sure no one will re-enter the shutdown method + self._at_exit_called = True + @classmethod def __register_at_exit(cls, exit_callback, only_remove_signal_and_exception_hooks=False): class ExitHooks(object): @@ -2556,8 +2568,11 @@ class Task(_Task): if self._signal_recursion_protection_flag: # call original org_handler = self._org_handlers.get(sig) - if isinstance(org_handler, Callable): + if callable(org_handler): org_handler = org_handler(sig, frame) + else: + signal.signal(sig, org_handler or signal.SIG_DFL) + os.kill(os.getpid(), sig) return org_handler self._signal_recursion_protection_flag = True @@ -2571,12 +2586,13 @@ class Task(_Task): pass # call original signal handler org_handler = self._org_handlers.get(sig) - if isinstance(org_handler, Callable): - # noinspection PyBroadException - try: - org_handler = org_handler(sig, frame) - except Exception: - org_handler = signal.SIG_DFL + self._org_handlers[sig] = None + if callable(org_handler): + ret = org_handler(sig, frame) + else: + signal.signal(sig, org_handler or signal.SIG_DFL) + ret = 0 + # remove stdout logger, just in case # noinspection PyBroadException try: @@ -2584,9 +2600,14 @@ class Task(_Task): Logger._remove_std_logger() except Exception: pass + + if not callable(org_handler): + os.kill(os.getpid(), sig) + self._signal_recursion_protection_flag = False + # return handler result - return org_handler + return ret # we only remove the signals since this will hang subprocesses if only_remove_signal_and_exception_hooks: diff --git a/trains/utilities/lowlevel/threads.py b/trains/utilities/lowlevel/threads.py index 0bf92373..2499afda 100644 --- a/trains/utilities/lowlevel/threads.py +++ b/trains/utilities/lowlevel/threads.py @@ -1,9 +1,14 @@ import ctypes import threading +import six import sys import time +def get_current_thread_id(): + return threading._get_ident() if six.PY2 else threading.get_ident() + + # Nasty hack to raise exception for other threads def _lowlevel_async_raise(thread_obj, exception=None): NULL = 0