diff --git a/trains/backend_api/session/datamodel.py b/trains/backend_api/session/datamodel.py index f859ea55..c3a0f896 100644 --- a/trains/backend_api/session/datamodel.py +++ b/trains/backend_api/session/datamodel.py @@ -128,11 +128,14 @@ class NonStrictDataModelMixin(object): :summary: supplies an __init__ method that warns about unused keywords """ def __init__(self, **kwargs): - unexpected = [key for key in kwargs if not key.startswith('_')] - if unexpected: - message = '{}: unused keyword argument(s) {}' \ - .format(type(self).__name__, unexpected) - warnings.warn(message, UnusedKwargsWarning) + # unexpected = [key for key in kwargs if not key.startswith('_')] + # if unexpected: + # message = '{}: unused keyword argument(s) {}' \ + # .format(type(self).__name__, unexpected) + # warnings.warn(message, UnusedKwargsWarning) + + # ignore extra data warnings + pass class NonStrictDataModel(DataModel, NonStrictDataModelMixin): diff --git a/trains/backend_api/session/response.py b/trains/backend_api/session/response.py index 4e6d159c..0cc6787e 100644 --- a/trains/backend_api/session/response.py +++ b/trains/backend_api/session/response.py @@ -1,5 +1,6 @@ import requests +import six import jsonmodels.models import jsonmodels.fields import jsonmodels.errors @@ -8,14 +9,21 @@ from .apimodel import ApiModel from .datamodel import NonStrictDataModelMixin +class FloatOrStringField(jsonmodels.fields.BaseField): + + """String field.""" + + types = (float, six.string_types,) + + class Response(ApiModel, NonStrictDataModelMixin): pass class _ResponseEndpoint(jsonmodels.models.Base): name = jsonmodels.fields.StringField() - requested_version = jsonmodels.fields.FloatField() - actual_version = jsonmodels.fields.FloatField() + requested_version = FloatOrStringField() + actual_version = FloatOrStringField() class ResponseMeta(jsonmodels.models.Base): @@ -42,8 +50,8 @@ class ResponseMeta(jsonmodels.models.Base): def __str__(self): if self.result_code == requests.codes.ok: - return "<%d: %s/v%.1f>" % (self.result_code, self.endpoint.name, self.endpoint.actual_version) + return "<%d: %s/v%s>" % (self.result_code, self.endpoint.name, self.endpoint.actual_version) elif self._is_valid: - return "<%d/%d: %s/v%.1f (%s)>" % (self.result_code, self.result_subcode, self.endpoint.name, - self.endpoint.actual_version, self.result_msg) + return "<%d/%d: %s/v%s (%s)>" % (self.result_code, self.result_subcode, self.endpoint.name, + self.endpoint.actual_version, self.result_msg) return "<%d/%d: %s (%s)>" % (self.result_code, self.result_subcode, self.endpoint.name, self.result_msg) diff --git a/trains/backend_api/session/session.py b/trains/backend_api/session/session.py index 1e267d83..1497fdbe 100644 --- a/trains/backend_api/session/session.py +++ b/trains/backend_api/session/session.py @@ -5,6 +5,7 @@ from socket import gethostname import requests import six +import jwt from pyhocon import ConfigTree from requests.auth import HTTPBasicAuth @@ -34,6 +35,8 @@ class Session(TokenManager): _session_initial_timeout = (1.0, 10) _session_timeout = (5.0, None) + api_version = '2.1' + # TODO: add requests.codes.gateway_timeout once we support async commits _retry_codes = [ requests.codes.bad_gateway, @@ -127,6 +130,13 @@ class Session(TokenManager): self.refresh_token() + # update api version from server response + try: + api_version = jwt.decode(self.token, verify=False).get('api_version', Session.api_version) + Session.api_version = str(api_version) + except (jwt.DecodeError, ValueError): + pass + def _send_request( self, service,