diff --git a/trains/logger.py b/trains/logger.py index 629e5943..7bfa881c 100644 --- a/trains/logger.py +++ b/trains/logger.py @@ -11,7 +11,7 @@ from .debugging.log import LoggerRoot from .backend_interface.task.development.worker import DevWorker from .backend_interface.task.log import TaskHandler from .storage import StorageHelper -from .utilities.plotly import SeriesInfo +from .utilities.plotly_reporter import SeriesInfo from .backend_interface import TaskStatusEnum from .backend_interface.task import Task as _Task from .config import running_remotely, get_cache_dir @@ -82,7 +82,7 @@ class Logger(object): # noinspection PyBroadException try: Logger._stdout_original_write = sys.stdout.write - # this will only work in python 3, but we still better guard it with try/catch + # this will only work in python 3, guard it with try/catch sys.stdout._original_write = sys.stdout.write sys.stdout.write = stdout__patched__write__ sys.stderr._original_write = sys.stderr.write @@ -549,7 +549,8 @@ class Logger(object): :param period: The period to flush the logger in seconds. If None or 0, There will be no periodic flush. """ - if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and not running_remotely(): + if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and \ + not running_remotely() and period is not None: period = min(period or DevWorker.report_period, DevWorker.report_period) if not period: @@ -562,6 +563,19 @@ class Logger(object): self._flusher = _Flusher(self, period) self._flusher.start() + @classmethod + def _remove_std_logger(self): + if isinstance(sys.stdout, PrintPatchLogger): + try: + sys.stdout.connect(None) + except Exception: + pass + if isinstance(sys.stderr, PrintPatchLogger): + try: + sys.stderr.connect(None) + except Exception: + pass + def _start_task_if_needed(self): if self._task._status == TaskStatusEnum.created: self._task.mark_started() @@ -634,6 +648,8 @@ class PrintPatchLogger(object): self._terminal.write(message) def connect(self, logger): + if self._log: + self._log._flush_stdout_handler() self._log = logger def __getattr__(self, attr): diff --git a/trains/task.py b/trains/task.py index 5c121b38..f460f329 100644 --- a/trains/task.py +++ b/trains/task.py @@ -3,7 +3,7 @@ import os import sys import threading import time -import warnings +import signal from argparse import ArgumentParser from collections import OrderedDict, Callable @@ -951,10 +951,15 @@ class Task(_Task): self.signal = None self._exit_callback = callback self._org_handlers = {} + self._signal_recursion_protection_flag = False + self._except_recursion_protection_flag = False def update_callback(self, callback): - if self._exit_callback: - atexit.unregister(self._exit_callback) + if self._exit_callback and not six.PY2: + try: + atexit.unregister(self._exit_callback) + except Exception: + pass self._exit_callback = callback atexit.register(self._exit_callback) @@ -966,7 +971,6 @@ class Task(_Task): self._orig_exc_handler = sys.excepthook sys.excepthook = self.exc_handler atexit.register(self._exit_callback) - import signal catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT, signal.SIGILL, signal.SIGFPE, signal.SIGQUIT] for s in catch_signals: @@ -982,16 +986,52 @@ class Task(_Task): self._orig_exit(code) def exc_handler(self, exctype, value, traceback, *args, **kwargs): + if self._except_recursion_protection_flag: + return sys.__excepthook__(exctype, value, traceback, *args, **kwargs) + + self._except_recursion_protection_flag = True self.exception = value - return self._orig_exc_handler(exctype, value, traceback, *args, **kwargs) + if self._orig_exc_handler: + ret = self._orig_exc_handler(exctype, value, traceback, *args, **kwargs) + else: + ret = sys.__excepthook__(exctype, value, traceback, *args, **kwargs) + self._except_recursion_protection_flag = False + + return ret def signal_handler(self, sig, frame): + if self._signal_recursion_protection_flag: + # call original + org_handler = self._org_handlers.get(sig) + if isinstance(org_handler, Callable): + org_handler = org_handler(sig, frame) + return org_handler + + self._signal_recursion_protection_flag = True + # call exit callback self.signal = sig if self._exit_callback: - self._exit_callback() - org_handler = self._org_handlers[sig] + # noinspection PyBroadException + try: + self._exit_callback() + except Exception: + pass + # call original signal handler + org_handler = self._org_handlers.get(sig) if isinstance(org_handler, Callable): - return org_handler(sig, frame) + # noinspection PyBroadException + try: + org_handler = org_handler(sig, frame) + except Exception: + org_handler = signal.SIG_DFL + # remove stdout logger, just in case + # noinspection PyBroadException + try: + Logger._remove_std_logger() + except Exception: + pass + self._signal_recursion_protection_flag = False + # return handler result return org_handler if cls.__exit_hook is None: