From 11420adce7976b4b17294b7f8dfd5aff5874ff18 Mon Sep 17 00:00:00 2001
From: allegroai <>
Date: Thu, 9 Apr 2020 12:24:37 +0300
Subject: [PATCH] Log reports at the end of the task

---
 trains/backend_interface/task/log.py  | 95 +++++++++++++++++++++------
 trains/backend_interface/task/task.py |  3 +-
 2 files changed, 77 insertions(+), 21 deletions(-)

diff --git a/trains/backend_interface/task/log.py b/trains/backend_interface/task/log.py
index a92709cc..428c8b13 100644
--- a/trains/backend_interface/task/log.py
+++ b/trains/backend_interface/task/log.py
@@ -1,3 +1,4 @@
+import sys
 import time
 from logging import LogRecord, getLogger, basicConfig
 from logging.handlers import BufferingHandler
@@ -28,7 +29,8 @@ class TaskHandler(BufferingHandler):
         self.last_timestamp = 0
         self.counter = 1
         self._last_event = None
-        self._thread_pool = ThreadPool(processes=1)
+        self._thread_pool = None
+        self._pending = 0
 
     def shouldFlush(self, record):
         """
@@ -37,7 +39,8 @@ class TaskHandler(BufferingHandler):
         Returns true if the buffer is up to capacity. This method can be
         overridden to implement custom flushing strategies.
         """
-
+        if self._task_id is None:
+            return False
         # Notice! protect against infinite loops, i.e. flush while sending previous records
         # if self.lock._is_owned():
         #     return False
@@ -67,6 +70,8 @@ class TaskHandler(BufferingHandler):
 
     def _record_to_event(self, record):
         # type: (LogRecord) -> events.TaskLogEvent
+        if self._task_id is None:
+            return None
         timestamp = int(record.created * 1000)
         if timestamp == self.last_timestamp:
             timestamp += self.counter
@@ -92,43 +97,95 @@ class TaskHandler(BufferingHandler):
         return self._last_event
 
     def flush(self):
+        if self._task_id is None:
+            return
+
         if not self.buffer:
             return
 
         self.acquire()
+        if not self.buffer:
+            self.release()
+            return
         buffer = self.buffer
+        self.buffer = []
         try:
-            if not buffer:
-                return
-            self.buffer = []
             record_events = [self._record_to_event(record) for record in buffer]
             self._last_event = None
             batch_requests = events.AddBatchRequest(requests=[events.AddRequest(e) for e in record_events if e])
         except Exception:
+            # print("Failed logging task to backend ({:d} lines)".format(len(buffer)))
             batch_requests = None
-            print("Failed logging task to backend ({:d} lines)".format(len(buffer)))
-        finally:
-            self.release()
 
         if batch_requests:
+            if not self._thread_pool:
+                self._thread_pool = ThreadPool(processes=1)
+            self._pending += 1
             self._thread_pool.apply_async(self._send_events, args=(batch_requests, ))
 
-    def wait_for_flush(self):
-        self.acquire()
-        try:
-            self._thread_pool.close()
-            self._thread_pool.join()
-        except Exception:
-            pass
-        self._thread_pool = ThreadPool(processes=1)
         self.release()
 
+    def wait_for_flush(self, shutdown=False):
+        msg = 'Task.log.wait_for_flush: %d'
+        ll = self.__log_stderr
+        ll(msg % 0)
+        self.acquire()
+        ll(msg % 1)
+        if self._thread_pool:
+            ll(msg % 2)
+            t = self._thread_pool
+            ll(msg % 3)
+            self._thread_pool = None
+            ll(msg % 4)
+            try:
+                ll(msg % 5)
+                t.close()
+                ll(msg % 6)
+                t.join()
+                ll(msg % 7)
+            except Exception:
+                ll(msg % 8)
+                pass
+        if shutdown:
+            ll(msg % 9)
+            self._task_id = None
+        ll(msg % 10)
+        self.release()
+        ll(msg % 11)
+
+    def close(self, wait=True):
+        # super already calls self.flush()
+        super(TaskHandler, self).close()
+        # shut down the TaskHandler, from this point onwards. No events will be logged
+        if not wait:
+            self.acquire()
+            self._thread_pool = None
+            self._task_id = None
+            self.release()
+        else:
+            self.wait_for_flush(shutdown=True)
+
     def _send_events(self, a_request):
         try:
+            if self._thread_pool is None:
+                self.__log_stderr('Warning: trains.Task - '
+                                  'Task.close() flushing remaining logs ({})'.format(self._pending))
+            self._pending -= 1
             res = self.session.send(a_request)
             if not res.ok():
-                print("Failed logging task to backend ({:d} lines, {})".format(len(a_request.requests), str(res.meta)))
+                self.__log_stderr("Warning: trains.log._send_events: failed logging task to backend "
+                                  "({:d} lines, {})".format(len(a_request.requests), str(res.meta)))
         except Exception as ex:
-            print("Retrying, failed logging task to backend ({:d} lines): {}".format(len(a_request.requests), ex))
+            self.__log_stderr("Warning: trains.log._send_events: Retrying, "
+                              "failed logging task to backend ({:d} lines): {}".format(len(a_request.requests), ex))
             # we should push ourselves back into the thread pool
-            self._thread_pool.apply_async(self._send_events, args=(a_request, ))
+            if self._thread_pool:
+                self._pending += 1
+                self._thread_pool.apply_async(self._send_events, args=(a_request, ))
+
+    @staticmethod
+    def __log_stderr(t):
+        if hasattr(sys.stderr, '_original_write'):
+            sys.stderr._original_write(t + '\n')
+        else:
+            sys.stderr.write(t + '\n')
diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py
index 9d6daec9..b0dfe487 100644
--- a/trains/backend_interface/task/task.py
+++ b/trains/backend_interface/task/task.py
@@ -34,8 +34,7 @@ from ...config import get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR
     running_remotely, get_cache_dir, DOCKER_IMAGE_ENV_VAR
 from ...debugging import get_logger
 from ...debugging.log import LoggerRoot
-from ...storage import StorageHelper
-from ...storage.helper import StorageError
+from ...storage.helper import StorageHelper, StorageError
 from .access import AccessMixin
 from .log import TaskHandler
 from .repo import ScriptInfo