diff --git a/trains/backend_interface/logger.py b/trains/backend_interface/logger.py index bce167f1..9b0fcd70 100644 --- a/trains/backend_interface/logger.py +++ b/trains/backend_interface/logger.py @@ -11,6 +11,7 @@ class StdStreamPatch(object): _stdout_proxy = None _stderr_proxy = None _stdout_original_write = None + _stderr_original_write = None @staticmethod def patch_std_streams(logger): @@ -22,6 +23,8 @@ class StdStreamPatch(object): try: if StdStreamPatch._stdout_original_write is None: StdStreamPatch._stdout_original_write = sys.stdout.write + if StdStreamPatch._stderr_original_write is None: + StdStreamPatch._stderr_original_write = sys.stderr.write # this will only work in python 3, guard it with try/catch if not hasattr(sys.stdout, '_original_write'): sys.stdout._original_write = sys.stdout.write @@ -93,6 +96,11 @@ class StdStreamPatch(object): if StdStreamPatch._stdout_original_write: StdStreamPatch._stdout_original_write(*args, **kwargs) + @staticmethod + def stderr_original_write(*args, **kwargs): + if StdStreamPatch._stderr_original_write: + StdStreamPatch._stderr_original_write(*args, **kwargs) + @staticmethod def _stdout__patched__write__(*args, **kwargs): if StdStreamPatch._stdout_proxy: diff --git a/trains/logger.py b/trains/logger.py index c09f37fe..c6b8d2c5 100644 --- a/trains/logger.py +++ b/trains/logger.py @@ -607,8 +607,13 @@ class Logger(object): if isinstance(h, TaskHandler) and h.task_id == self._task.id][0] self._task_handler.emit(record) except Exception: - LoggerRoot.get_base_logger().warning(msg='Logger failed sending log: [level %s]: "%s"' - % (str(level), str(msg))) + # avoid infinite loop, output directly to stderr + try: + # make sure we are writing to the original stdout + StdStreamPatch.stderr_original_write( + 'trains.Logger failed sending log [level {}]: "{}"\n'.format(level, msg)) + except Exception: + pass if not omit_console: # if we are here and we grabbed the stdout, we need to print the real thing @@ -728,10 +733,11 @@ class Logger(object): if self._task_handler and DevWorker.report_stdout: self._task_handler.flush() - def _flush_wait_stdout_handler(self): + def _close_stdout_handler(self, wait=True): if self._task_handler and DevWorker.report_stdout: - self._task_handler.flush() - self._task_handler.wait_for_flush() + t = self._task_handler + self._task_handler = None + t.close(wait) def _touch_title_series(self, title, series): if title not in self._graph_titles: diff --git a/trains/task.py b/trains/task.py index 54795015..6011ff81 100644 --- a/trains/task.py +++ b/trains/task.py @@ -741,12 +741,8 @@ class Task(_Task): # flush any outstanding logs if self._logger: - if wait_for_uploads: - # noinspection PyProtectedMember - self._logger._flush_wait_stdout_handler() - else: - # noinspection PyProtectedMember - self._logger._flush_stdout_handler() + # noinspection PyProtectedMember + self._logger._flush_stdout_handler() if self._reporter: self.reporter.flush() LoggerRoot.flush() @@ -1556,6 +1552,8 @@ class Task(_Task): if self._logger: self._logger.set_flush_period(None) + if wait_for_uploads: + self._logger._close_stdout_handler() # this is so in theory we can close a main task and start a new one if self.is_main_task(): Task.__main_task = None