From 5a080798cb4292e198948fbe16cba70136cb6bdf Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 27 Aug 2021 19:14:53 +0300 Subject: [PATCH] Add support for overriding initial server connection behavior using the CLEARML_AGENT_INITIAL_CONNECT_RETRY_OVERRIDE env var (defaults to true, allows boolean value or an explicit number specifying the number of connect retries) --- .../backend_api/config/default/api.conf | 6 +- clearml_agent/backend_api/session/defs.py | 4 + clearml_agent/backend_api/session/session.py | 83 +++++++++++-------- .../backend_api/session/token_manager.py | 51 ++++++++---- clearml_agent/backend_config/converters.py | 8 ++ 5 files changed, 102 insertions(+), 50 deletions(-) diff --git a/clearml_agent/backend_api/config/default/api.conf b/clearml_agent/backend_api/config/default/api.conf index 25c3e04..9f8aa62 100644 --- a/clearml_agent/backend_api/config/default/api.conf +++ b/clearml_agent/backend_api/config/default/api.conf @@ -31,7 +31,9 @@ } auth { - # When creating a request, if token will expire in less than this value, try to refresh the token - token_expiration_threshold_sec = 360 + # When creating a request, if token will expire in less than this value, try to refresh the token. Default 12 hours + token_expiration_threshold_sec: 43200 + # When requesting a token, request specific expiration time. Server default (and maximum) is 30 days + # request_token_expiration_sec: 2592000 } } diff --git a/clearml_agent/backend_api/session/defs.py b/clearml_agent/backend_api/session/defs.py index e209c18..d311660 100644 --- a/clearml_agent/backend_api/session/defs.py +++ b/clearml_agent/backend_api/session/defs.py @@ -1,3 +1,4 @@ +from ...backend_config.converters import safe_text_to_bool from ...backend_config.environment import EnvEntry @@ -12,3 +13,6 @@ ENV_HOST_VERIFY_CERT = EnvEntry("CLEARML_API_HOST_VERIFY_CERT", "TRAINS_API_HOST ENV_CONDA_ENV_PACKAGE = EnvEntry("CLEARML_CONDA_ENV_PACKAGE", "TRAINS_CONDA_ENV_PACKAGE") ENV_NO_DEFAULT_SERVER = EnvEntry("CLEARML_NO_DEFAULT_SERVER", "TRAINS_NO_DEFAULT_SERVER", type=bool, default=True) ENV_DISABLE_VAULT_SUPPORT = EnvEntry('CLEARML_AGENT_DISABLE_VAULT_SUPPORT', type=bool) +ENV_INITIAL_CONNECT_RETRY_OVERRIDE = EnvEntry( + 'CLEARML_AGENT_INITIAL_CONNECT_RETRY_OVERRIDE', default=True, converter=safe_text_to_bool +) diff --git a/clearml_agent/backend_api/session/session.py b/clearml_agent/backend_api/session/session.py index b4d625c..cbe9074 100644 --- a/clearml_agent/backend_api/session/session.py +++ b/clearml_agent/backend_api/session/session.py @@ -1,8 +1,10 @@ import json as json_lib +import os import sys import types from socket import gethostname +from typing import Optional import jwt import requests @@ -13,7 +15,7 @@ from six.moves.urllib.parse import urlparse, urlunparse from .callresult import CallResult from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, ENV_FILES_HOST, ENV_AUTH_TOKEN, \ - ENV_NO_DEFAULT_SERVER, ENV_DISABLE_VAULT_SUPPORT + ENV_NO_DEFAULT_SERVER, ENV_DISABLE_VAULT_SUPPORT, ENV_INITIAL_CONNECT_RETRY_OVERRIDE from .request import Request, BatchRequest from .token_manager import TokenManager from ..config import load @@ -42,7 +44,7 @@ class Session(TokenManager): _session_requests = 0 _session_initial_timeout = (3.0, 10.) _session_timeout = (10.0, 30.) - _session_initial_connect_retry = 4 + _session_initial_retry_connect_override = 4 _write_session_data_size = 15000 _write_session_timeout = (30.0, 30.) @@ -102,9 +104,7 @@ class Session(TokenManager): if initialize_logging: self.config.initialize_logging(debug=kwargs.get('debug', False)) - token_expiration_threshold_sec = self.config.get( - "auth.token_expiration_threshold_sec", 60 - ) + super(Session, self).__init__(config=config, **kwargs) self._verbose = verbose if verbose is not None else ENV_VERBOSE.get() self._logger = logger @@ -113,10 +113,7 @@ class Session(TokenManager): if ENV_AUTH_TOKEN.get( value_cb=lambda key, value: print("Using environment access token {}=********".format(key)) ): - self._set_auth_token(ENV_AUTH_TOKEN.get()) - # if we use a token we override make sure we are at least 3600 seconds (1 hour) - # away from the token expiration date, ask for a new one. - token_expiration_threshold_sec = max(token_expiration_threshold_sec, 3600) + self.set_auth_token(ENV_AUTH_TOKEN.get()) else: self.__access_key = api_key or ENV_ACCESS_KEY.get( default=(self.config.get("api.credentials.access_key", None) or self.default_key), @@ -136,10 +133,6 @@ class Session(TokenManager): "Missing secret_key. Please set in configuration file or pass in session init." ) - super(Session, self).__init__( - token_expiration_threshold_sec=token_expiration_threshold_sec, **kwargs - ) - if self.access_key == self.default_key and self.secret_key == self.default_secret: print("Using built-in ClearML default key/secret") @@ -153,10 +146,6 @@ class Session(TokenManager): ) self.__host = host.strip("/") - http_retries_config = http_retries_config or self.config.get( - "api.http.retries", ConfigTree() - ).as_plain_ordered_dict() - http_retries_config["status_forcelist"] = self._retry_codes self.__worker = worker or gethostname() @@ -167,13 +156,15 @@ class Session(TokenManager): self.client = client or "api-{}".format(__version__) # limit the reconnect retries, so we get an error if we are starting the session - http_no_retries_config = dict(**http_retries_config) - http_no_retries_config['connect'] = self._session_initial_connect_retry - self.__http_session = get_http_session_with_retry(**http_no_retries_config) + _, self.__http_session = self._setup_session( + http_retries_config, + initial_session=True, + default_initial_connect_override=(False if kwargs.get("command") == "execute" else None) + ) # try to connect with the server self.refresh_token() # create the default session with many retries - self.__http_session = get_http_session_with_retry(**http_retries_config) + http_retries_config, self.__http_session = self._setup_session(http_retries_config) # update api version from server response try: @@ -194,6 +185,31 @@ class Session(TokenManager): self._load_vaults() + def _setup_session(self, http_retries_config, initial_session=False, default_initial_connect_override=None): + # type: (dict, bool, Optional[bool]) -> (dict, requests.Session) + http_retries_config = http_retries_config or self.config.get( + "api.http.retries", ConfigTree() + ).as_plain_ordered_dict() + http_retries_config["status_forcelist"] = self._retry_codes + + if initial_session: + kwargs = {} if default_initial_connect_override is None else { + "default": default_initial_connect_override + } + if ENV_INITIAL_CONNECT_RETRY_OVERRIDE.get(**kwargs): + connect_retries = self._session_initial_retry_connect_override + try: + value = ENV_INITIAL_CONNECT_RETRY_OVERRIDE.get(converter=str) + if not isinstance(value, bool): + connect_retries = abs(int(value)) + except ValueError: + pass + + http_retries_config = dict(**http_retries_config) + http_retries_config['connect'] = connect_retries + + return http_retries_config, get_http_session_with_retry(**http_retries_config) + def _load_vaults(self): if not self.check_min_api_version("2.15") or self.feature_set == "basic": return @@ -299,13 +315,9 @@ class Session(TokenManager): headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token) return headers - def _set_auth_token(self, auth_token): - self.__access_key = self.__secret_key = None - self.__auth_token = auth_token - def set_auth_token(self, auth_token): - self._set_auth_token(auth_token) - self.refresh_token() + self.__access_key = self.__secret_key = None + self._set_token(auth_token) def send_request( self, @@ -576,7 +588,7 @@ class Session(TokenManager): return v + (0,) * max(0, 3 - len(v)) return version_tuple(cls.api_version) >= version_tuple(str(min_api_version)) - def _do_refresh_token(self, old_token, exp=None): + def _do_refresh_token(self, current_token, exp=None): """ TokenManager abstract method implementation. Here we ignore the old token and simply obtain a new token. """ @@ -588,13 +600,13 @@ class Session(TokenManager): ) ) + auth = None headers = None - # use token only once (the second time the token is already built into the http session) - if self.__auth_token: - headers = dict(Authorization="Bearer {}".format(self.__auth_token)) - self.__auth_token = None + if self.access_key and self.secret_key: + auth = HTTPBasicAuth(self.access_key, self.secret_key) + elif current_token: + headers = dict(Authorization="Bearer {}".format(current_token)) - auth = HTTPBasicAuth(self.access_key, self.secret_key) if self.access_key and self.secret_key else None res = None try: data = {"expiration_sec": exp} if exp else {} @@ -619,7 +631,10 @@ class Session(TokenManager): ) if verbose: self._logger.info("Received new token") - return resp["data"]["token"] + token = resp["data"]["token"] + if ENV_AUTH_TOKEN.get(): + os.environ[ENV_AUTH_TOKEN.key] = token + return token except LoginError: six.reraise(*sys.exc_info()) except KeyError as ex: diff --git a/clearml_agent/backend_api/session/token_manager.py b/clearml_agent/backend_api/session/token_manager.py index 635f20b..c00722c 100644 --- a/clearml_agent/backend_api/session/token_manager.py +++ b/clearml_agent/backend_api/session/token_manager.py @@ -9,6 +9,8 @@ import six @six.add_metaclass(ABCMeta) class TokenManager(object): + _default_token_exp_threshold_sec = 12 * 60 * 60 + _default_req_token_expiration_sec = None @property def token_expiration_threshold_sec(self): @@ -41,17 +43,30 @@ class TokenManager(object): return self.__token def __init__( - self, - token=None, - req_token_expiration_sec=None, - token_history=None, - token_expiration_threshold_sec=60, - **kwargs + self, + token=None, + req_token_expiration_sec=None, + token_history=None, + token_expiration_threshold_sec=None, + config=None, + **kwargs ): super(TokenManager, self).__init__() assert isinstance(token_history, (type(None), dict)) - self.token_expiration_threshold_sec = token_expiration_threshold_sec - self.req_token_expiration_sec = req_token_expiration_sec + if config: + req_token_expiration_sec = req_token_expiration_sec or config.get( + "api.auth.request_token_expiration_sec", None + ) + token_expiration_threshold_sec = ( + token_expiration_threshold_sec + or config.get("api.auth.token_expiration_threshold_sec", None) + ) + self.token_expiration_threshold_sec = ( + token_expiration_threshold_sec or self._default_token_exp_threshold_sec + ) + self.req_token_expiration_sec = ( + req_token_expiration_sec or self._default_req_token_expiration_sec + ) self._set_token(token) def _calc_token_valid_period_sec(self, token, exp=None, at_least_sec=None): @@ -59,7 +74,9 @@ class TokenManager(object): try: exp = exp or self._get_token_exp(token) if at_least_sec: - at_least_sec = max(at_least_sec, self.token_expiration_threshold_sec) + at_least_sec = max( + at_least_sec, self.token_expiration_threshold_sec + ) else: at_least_sec = self.token_expiration_threshold_sec return max(0, (exp - time() - at_least_sec)) @@ -71,14 +88,16 @@ class TokenManager(object): def get_decoded_token(cls, token, verify=False): """ Get token expiration time. If not present, assume forever """ return jwt.decode( - token, verify=verify, + token, + verify=verify, options=dict(verify_signature=False), - algorithms=get_default_algorithms()) + algorithms=get_default_algorithms(), + ) @classmethod def _get_token_exp(cls, token): """ Get token expiration time. If not present, assume forever """ - return cls.get_decoded_token(token).get('exp', sys.maxsize) + return cls.get_decoded_token(token).get("exp", sys.maxsize) def _set_token(self, token): if token: @@ -89,7 +108,9 @@ class TokenManager(object): self.__token_expiration_sec = 0 def get_token_valid_period_sec(self): - return self._calc_token_valid_period_sec(self.__token, self.token_expiration_sec) + return self._calc_token_valid_period_sec( + self.__token, self.token_expiration_sec + ) def _get_token(self): if self.get_token_valid_period_sec() <= 0: @@ -101,4 +122,6 @@ class TokenManager(object): pass def refresh_token(self): - self._set_token(self._do_refresh_token(self.__token, exp=self.req_token_expiration_sec)) + self._set_token( + self._do_refresh_token(self.__token, exp=self.req_token_expiration_sec) + ) diff --git a/clearml_agent/backend_config/converters.py b/clearml_agent/backend_config/converters.py index 64901f3..cce1829 100644 --- a/clearml_agent/backend_config/converters.py +++ b/clearml_agent/backend_config/converters.py @@ -24,6 +24,14 @@ def text_to_bool(value): return bool(strtobool(value)) +def safe_text_to_bool(value): + # type: (Text) -> bool + try: + return text_to_bool(value) + except ValueError: + return bool(value) + + def any_to_bool(value): # type: (Optional[Union[int, float, Text]]) -> bool if isinstance(value, six.text_type):