mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Add better signal/exception binding
This commit is contained in:
		
							parent
							
								
									ea3b5856fd
								
							
						
					
					
						commit
						cf1914fa64
					
				| @ -11,7 +11,7 @@ from .debugging.log import LoggerRoot | ||||
| from .backend_interface.task.development.worker import DevWorker | ||||
| from .backend_interface.task.log import TaskHandler | ||||
| from .storage import StorageHelper | ||||
| from .utilities.plotly import SeriesInfo | ||||
| from .utilities.plotly_reporter import SeriesInfo | ||||
| from .backend_interface import TaskStatusEnum | ||||
| from .backend_interface.task import Task as _Task | ||||
| from .config import running_remotely, get_cache_dir | ||||
| @ -82,7 +82,7 @@ class Logger(object): | ||||
|             # noinspection PyBroadException | ||||
|             try: | ||||
|                 Logger._stdout_original_write = sys.stdout.write | ||||
|                 # this will only work in python 3, but we still better guard it with try/catch | ||||
|                 # this will only work in python 3, guard it with try/catch | ||||
|                 sys.stdout._original_write = sys.stdout.write | ||||
|                 sys.stdout.write = stdout__patched__write__ | ||||
|                 sys.stderr._original_write = sys.stderr.write | ||||
| @ -549,7 +549,8 @@ class Logger(object): | ||||
|         :param period: The period to flush the logger in seconds. If None or 0, | ||||
|             There will be no periodic flush. | ||||
|         """ | ||||
|         if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and not running_remotely(): | ||||
|         if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and \ | ||||
|                 not running_remotely() and period is not None: | ||||
|             period = min(period or DevWorker.report_period, DevWorker.report_period) | ||||
| 
 | ||||
|         if not period: | ||||
| @ -562,6 +563,19 @@ class Logger(object): | ||||
|             self._flusher = _Flusher(self, period) | ||||
|             self._flusher.start() | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _remove_std_logger(self): | ||||
|         if isinstance(sys.stdout, PrintPatchLogger): | ||||
|             try: | ||||
|                 sys.stdout.connect(None) | ||||
|             except Exception: | ||||
|                 pass | ||||
|         if isinstance(sys.stderr, PrintPatchLogger): | ||||
|             try: | ||||
|                 sys.stderr.connect(None) | ||||
|             except Exception: | ||||
|                 pass | ||||
| 
 | ||||
|     def _start_task_if_needed(self): | ||||
|         if self._task._status == TaskStatusEnum.created: | ||||
|             self._task.mark_started() | ||||
| @ -634,6 +648,8 @@ class PrintPatchLogger(object): | ||||
|                 self._terminal.write(message) | ||||
| 
 | ||||
|     def connect(self, logger): | ||||
|         if self._log: | ||||
|             self._log._flush_stdout_handler() | ||||
|         self._log = logger | ||||
| 
 | ||||
|     def __getattr__(self, attr): | ||||
|  | ||||
| @ -3,7 +3,7 @@ import os | ||||
| import sys | ||||
| import threading | ||||
| import time | ||||
| import warnings | ||||
| import signal | ||||
| from argparse import ArgumentParser | ||||
| from collections import OrderedDict, Callable | ||||
| 
 | ||||
| @ -951,10 +951,15 @@ class Task(_Task): | ||||
|                 self.signal = None | ||||
|                 self._exit_callback = callback | ||||
|                 self._org_handlers = {} | ||||
|                 self._signal_recursion_protection_flag = False | ||||
|                 self._except_recursion_protection_flag = False | ||||
| 
 | ||||
|             def update_callback(self, callback): | ||||
|                 if self._exit_callback: | ||||
|                     atexit.unregister(self._exit_callback) | ||||
|                 if self._exit_callback and not six.PY2: | ||||
|                     try: | ||||
|                         atexit.unregister(self._exit_callback) | ||||
|                     except Exception: | ||||
|                         pass | ||||
|                 self._exit_callback = callback | ||||
|                 atexit.register(self._exit_callback) | ||||
| 
 | ||||
| @ -966,7 +971,6 @@ class Task(_Task): | ||||
|                     self._orig_exc_handler = sys.excepthook | ||||
|                 sys.excepthook = self.exc_handler | ||||
|                 atexit.register(self._exit_callback) | ||||
|                 import signal | ||||
|                 catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT, | ||||
|                                  signal.SIGILL, signal.SIGFPE, signal.SIGQUIT] | ||||
|                 for s in catch_signals: | ||||
| @ -982,16 +986,52 @@ class Task(_Task): | ||||
|                 self._orig_exit(code) | ||||
| 
 | ||||
|             def exc_handler(self, exctype, value, traceback, *args, **kwargs): | ||||
|                 if self._except_recursion_protection_flag: | ||||
|                     return sys.__excepthook__(exctype, value, traceback, *args, **kwargs) | ||||
| 
 | ||||
|                 self._except_recursion_protection_flag = True | ||||
|                 self.exception = value | ||||
|                 return self._orig_exc_handler(exctype, value, traceback, *args, **kwargs) | ||||
|                 if self._orig_exc_handler: | ||||
|                     ret = self._orig_exc_handler(exctype, value, traceback, *args, **kwargs) | ||||
|                 else: | ||||
|                     ret = sys.__excepthook__(exctype, value, traceback, *args, **kwargs) | ||||
|                 self._except_recursion_protection_flag = False | ||||
| 
 | ||||
|                 return ret | ||||
| 
 | ||||
|             def signal_handler(self, sig, frame): | ||||
|                 if self._signal_recursion_protection_flag: | ||||
|                     # call original | ||||
|                     org_handler = self._org_handlers.get(sig) | ||||
|                     if isinstance(org_handler, Callable): | ||||
|                         org_handler = org_handler(sig, frame) | ||||
|                     return org_handler | ||||
| 
 | ||||
|                 self._signal_recursion_protection_flag = True | ||||
|                 # call exit callback | ||||
|                 self.signal = sig | ||||
|                 if self._exit_callback: | ||||
|                     self._exit_callback() | ||||
|                 org_handler = self._org_handlers[sig] | ||||
|                     # noinspection PyBroadException | ||||
|                     try: | ||||
|                         self._exit_callback() | ||||
|                     except Exception: | ||||
|                         pass | ||||
|                 # call original signal handler | ||||
|                 org_handler = self._org_handlers.get(sig) | ||||
|                 if isinstance(org_handler, Callable): | ||||
|                     return org_handler(sig, frame) | ||||
|                     # noinspection PyBroadException | ||||
|                     try: | ||||
|                         org_handler = org_handler(sig, frame) | ||||
|                     except Exception: | ||||
|                         org_handler = signal.SIG_DFL | ||||
|                 # remove stdout logger, just in case | ||||
|                 # noinspection PyBroadException | ||||
|                 try: | ||||
|                     Logger._remove_std_logger() | ||||
|                 except Exception: | ||||
|                     pass | ||||
|                 self._signal_recursion_protection_flag = False | ||||
|                 # return handler result | ||||
|                 return org_handler | ||||
| 
 | ||||
|         if cls.__exit_hook is None: | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai