mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Add better signal/exception binding
This commit is contained in:
parent
ea3b5856fd
commit
cf1914fa64
@ -11,7 +11,7 @@ from .debugging.log import LoggerRoot
|
|||||||
from .backend_interface.task.development.worker import DevWorker
|
from .backend_interface.task.development.worker import DevWorker
|
||||||
from .backend_interface.task.log import TaskHandler
|
from .backend_interface.task.log import TaskHandler
|
||||||
from .storage import StorageHelper
|
from .storage import StorageHelper
|
||||||
from .utilities.plotly import SeriesInfo
|
from .utilities.plotly_reporter import SeriesInfo
|
||||||
from .backend_interface import TaskStatusEnum
|
from .backend_interface import TaskStatusEnum
|
||||||
from .backend_interface.task import Task as _Task
|
from .backend_interface.task import Task as _Task
|
||||||
from .config import running_remotely, get_cache_dir
|
from .config import running_remotely, get_cache_dir
|
||||||
@ -82,7 +82,7 @@ class Logger(object):
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
Logger._stdout_original_write = sys.stdout.write
|
Logger._stdout_original_write = sys.stdout.write
|
||||||
# this will only work in python 3, but we still better guard it with try/catch
|
# this will only work in python 3, guard it with try/catch
|
||||||
sys.stdout._original_write = sys.stdout.write
|
sys.stdout._original_write = sys.stdout.write
|
||||||
sys.stdout.write = stdout__patched__write__
|
sys.stdout.write = stdout__patched__write__
|
||||||
sys.stderr._original_write = sys.stderr.write
|
sys.stderr._original_write = sys.stderr.write
|
||||||
@ -549,7 +549,8 @@ class Logger(object):
|
|||||||
:param period: The period to flush the logger in seconds. If None or 0,
|
:param period: The period to flush the logger in seconds. If None or 0,
|
||||||
There will be no periodic flush.
|
There will be no periodic flush.
|
||||||
"""
|
"""
|
||||||
if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and not running_remotely():
|
if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and \
|
||||||
|
not running_remotely() and period is not None:
|
||||||
period = min(period or DevWorker.report_period, DevWorker.report_period)
|
period = min(period or DevWorker.report_period, DevWorker.report_period)
|
||||||
|
|
||||||
if not period:
|
if not period:
|
||||||
@ -562,6 +563,19 @@ class Logger(object):
|
|||||||
self._flusher = _Flusher(self, period)
|
self._flusher = _Flusher(self, period)
|
||||||
self._flusher.start()
|
self._flusher.start()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _remove_std_logger(self):
|
||||||
|
if isinstance(sys.stdout, PrintPatchLogger):
|
||||||
|
try:
|
||||||
|
sys.stdout.connect(None)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if isinstance(sys.stderr, PrintPatchLogger):
|
||||||
|
try:
|
||||||
|
sys.stderr.connect(None)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _start_task_if_needed(self):
|
def _start_task_if_needed(self):
|
||||||
if self._task._status == TaskStatusEnum.created:
|
if self._task._status == TaskStatusEnum.created:
|
||||||
self._task.mark_started()
|
self._task.mark_started()
|
||||||
@ -634,6 +648,8 @@ class PrintPatchLogger(object):
|
|||||||
self._terminal.write(message)
|
self._terminal.write(message)
|
||||||
|
|
||||||
def connect(self, logger):
|
def connect(self, logger):
|
||||||
|
if self._log:
|
||||||
|
self._log._flush_stdout_handler()
|
||||||
self._log = logger
|
self._log = logger
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
|
@ -3,7 +3,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import warnings
|
import signal
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from collections import OrderedDict, Callable
|
from collections import OrderedDict, Callable
|
||||||
|
|
||||||
@ -951,10 +951,15 @@ class Task(_Task):
|
|||||||
self.signal = None
|
self.signal = None
|
||||||
self._exit_callback = callback
|
self._exit_callback = callback
|
||||||
self._org_handlers = {}
|
self._org_handlers = {}
|
||||||
|
self._signal_recursion_protection_flag = False
|
||||||
|
self._except_recursion_protection_flag = False
|
||||||
|
|
||||||
def update_callback(self, callback):
|
def update_callback(self, callback):
|
||||||
if self._exit_callback:
|
if self._exit_callback and not six.PY2:
|
||||||
|
try:
|
||||||
atexit.unregister(self._exit_callback)
|
atexit.unregister(self._exit_callback)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
self._exit_callback = callback
|
self._exit_callback = callback
|
||||||
atexit.register(self._exit_callback)
|
atexit.register(self._exit_callback)
|
||||||
|
|
||||||
@ -966,7 +971,6 @@ class Task(_Task):
|
|||||||
self._orig_exc_handler = sys.excepthook
|
self._orig_exc_handler = sys.excepthook
|
||||||
sys.excepthook = self.exc_handler
|
sys.excepthook = self.exc_handler
|
||||||
atexit.register(self._exit_callback)
|
atexit.register(self._exit_callback)
|
||||||
import signal
|
|
||||||
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
|
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
|
||||||
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
|
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
|
||||||
for s in catch_signals:
|
for s in catch_signals:
|
||||||
@ -982,16 +986,52 @@ class Task(_Task):
|
|||||||
self._orig_exit(code)
|
self._orig_exit(code)
|
||||||
|
|
||||||
def exc_handler(self, exctype, value, traceback, *args, **kwargs):
|
def exc_handler(self, exctype, value, traceback, *args, **kwargs):
|
||||||
|
if self._except_recursion_protection_flag:
|
||||||
|
return sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
|
||||||
|
|
||||||
|
self._except_recursion_protection_flag = True
|
||||||
self.exception = value
|
self.exception = value
|
||||||
return self._orig_exc_handler(exctype, value, traceback, *args, **kwargs)
|
if self._orig_exc_handler:
|
||||||
|
ret = self._orig_exc_handler(exctype, value, traceback, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
ret = sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
|
||||||
|
self._except_recursion_protection_flag = False
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
def signal_handler(self, sig, frame):
|
def signal_handler(self, sig, frame):
|
||||||
|
if self._signal_recursion_protection_flag:
|
||||||
|
# call original
|
||||||
|
org_handler = self._org_handlers.get(sig)
|
||||||
|
if isinstance(org_handler, Callable):
|
||||||
|
org_handler = org_handler(sig, frame)
|
||||||
|
return org_handler
|
||||||
|
|
||||||
|
self._signal_recursion_protection_flag = True
|
||||||
|
# call exit callback
|
||||||
self.signal = sig
|
self.signal = sig
|
||||||
if self._exit_callback:
|
if self._exit_callback:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
self._exit_callback()
|
self._exit_callback()
|
||||||
org_handler = self._org_handlers[sig]
|
except Exception:
|
||||||
|
pass
|
||||||
|
# call original signal handler
|
||||||
|
org_handler = self._org_handlers.get(sig)
|
||||||
if isinstance(org_handler, Callable):
|
if isinstance(org_handler, Callable):
|
||||||
return org_handler(sig, frame)
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
org_handler = org_handler(sig, frame)
|
||||||
|
except Exception:
|
||||||
|
org_handler = signal.SIG_DFL
|
||||||
|
# remove stdout logger, just in case
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
Logger._remove_std_logger()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._signal_recursion_protection_flag = False
|
||||||
|
# return handler result
|
||||||
return org_handler
|
return org_handler
|
||||||
|
|
||||||
if cls.__exit_hook is None:
|
if cls.__exit_hook is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user