From 13ce783fa349e33e4620d6278142acc77a575db6 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 26 Apr 2020 22:53:18 +0300 Subject: [PATCH] Check for updates based on session version --- trains/backend_api/session/session.py | 5 +++-- trains/backend_interface/base.py | 1 - trains/backend_interface/task/task.py | 4 ++-- trains/utilities/check_updates.py | 31 +++++++++++++++++---------- 4 files changed, 25 insertions(+), 16 deletions(-) diff --git a/trains/backend_api/session/session.py b/trains/backend_api/session/session.py index 726aa8ef..7d1ccac2 100644 --- a/trains/backend_api/session/session.py +++ b/trains/backend_api/session/session.py @@ -51,6 +51,8 @@ class Session(TokenManager): _sessions_created = 0 _ssl_error_count_verbosity = 2 + _client = [(__package__.partition(".")[0], __version__)] + api_version = '2.1' default_host = "https://demoapi.trains.allegro.ai" default_web = "https://demoapp.trains.allegro.ai" @@ -91,7 +93,6 @@ class Session(TokenManager): logger=None, verbose=None, initialize_logging=True, - client=None, config=None, http_retries_config=None, **kwargs @@ -150,7 +151,7 @@ class Session(TokenManager): if not self.__max_req_size: raise ValueError("missing max request size") - self.client = client or "api-{}".format(__version__) + self.client = ", ".join("{}-{}".format(*x) for x in self._client) self.refresh_token() diff --git a/trains/backend_interface/base.py b/trains/backend_interface/base.py index 7bf9ecf5..5b3dad7c 100644 --- a/trains/backend_interface/base.py +++ b/trains/backend_interface/base.py @@ -97,7 +97,6 @@ class InterfaceBase(SessionInterface): if not InterfaceBase._default_session: InterfaceBase._default_session = Session( initialize_logging=False, - client='sdk-%s' % __version__, config=config_obj, api_key=ENV_ACCESS_KEY.get(), secret_key=ENV_SECRET_KEY.get(), diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index eee03e28..e6ae6bc0 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -201,8 +201,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): if not latest_version[1]: sep = os.linesep self.get_logger().report_text( - 'TRAINS new package available: UPGRADE to v{} is recommended!\nRelease Notes:\n{}'.format( - latest_version[0], sep.join(latest_version[2])), + '{} new package available: UPGRADE to v{} is recommended!\nRelease Notes:\n{}'.format( + Session._client[0][0].upper(), latest_version[0], sep.join(latest_version[2])), ) else: self.get_logger().report_text( diff --git a/trains/utilities/check_updates.py b/trains/utilities/check_updates.py index f87c176d..25133fa4 100644 --- a/trains/utilities/check_updates.py +++ b/trains/utilities/check_updates.py @@ -3,7 +3,6 @@ from __future__ import absolute_import, division, print_function import collections import json import re -import threading import requests import six @@ -12,6 +11,8 @@ if six.PY3: else: inf = float('inf') +from ..backend_api.session import Session + class InvalidVersion(ValueError): """ @@ -314,23 +315,31 @@ class CheckPackageUpdates(object): # noinspection PyBroadException try: - from ..version import __version__ cls._package_version_checked = True - cur_version = Version(__version__) - update_server_releases = requests.get('https://updates.trains.allegro.ai/updates', - data=json.dumps({"versions": {"trains": str(cur_version)}}), - timeout=3.0) + client, version = Session._client[0] + version = Version(version) + + update_server_releases = requests.get( + 'https://updates.trains.allegro.ai/updates', + json={"versions": {c: str(v) for c, v in Session._client}}, + timeout=3.0 + ) + if update_server_releases.ok: update_server_releases = update_server_releases.json() else: return None - trains_answer = update_server_releases.get("trains", {}) - latest_version = Version(trains_answer.get("version")) - if cur_version >= latest_version: + client_answer = update_server_releases.get(client, {}) + if "version" not in client_answer: return None - not_patch_upgrade = latest_version.release[:2] == cur_version.release[:2] - return str(latest_version), not_patch_upgrade, trains_answer.get("description").split("\r\n") + + latest_version = Version(client_answer["version"]) + + if version >= latest_version: + return None + not_patch_upgrade = latest_version.release[:2] == version.release[:2] + return str(latest_version), not_patch_upgrade, client_answer.get("description").split("\r\n") except Exception: return None