Add better signal/exception binding

This commit is contained in:
allegroai 2019-06-13 01:56:21 +03:00
parent ea3b5856fd
commit cf1914fa64
2 changed files with 67 additions and 11 deletions

View File

@ -11,7 +11,7 @@ from .debugging.log import LoggerRoot
from .backend_interface.task.development.worker import DevWorker from .backend_interface.task.development.worker import DevWorker
from .backend_interface.task.log import TaskHandler from .backend_interface.task.log import TaskHandler
from .storage import StorageHelper from .storage import StorageHelper
from .utilities.plotly import SeriesInfo from .utilities.plotly_reporter import SeriesInfo
from .backend_interface import TaskStatusEnum from .backend_interface import TaskStatusEnum
from .backend_interface.task import Task as _Task from .backend_interface.task import Task as _Task
from .config import running_remotely, get_cache_dir from .config import running_remotely, get_cache_dir
@ -82,7 +82,7 @@ class Logger(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
Logger._stdout_original_write = sys.stdout.write Logger._stdout_original_write = sys.stdout.write
# this will only work in python 3, but we still better guard it with try/catch # this will only work in python 3, guard it with try/catch
sys.stdout._original_write = sys.stdout.write sys.stdout._original_write = sys.stdout.write
sys.stdout.write = stdout__patched__write__ sys.stdout.write = stdout__patched__write__
sys.stderr._original_write = sys.stderr.write sys.stderr._original_write = sys.stderr.write
@ -549,7 +549,8 @@ class Logger(object):
:param period: The period to flush the logger in seconds. If None or 0, :param period: The period to flush the logger in seconds. If None or 0,
There will be no periodic flush. There will be no periodic flush.
""" """
if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and not running_remotely(): if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and \
not running_remotely() and period is not None:
period = min(period or DevWorker.report_period, DevWorker.report_period) period = min(period or DevWorker.report_period, DevWorker.report_period)
if not period: if not period:
@ -562,6 +563,19 @@ class Logger(object):
self._flusher = _Flusher(self, period) self._flusher = _Flusher(self, period)
self._flusher.start() self._flusher.start()
@classmethod
def _remove_std_logger(self):
if isinstance(sys.stdout, PrintPatchLogger):
try:
sys.stdout.connect(None)
except Exception:
pass
if isinstance(sys.stderr, PrintPatchLogger):
try:
sys.stderr.connect(None)
except Exception:
pass
def _start_task_if_needed(self): def _start_task_if_needed(self):
if self._task._status == TaskStatusEnum.created: if self._task._status == TaskStatusEnum.created:
self._task.mark_started() self._task.mark_started()
@ -634,6 +648,8 @@ class PrintPatchLogger(object):
self._terminal.write(message) self._terminal.write(message)
def connect(self, logger): def connect(self, logger):
if self._log:
self._log._flush_stdout_handler()
self._log = logger self._log = logger
def __getattr__(self, attr): def __getattr__(self, attr):

View File

@ -3,7 +3,7 @@ import os
import sys import sys
import threading import threading
import time import time
import warnings import signal
from argparse import ArgumentParser from argparse import ArgumentParser
from collections import OrderedDict, Callable from collections import OrderedDict, Callable
@ -951,10 +951,15 @@ class Task(_Task):
self.signal = None self.signal = None
self._exit_callback = callback self._exit_callback = callback
self._org_handlers = {} self._org_handlers = {}
self._signal_recursion_protection_flag = False
self._except_recursion_protection_flag = False
def update_callback(self, callback): def update_callback(self, callback):
if self._exit_callback: if self._exit_callback and not six.PY2:
try:
atexit.unregister(self._exit_callback) atexit.unregister(self._exit_callback)
except Exception:
pass
self._exit_callback = callback self._exit_callback = callback
atexit.register(self._exit_callback) atexit.register(self._exit_callback)
@ -966,7 +971,6 @@ class Task(_Task):
self._orig_exc_handler = sys.excepthook self._orig_exc_handler = sys.excepthook
sys.excepthook = self.exc_handler sys.excepthook = self.exc_handler
atexit.register(self._exit_callback) atexit.register(self._exit_callback)
import signal
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT, catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT] signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
for s in catch_signals: for s in catch_signals:
@ -982,16 +986,52 @@ class Task(_Task):
self._orig_exit(code) self._orig_exit(code)
def exc_handler(self, exctype, value, traceback, *args, **kwargs): def exc_handler(self, exctype, value, traceback, *args, **kwargs):
if self._except_recursion_protection_flag:
return sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
self._except_recursion_protection_flag = True
self.exception = value self.exception = value
return self._orig_exc_handler(exctype, value, traceback, *args, **kwargs) if self._orig_exc_handler:
ret = self._orig_exc_handler(exctype, value, traceback, *args, **kwargs)
else:
ret = sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
self._except_recursion_protection_flag = False
return ret
def signal_handler(self, sig, frame): def signal_handler(self, sig, frame):
if self._signal_recursion_protection_flag:
# call original
org_handler = self._org_handlers.get(sig)
if isinstance(org_handler, Callable):
org_handler = org_handler(sig, frame)
return org_handler
self._signal_recursion_protection_flag = True
# call exit callback
self.signal = sig self.signal = sig
if self._exit_callback: if self._exit_callback:
# noinspection PyBroadException
try:
self._exit_callback() self._exit_callback()
org_handler = self._org_handlers[sig] except Exception:
pass
# call original signal handler
org_handler = self._org_handlers.get(sig)
if isinstance(org_handler, Callable): if isinstance(org_handler, Callable):
return org_handler(sig, frame) # noinspection PyBroadException
try:
org_handler = org_handler(sig, frame)
except Exception:
org_handler = signal.SIG_DFL
# remove stdout logger, just in case
# noinspection PyBroadException
try:
Logger._remove_std_logger()
except Exception:
pass
self._signal_recursion_protection_flag = False
# return handler result
return org_handler return org_handler
if cls.__exit_hook is None: if cls.__exit_hook is None: