diff --git a/trains/backend_interface/metrics/interface.py b/trains/backend_interface/metrics/interface.py index 09f2d3d7..7a53b652 100644 --- a/trains/backend_interface/metrics/interface.py +++ b/trains/backend_interface/metrics/interface.py @@ -196,3 +196,19 @@ class Metrics(InterfaceBase): return self.send(req, raise_on_errors=False) return None + + @staticmethod + def close_async_threads(): + global file_upload_pool + global upload_pool + try: + file_upload_pool.close() + file_upload_pool.join() + except: + pass + + try: + upload_pool.close() + upload_pool.join() + except: + pass diff --git a/trains/backend_interface/metrics/reporter.py b/trains/backend_interface/metrics/reporter.py index 04e8cd56..acf71920 100644 --- a/trains/backend_interface/metrics/reporter.py +++ b/trains/backend_interface/metrics/reporter.py @@ -92,6 +92,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan if self.get_num_results() > 0: self.wait_for_results() # make sure we flushed everything + self._async_enable = False self._write() if self.get_num_results() > 0: self.wait_for_results() diff --git a/trains/backend_interface/task/log.py b/trains/backend_interface/task/log.py index f855263c..1a5b2e78 100644 --- a/trains/backend_interface/task/log.py +++ b/trains/backend_interface/task/log.py @@ -113,6 +113,16 @@ class TaskHandler(BufferingHandler): if batch_requests: self._thread_pool.apply_async(self._send_events, args=(batch_requests, )) + def wait_for_flush(self): + self.acquire() + try: + self._thread_pool.close() + self._thread_pool.join() + except Exception: + pass + self._thread_pool = ThreadPool(processes=1) + self.release() + def _send_events(self, a_request): try: res = self.session.send(a_request) diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index 175d44a5..f2e9c643 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -220,11 +220,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): self.reload() # if jupyter is present, requirements will be created in the background, when saving a snapshot if result.script and script_requirements: - requirements = script_requirements.get_requirements() + requirements, conda_requirements = script_requirements.get_requirements() if requirements: if not result.script['requirements']: result.script['requirements'] = {} result.script['requirements']['pip'] = requirements + result.script['requirements']['conda'] = conda_requirements self._update_requirements(result.script.get('requirements') or '') self.reload() diff --git a/trains/logger.py b/trains/logger.py index bbb3bf03..e8fbe64a 100644 --- a/trains/logger.py +++ b/trains/logger.py @@ -650,6 +650,11 @@ class Logger(object): if self._task_handler and DevWorker.report_stdout: self._task_handler.flush() + def _flush_wait_stdout_handler(self): + if self._task_handler and DevWorker.report_stdout: + self._task_handler.flush() + self._task_handler.wait_for_flush() + def _touch_title_series(self, title, series): if title not in self._graph_titles: self._graph_titles[title] = set() diff --git a/trains/task.py b/trains/task.py index 5f157afa..a7bab712 100644 --- a/trains/task.py +++ b/trains/task.py @@ -21,6 +21,7 @@ from pathlib2 import Path from .backend_api.services import tasks, projects, queues from .backend_api.session.session import Session +from .backend_interface.metrics import Metrics from .backend_interface.model import Model as BackendModel from .backend_interface.task import Task as _Task from .backend_interface.task.args import _Arguments @@ -683,8 +684,12 @@ class Task(_Task): # flush any outstanding logs if self._logger: - # noinspection PyProtectedMember - self._logger._flush_stdout_handler() + if wait_for_uploads: + # noinspection PyProtectedMember + self._logger._flush_wait_stdout_handler() + else: + # noinspection PyProtectedMember + self._logger._flush_stdout_handler() if self._reporter: self.reporter.flush() LoggerRoot.flush() @@ -1365,6 +1370,11 @@ class Task(_Task): # wait until the reporter flush everything if self._reporter: self.reporter.stop() + if self.is_main_task(): + # notice: this will close the reporting for all the Tasks in the system + Metrics.close_async_threads() + # notice: this will close the jupyter monitoring + ScriptInfo.close() if print_done_waiting: self.log.info('Finished uploading') elif self._logger: