Add support for overriding initial server connection behavior using the CLEARML_AGENT_INITIAL_CONNECT_RETRY_OVERRIDE env var (defaults to true, allows boolean value or an explicit number specifying the number of connect retries)

This commit is contained in:
allegroai 2021-08-27 19:14:53 +03:00
parent 21c4857795
commit 5a080798cb
5 changed files with 102 additions and 50 deletions

View File

@ -31,7 +31,9 @@
} }
auth { auth {
# When creating a request, if token will expire in less than this value, try to refresh the token # When creating a request, if token will expire in less than this value, try to refresh the token. Default 12 hours
token_expiration_threshold_sec = 360 token_expiration_threshold_sec: 43200
# When requesting a token, request specific expiration time. Server default (and maximum) is 30 days
# request_token_expiration_sec: 2592000
} }
} }

View File

@ -1,3 +1,4 @@
from ...backend_config.converters import safe_text_to_bool
from ...backend_config.environment import EnvEntry from ...backend_config.environment import EnvEntry
@ -12,3 +13,6 @@ ENV_HOST_VERIFY_CERT = EnvEntry("CLEARML_API_HOST_VERIFY_CERT", "TRAINS_API_HOST
ENV_CONDA_ENV_PACKAGE = EnvEntry("CLEARML_CONDA_ENV_PACKAGE", "TRAINS_CONDA_ENV_PACKAGE") ENV_CONDA_ENV_PACKAGE = EnvEntry("CLEARML_CONDA_ENV_PACKAGE", "TRAINS_CONDA_ENV_PACKAGE")
ENV_NO_DEFAULT_SERVER = EnvEntry("CLEARML_NO_DEFAULT_SERVER", "TRAINS_NO_DEFAULT_SERVER", type=bool, default=True) ENV_NO_DEFAULT_SERVER = EnvEntry("CLEARML_NO_DEFAULT_SERVER", "TRAINS_NO_DEFAULT_SERVER", type=bool, default=True)
ENV_DISABLE_VAULT_SUPPORT = EnvEntry('CLEARML_AGENT_DISABLE_VAULT_SUPPORT', type=bool) ENV_DISABLE_VAULT_SUPPORT = EnvEntry('CLEARML_AGENT_DISABLE_VAULT_SUPPORT', type=bool)
ENV_INITIAL_CONNECT_RETRY_OVERRIDE = EnvEntry(
'CLEARML_AGENT_INITIAL_CONNECT_RETRY_OVERRIDE', default=True, converter=safe_text_to_bool
)

View File

@ -1,8 +1,10 @@
import json as json_lib import json as json_lib
import os
import sys import sys
import types import types
from socket import gethostname from socket import gethostname
from typing import Optional
import jwt import jwt
import requests import requests
@ -13,7 +15,7 @@ from six.moves.urllib.parse import urlparse, urlunparse
from .callresult import CallResult from .callresult import CallResult
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, ENV_FILES_HOST, ENV_AUTH_TOKEN, \ from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, ENV_FILES_HOST, ENV_AUTH_TOKEN, \
ENV_NO_DEFAULT_SERVER, ENV_DISABLE_VAULT_SUPPORT ENV_NO_DEFAULT_SERVER, ENV_DISABLE_VAULT_SUPPORT, ENV_INITIAL_CONNECT_RETRY_OVERRIDE
from .request import Request, BatchRequest from .request import Request, BatchRequest
from .token_manager import TokenManager from .token_manager import TokenManager
from ..config import load from ..config import load
@ -42,7 +44,7 @@ class Session(TokenManager):
_session_requests = 0 _session_requests = 0
_session_initial_timeout = (3.0, 10.) _session_initial_timeout = (3.0, 10.)
_session_timeout = (10.0, 30.) _session_timeout = (10.0, 30.)
_session_initial_connect_retry = 4 _session_initial_retry_connect_override = 4
_write_session_data_size = 15000 _write_session_data_size = 15000
_write_session_timeout = (30.0, 30.) _write_session_timeout = (30.0, 30.)
@ -102,9 +104,7 @@ class Session(TokenManager):
if initialize_logging: if initialize_logging:
self.config.initialize_logging(debug=kwargs.get('debug', False)) self.config.initialize_logging(debug=kwargs.get('debug', False))
token_expiration_threshold_sec = self.config.get( super(Session, self).__init__(config=config, **kwargs)
"auth.token_expiration_threshold_sec", 60
)
self._verbose = verbose if verbose is not None else ENV_VERBOSE.get() self._verbose = verbose if verbose is not None else ENV_VERBOSE.get()
self._logger = logger self._logger = logger
@ -113,10 +113,7 @@ class Session(TokenManager):
if ENV_AUTH_TOKEN.get( if ENV_AUTH_TOKEN.get(
value_cb=lambda key, value: print("Using environment access token {}=********".format(key)) value_cb=lambda key, value: print("Using environment access token {}=********".format(key))
): ):
self._set_auth_token(ENV_AUTH_TOKEN.get()) self.set_auth_token(ENV_AUTH_TOKEN.get())
# if we use a token we override make sure we are at least 3600 seconds (1 hour)
# away from the token expiration date, ask for a new one.
token_expiration_threshold_sec = max(token_expiration_threshold_sec, 3600)
else: else:
self.__access_key = api_key or ENV_ACCESS_KEY.get( self.__access_key = api_key or ENV_ACCESS_KEY.get(
default=(self.config.get("api.credentials.access_key", None) or self.default_key), default=(self.config.get("api.credentials.access_key", None) or self.default_key),
@ -136,10 +133,6 @@ class Session(TokenManager):
"Missing secret_key. Please set in configuration file or pass in session init." "Missing secret_key. Please set in configuration file or pass in session init."
) )
super(Session, self).__init__(
token_expiration_threshold_sec=token_expiration_threshold_sec, **kwargs
)
if self.access_key == self.default_key and self.secret_key == self.default_secret: if self.access_key == self.default_key and self.secret_key == self.default_secret:
print("Using built-in ClearML default key/secret") print("Using built-in ClearML default key/secret")
@ -153,10 +146,6 @@ class Session(TokenManager):
) )
self.__host = host.strip("/") self.__host = host.strip("/")
http_retries_config = http_retries_config or self.config.get(
"api.http.retries", ConfigTree()
).as_plain_ordered_dict()
http_retries_config["status_forcelist"] = self._retry_codes
self.__worker = worker or gethostname() self.__worker = worker or gethostname()
@ -167,13 +156,15 @@ class Session(TokenManager):
self.client = client or "api-{}".format(__version__) self.client = client or "api-{}".format(__version__)
# limit the reconnect retries, so we get an error if we are starting the session # limit the reconnect retries, so we get an error if we are starting the session
http_no_retries_config = dict(**http_retries_config) _, self.__http_session = self._setup_session(
http_no_retries_config['connect'] = self._session_initial_connect_retry http_retries_config,
self.__http_session = get_http_session_with_retry(**http_no_retries_config) initial_session=True,
default_initial_connect_override=(False if kwargs.get("command") == "execute" else None)
)
# try to connect with the server # try to connect with the server
self.refresh_token() self.refresh_token()
# create the default session with many retries # create the default session with many retries
self.__http_session = get_http_session_with_retry(**http_retries_config) http_retries_config, self.__http_session = self._setup_session(http_retries_config)
# update api version from server response # update api version from server response
try: try:
@ -194,6 +185,31 @@ class Session(TokenManager):
self._load_vaults() self._load_vaults()
def _setup_session(self, http_retries_config, initial_session=False, default_initial_connect_override=None):
# type: (dict, bool, Optional[bool]) -> (dict, requests.Session)
http_retries_config = http_retries_config or self.config.get(
"api.http.retries", ConfigTree()
).as_plain_ordered_dict()
http_retries_config["status_forcelist"] = self._retry_codes
if initial_session:
kwargs = {} if default_initial_connect_override is None else {
"default": default_initial_connect_override
}
if ENV_INITIAL_CONNECT_RETRY_OVERRIDE.get(**kwargs):
connect_retries = self._session_initial_retry_connect_override
try:
value = ENV_INITIAL_CONNECT_RETRY_OVERRIDE.get(converter=str)
if not isinstance(value, bool):
connect_retries = abs(int(value))
except ValueError:
pass
http_retries_config = dict(**http_retries_config)
http_retries_config['connect'] = connect_retries
return http_retries_config, get_http_session_with_retry(**http_retries_config)
def _load_vaults(self): def _load_vaults(self):
if not self.check_min_api_version("2.15") or self.feature_set == "basic": if not self.check_min_api_version("2.15") or self.feature_set == "basic":
return return
@ -299,13 +315,9 @@ class Session(TokenManager):
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token) headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
return headers return headers
def _set_auth_token(self, auth_token):
self.__access_key = self.__secret_key = None
self.__auth_token = auth_token
def set_auth_token(self, auth_token): def set_auth_token(self, auth_token):
self._set_auth_token(auth_token) self.__access_key = self.__secret_key = None
self.refresh_token() self._set_token(auth_token)
def send_request( def send_request(
self, self,
@ -576,7 +588,7 @@ class Session(TokenManager):
return v + (0,) * max(0, 3 - len(v)) return v + (0,) * max(0, 3 - len(v))
return version_tuple(cls.api_version) >= version_tuple(str(min_api_version)) return version_tuple(cls.api_version) >= version_tuple(str(min_api_version))
def _do_refresh_token(self, old_token, exp=None): def _do_refresh_token(self, current_token, exp=None):
""" TokenManager abstract method implementation. """ TokenManager abstract method implementation.
Here we ignore the old token and simply obtain a new token. Here we ignore the old token and simply obtain a new token.
""" """
@ -588,13 +600,13 @@ class Session(TokenManager):
) )
) )
auth = None
headers = None headers = None
# use token only once (the second time the token is already built into the http session) if self.access_key and self.secret_key:
if self.__auth_token: auth = HTTPBasicAuth(self.access_key, self.secret_key)
headers = dict(Authorization="Bearer {}".format(self.__auth_token)) elif current_token:
self.__auth_token = None headers = dict(Authorization="Bearer {}".format(current_token))
auth = HTTPBasicAuth(self.access_key, self.secret_key) if self.access_key and self.secret_key else None
res = None res = None
try: try:
data = {"expiration_sec": exp} if exp else {} data = {"expiration_sec": exp} if exp else {}
@ -619,7 +631,10 @@ class Session(TokenManager):
) )
if verbose: if verbose:
self._logger.info("Received new token") self._logger.info("Received new token")
return resp["data"]["token"] token = resp["data"]["token"]
if ENV_AUTH_TOKEN.get():
os.environ[ENV_AUTH_TOKEN.key] = token
return token
except LoginError: except LoginError:
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
except KeyError as ex: except KeyError as ex:

View File

@ -9,6 +9,8 @@ import six
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class TokenManager(object): class TokenManager(object):
_default_token_exp_threshold_sec = 12 * 60 * 60
_default_req_token_expiration_sec = None
@property @property
def token_expiration_threshold_sec(self): def token_expiration_threshold_sec(self):
@ -41,17 +43,30 @@ class TokenManager(object):
return self.__token return self.__token
def __init__( def __init__(
self, self,
token=None, token=None,
req_token_expiration_sec=None, req_token_expiration_sec=None,
token_history=None, token_history=None,
token_expiration_threshold_sec=60, token_expiration_threshold_sec=None,
**kwargs config=None,
**kwargs
): ):
super(TokenManager, self).__init__() super(TokenManager, self).__init__()
assert isinstance(token_history, (type(None), dict)) assert isinstance(token_history, (type(None), dict))
self.token_expiration_threshold_sec = token_expiration_threshold_sec if config:
self.req_token_expiration_sec = req_token_expiration_sec req_token_expiration_sec = req_token_expiration_sec or config.get(
"api.auth.request_token_expiration_sec", None
)
token_expiration_threshold_sec = (
token_expiration_threshold_sec
or config.get("api.auth.token_expiration_threshold_sec", None)
)
self.token_expiration_threshold_sec = (
token_expiration_threshold_sec or self._default_token_exp_threshold_sec
)
self.req_token_expiration_sec = (
req_token_expiration_sec or self._default_req_token_expiration_sec
)
self._set_token(token) self._set_token(token)
def _calc_token_valid_period_sec(self, token, exp=None, at_least_sec=None): def _calc_token_valid_period_sec(self, token, exp=None, at_least_sec=None):
@ -59,7 +74,9 @@ class TokenManager(object):
try: try:
exp = exp or self._get_token_exp(token) exp = exp or self._get_token_exp(token)
if at_least_sec: if at_least_sec:
at_least_sec = max(at_least_sec, self.token_expiration_threshold_sec) at_least_sec = max(
at_least_sec, self.token_expiration_threshold_sec
)
else: else:
at_least_sec = self.token_expiration_threshold_sec at_least_sec = self.token_expiration_threshold_sec
return max(0, (exp - time() - at_least_sec)) return max(0, (exp - time() - at_least_sec))
@ -71,14 +88,16 @@ class TokenManager(object):
def get_decoded_token(cls, token, verify=False): def get_decoded_token(cls, token, verify=False):
""" Get token expiration time. If not present, assume forever """ """ Get token expiration time. If not present, assume forever """
return jwt.decode( return jwt.decode(
token, verify=verify, token,
verify=verify,
options=dict(verify_signature=False), options=dict(verify_signature=False),
algorithms=get_default_algorithms()) algorithms=get_default_algorithms(),
)
@classmethod @classmethod
def _get_token_exp(cls, token): def _get_token_exp(cls, token):
""" Get token expiration time. If not present, assume forever """ """ Get token expiration time. If not present, assume forever """
return cls.get_decoded_token(token).get('exp', sys.maxsize) return cls.get_decoded_token(token).get("exp", sys.maxsize)
def _set_token(self, token): def _set_token(self, token):
if token: if token:
@ -89,7 +108,9 @@ class TokenManager(object):
self.__token_expiration_sec = 0 self.__token_expiration_sec = 0
def get_token_valid_period_sec(self): def get_token_valid_period_sec(self):
return self._calc_token_valid_period_sec(self.__token, self.token_expiration_sec) return self._calc_token_valid_period_sec(
self.__token, self.token_expiration_sec
)
def _get_token(self): def _get_token(self):
if self.get_token_valid_period_sec() <= 0: if self.get_token_valid_period_sec() <= 0:
@ -101,4 +122,6 @@ class TokenManager(object):
pass pass
def refresh_token(self): def refresh_token(self):
self._set_token(self._do_refresh_token(self.__token, exp=self.req_token_expiration_sec)) self._set_token(
self._do_refresh_token(self.__token, exp=self.req_token_expiration_sec)
)

View File

@ -24,6 +24,14 @@ def text_to_bool(value):
return bool(strtobool(value)) return bool(strtobool(value))
def safe_text_to_bool(value):
# type: (Text) -> bool
try:
return text_to_bool(value)
except ValueError:
return bool(value)
def any_to_bool(value): def any_to_bool(value):
# type: (Optional[Union[int, float, Text]]) -> bool # type: (Optional[Union[int, float, Text]]) -> bool
if isinstance(value, six.text_type): if isinstance(value, six.text_type):