Add better signal/exception binding

This commit is contained in:
allegroai 2019-06-13 01:56:21 +03:00
parent ea3b5856fd
commit cf1914fa64
2 changed files with 67 additions and 11 deletions

View File

@ -11,7 +11,7 @@ from .debugging.log import LoggerRoot
from .backend_interface.task.development.worker import DevWorker
from .backend_interface.task.log import TaskHandler
from .storage import StorageHelper
from .utilities.plotly import SeriesInfo
from .utilities.plotly_reporter import SeriesInfo
from .backend_interface import TaskStatusEnum
from .backend_interface.task import Task as _Task
from .config import running_remotely, get_cache_dir
@ -82,7 +82,7 @@ class Logger(object):
# noinspection PyBroadException
try:
Logger._stdout_original_write = sys.stdout.write
# this will only work in python 3, but we still better guard it with try/catch
# this will only work in python 3, guard it with try/catch
sys.stdout._original_write = sys.stdout.write
sys.stdout.write = stdout__patched__write__
sys.stderr._original_write = sys.stderr.write
@ -549,7 +549,8 @@ class Logger(object):
:param period: The period to flush the logger in seconds. If None or 0,
There will be no periodic flush.
"""
if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and not running_remotely():
if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and \
not running_remotely() and period is not None:
period = min(period or DevWorker.report_period, DevWorker.report_period)
if not period:
@ -562,6 +563,19 @@ class Logger(object):
self._flusher = _Flusher(self, period)
self._flusher.start()
@classmethod
def _remove_std_logger(self):
if isinstance(sys.stdout, PrintPatchLogger):
try:
sys.stdout.connect(None)
except Exception:
pass
if isinstance(sys.stderr, PrintPatchLogger):
try:
sys.stderr.connect(None)
except Exception:
pass
def _start_task_if_needed(self):
if self._task._status == TaskStatusEnum.created:
self._task.mark_started()
@ -634,6 +648,8 @@ class PrintPatchLogger(object):
self._terminal.write(message)
def connect(self, logger):
if self._log:
self._log._flush_stdout_handler()
self._log = logger
def __getattr__(self, attr):

View File

@ -3,7 +3,7 @@ import os
import sys
import threading
import time
import warnings
import signal
from argparse import ArgumentParser
from collections import OrderedDict, Callable
@ -951,10 +951,15 @@ class Task(_Task):
self.signal = None
self._exit_callback = callback
self._org_handlers = {}
self._signal_recursion_protection_flag = False
self._except_recursion_protection_flag = False
def update_callback(self, callback):
if self._exit_callback:
atexit.unregister(self._exit_callback)
if self._exit_callback and not six.PY2:
try:
atexit.unregister(self._exit_callback)
except Exception:
pass
self._exit_callback = callback
atexit.register(self._exit_callback)
@ -966,7 +971,6 @@ class Task(_Task):
self._orig_exc_handler = sys.excepthook
sys.excepthook = self.exc_handler
atexit.register(self._exit_callback)
import signal
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
for s in catch_signals:
@ -982,16 +986,52 @@ class Task(_Task):
self._orig_exit(code)
def exc_handler(self, exctype, value, traceback, *args, **kwargs):
if self._except_recursion_protection_flag:
return sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
self._except_recursion_protection_flag = True
self.exception = value
return self._orig_exc_handler(exctype, value, traceback, *args, **kwargs)
if self._orig_exc_handler:
ret = self._orig_exc_handler(exctype, value, traceback, *args, **kwargs)
else:
ret = sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
self._except_recursion_protection_flag = False
return ret
def signal_handler(self, sig, frame):
if self._signal_recursion_protection_flag:
# call original
org_handler = self._org_handlers.get(sig)
if isinstance(org_handler, Callable):
org_handler = org_handler(sig, frame)
return org_handler
self._signal_recursion_protection_flag = True
# call exit callback
self.signal = sig
if self._exit_callback:
self._exit_callback()
org_handler = self._org_handlers[sig]
# noinspection PyBroadException
try:
self._exit_callback()
except Exception:
pass
# call original signal handler
org_handler = self._org_handlers.get(sig)
if isinstance(org_handler, Callable):
return org_handler(sig, frame)
# noinspection PyBroadException
try:
org_handler = org_handler(sig, frame)
except Exception:
org_handler = signal.SIG_DFL
# remove stdout logger, just in case
# noinspection PyBroadException
try:
Logger._remove_std_logger()
except Exception:
pass
self._signal_recursion_protection_flag = False
# return handler result
return org_handler
if cls.__exit_hook is None: