diff --git a/trains/backend_interface/task/log.py b/trains/backend_interface/task/log.py index a92709cc..428c8b13 100644 --- a/trains/backend_interface/task/log.py +++ b/trains/backend_interface/task/log.py @@ -1,3 +1,4 @@ +import sys import time from logging import LogRecord, getLogger, basicConfig from logging.handlers import BufferingHandler @@ -28,7 +29,8 @@ class TaskHandler(BufferingHandler): self.last_timestamp = 0 self.counter = 1 self._last_event = None - self._thread_pool = ThreadPool(processes=1) + self._thread_pool = None + self._pending = 0 def shouldFlush(self, record): """ @@ -37,7 +39,8 @@ class TaskHandler(BufferingHandler): Returns true if the buffer is up to capacity. This method can be overridden to implement custom flushing strategies. """ - + if self._task_id is None: + return False # Notice! protect against infinite loops, i.e. flush while sending previous records # if self.lock._is_owned(): # return False @@ -67,6 +70,8 @@ class TaskHandler(BufferingHandler): def _record_to_event(self, record): # type: (LogRecord) -> events.TaskLogEvent + if self._task_id is None: + return None timestamp = int(record.created * 1000) if timestamp == self.last_timestamp: timestamp += self.counter @@ -92,43 +97,95 @@ class TaskHandler(BufferingHandler): return self._last_event def flush(self): + if self._task_id is None: + return + if not self.buffer: return self.acquire() + if not self.buffer: + self.release() + return buffer = self.buffer + self.buffer = [] try: - if not buffer: - return - self.buffer = [] record_events = [self._record_to_event(record) for record in buffer] self._last_event = None batch_requests = events.AddBatchRequest(requests=[events.AddRequest(e) for e in record_events if e]) except Exception: + # print("Failed logging task to backend ({:d} lines)".format(len(buffer))) batch_requests = None - print("Failed logging task to backend ({:d} lines)".format(len(buffer))) - finally: - self.release() if batch_requests: + if not self._thread_pool: + self._thread_pool = ThreadPool(processes=1) + self._pending += 1 self._thread_pool.apply_async(self._send_events, args=(batch_requests, )) - def wait_for_flush(self): - self.acquire() - try: - self._thread_pool.close() - self._thread_pool.join() - except Exception: - pass - self._thread_pool = ThreadPool(processes=1) self.release() + def wait_for_flush(self, shutdown=False): + msg = 'Task.log.wait_for_flush: %d' + ll = self.__log_stderr + ll(msg % 0) + self.acquire() + ll(msg % 1) + if self._thread_pool: + ll(msg % 2) + t = self._thread_pool + ll(msg % 3) + self._thread_pool = None + ll(msg % 4) + try: + ll(msg % 5) + t.close() + ll(msg % 6) + t.join() + ll(msg % 7) + except Exception: + ll(msg % 8) + pass + if shutdown: + ll(msg % 9) + self._task_id = None + ll(msg % 10) + self.release() + ll(msg % 11) + + def close(self, wait=True): + # super already calls self.flush() + super(TaskHandler, self).close() + # shut down the TaskHandler, from this point onwards. No events will be logged + if not wait: + self.acquire() + self._thread_pool = None + self._task_id = None + self.release() + else: + self.wait_for_flush(shutdown=True) + def _send_events(self, a_request): try: + if self._thread_pool is None: + self.__log_stderr('Warning: trains.Task - ' + 'Task.close() flushing remaining logs ({})'.format(self._pending)) + self._pending -= 1 res = self.session.send(a_request) if not res.ok(): - print("Failed logging task to backend ({:d} lines, {})".format(len(a_request.requests), str(res.meta))) + self.__log_stderr("Warning: trains.log._send_events: failed logging task to backend " + "({:d} lines, {})".format(len(a_request.requests), str(res.meta))) except Exception as ex: - print("Retrying, failed logging task to backend ({:d} lines): {}".format(len(a_request.requests), ex)) + self.__log_stderr("Warning: trains.log._send_events: Retrying, " + "failed logging task to backend ({:d} lines): {}".format(len(a_request.requests), ex)) # we should push ourselves back into the thread pool - self._thread_pool.apply_async(self._send_events, args=(a_request, )) + if self._thread_pool: + self._pending += 1 + self._thread_pool.apply_async(self._send_events, args=(a_request, )) + + @staticmethod + def __log_stderr(t): + if hasattr(sys.stderr, '_original_write'): + sys.stderr._original_write(t + '\n') + else: + sys.stderr.write(t + '\n') diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index 9d6daec9..b0dfe487 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -34,8 +34,7 @@ from ...config import get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR running_remotely, get_cache_dir, DOCKER_IMAGE_ENV_VAR from ...debugging import get_logger from ...debugging.log import LoggerRoot -from ...storage import StorageHelper -from ...storage.helper import StorageError +from ...storage.helper import StorageHelper, StorageError from .access import AccessMixin from .log import TaskHandler from .repo import ScriptInfo