From e0e6d9159bfd6aced4e694f439c14a6c82347be4 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 10 Oct 2019 21:10:51 +0300 Subject: [PATCH] Fix python 2.7 and Windows support --- trains/backend_interface/task/task.py | 7 ++- trains/backend_interface/util.py | 22 +++++++-- trains/task.py | 70 ++++++++++++++++----------- trains/utilities/check_updates.py | 33 ++++++------- 4 files changed, 81 insertions(+), 51 deletions(-) diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index b8db2856..5b154cc4 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -2,6 +2,7 @@ import collections import itertools import logging +import os from enum import Enum from threading import Thread from multiprocessing import RLock @@ -189,9 +190,11 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): latest_version = CheckPackageUpdates.check_new_package_available(only_once=True) if latest_version: if not latest_version[1]: + sep = os.linesep self.get_logger().report_text( - 'TRAINS new package available: UPGRADE to v{} is recommended!'.format( - latest_version[0]), + 'TRAINS new package available: UPGRADE to v{} is recommended! ' + '{}'.format( + latest_version[0], sep.join(latest_version[2])), ) else: self.get_logger().report_text( diff --git a/trains/backend_interface/util.py b/trains/backend_interface/util.py index ec899149..f245f455 100644 --- a/trains/backend_interface/util.py +++ b/trains/backend_interface/util.py @@ -1,7 +1,23 @@ import getpass import re from _socket import gethostname -from datetime import datetime, timezone +from datetime import datetime +try: + from datetime import timezone + utc_timezone = timezone.utc +except ImportError: + from datetime import tzinfo, timedelta + + class UTC(tzinfo): + def utcoffset(self, dt): + return timedelta(0) + + def tzname(self, dt): + return "UTC" + + def dst(self, dt): + return timedelta(0) + utc_timezone = UTC() from ..backend_api.services import projects from ..debugging.log import get_logger @@ -26,8 +42,8 @@ def get_or_create_project(session, project_name, description=None): # Hack for supporting windows -def get_epoch_beginning_of_time(tzinfo=None): - return datetime(1970, 1, 1, tzinfo=tzinfo if tzinfo else timezone.utc) +def get_epoch_beginning_of_time(timezone_info=None): + return datetime(1970, 1, 1).replace(tzinfo=timezone_info if timezone_info else utc_timezone) def get_single_result(entity, query, results, log=None, show_results=10, raise_on_error=True, sort_by_date=True): diff --git a/trains/task.py b/trains/task.py index bcaff5aa..49a8331a 100644 --- a/trains/task.py +++ b/trains/task.py @@ -41,8 +41,6 @@ from .utilities.resource_monitor import ResourceMonitor from .utilities.seed import make_deterministic from .utilities.dicts import ReadOnlyDict -NotSet = object() - class Task(_Task): """ @@ -67,6 +65,8 @@ class Task(_Task): TaskTypes = _Task.TaskTypes + NotSet = object() + __create_protection = object() __main_task = None __exit_hook = None @@ -566,37 +566,15 @@ class Task(_Task): raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__) - def get_logger(self, flush_period=NotSet): - # type: (Optional[float]) -> Logger + def get_logger(self): + # type: () -> Logger """ - get a logger object for reporting based on the task - - :param flush_period: The period of the logger flush. - If None of any other False value, will not flush periodically. - If a logger was created before, this will be the new period and - the old one will be discarded. + get a logger object for reporting, for this task context. + All reports (metrics, text etc.) related to this task are accessible in the web UI :return: Logger object """ - if not self._logger: - # force update of base logger to this current task (this is the main logger task) - self._setup_log(replace_existing=self.is_main_task()) - # Get a logger object - self._logger = Logger(private_task=self) - # make sure we set our reported to async mode - # we make sure we flush it in self._at_exit - self.reporter.async_enable = True - # if we just created the logger, set default flush period - if not flush_period or flush_period is NotSet: - flush_period = DevWorker.report_period - - if isinstance(flush_period, (int, float)): - flush_period = int(abs(flush_period)) - - if flush_period is None or isinstance(flush_period, int): - self._logger.set_flush_period(flush_period) - - return self._logger + return self._get_logger() def mark_started(self): """ @@ -819,6 +797,40 @@ class Task(_Task): if secret: Session.default_secret = secret + def _get_logger(self, flush_period=NotSet): + # type: (Optional[float]) -> Logger + """ + get a logger object for reporting based on the task + + :param flush_period: The period of the logger flush. + If None of any other False value, will not flush periodically. + If a logger was created before, this will be the new period and + the old one will be discarded. + + :return: Logger object + """ + pass + + if not self._logger: + # force update of base logger to this current task (this is the main logger task) + self._setup_log(replace_existing=self.is_main_task()) + # Get a logger object + self._logger = Logger(private_task=self) + # make sure we set our reported to async mode + # we make sure we flush it in self._at_exit + self.reporter.async_enable = True + # if we just created the logger, set default flush period + if not flush_period or flush_period is self.NotSet: + flush_period = DevWorker.report_period + + if isinstance(flush_period, (int, float)): + flush_period = int(abs(flush_period)) + + if flush_period is None or isinstance(flush_period, int): + self._logger.set_flush_period(flush_period) + + return self._logger + def _connect_output_model(self, model): assert isinstance(model, OutputModel) model.connect(self) diff --git a/trains/utilities/check_updates.py b/trains/utilities/check_updates.py index 523a2856..6f41b23f 100644 --- a/trains/utilities/check_updates.py +++ b/trains/utilities/check_updates.py @@ -316,24 +316,21 @@ class CheckPackageUpdates(object): try: from ..version import __version__ cls._package_version_checked = True - # Sending the request only for statistics - update_statistics = threading.Thread(target=CheckPackageUpdates.get_version_from_updates_server, - args=(__version__,)) - update_statistics.daemon = True - update_statistics.start() - - releases = requests.get('https://pypi.python.org/pypi/trains/json', timeout=3.0).json()['releases'].keys() - - releases = [Version(r) for r in releases] - latest_version = sorted(releases) cur_version = Version(__version__) - if not cur_version.is_devrelease and not cur_version.is_prerelease: - latest_version = [r for r in latest_version if not r.is_devrelease and not r.is_prerelease] - - if cur_version >= latest_version[-1]: + update_server_releases = requests.get('https://updates.trainsai.io/updates', + data=json.dumps({"versions": {"trains": str(cur_version)}}), + timeout=3.0) + if update_server_releases.ok: + update_server_releases = update_server_releases.json() + else: return None - not_patch_upgrade = latest_version[-1].release[:2] != cur_version.release[:2] - return str(latest_version[-1]), not_patch_upgrade + trains_answer = update_server_releases.get("trains", {}) + latest_version = Version(trains_answer.get("version")) + + if cur_version >= latest_version: + return None + not_patch_upgrade = latest_version.release[:2] == cur_version.release[:2] + return str(latest_version), not_patch_upgrade, trains_answer.get("description").split("\r\n") except Exception: return None @@ -341,6 +338,8 @@ class CheckPackageUpdates(object): def get_version_from_updates_server(cur_version): try: _ = requests.get('https://updates.trainsai.io/updates', - params=json.dumps({'versions': {'trains': str(cur_version)}}), timeout=1.0) + data=json.dumps({"versions": {"trains": str(cur_version)}}), + timeout=1.0) + return except Exception: pass