Add support for overriding initial server connection behavior using the CLEARML_AGENT_INITIAL_CONNECT_RETRY_OVERRIDE env var (defaults to true, allows boolean value or an explicit number specifying the number of connect retries)

This commit is contained in:
allegroai 2021-08-27 19:14:53 +03:00
parent 21c4857795
commit 5a080798cb
5 changed files with 102 additions and 50 deletions

View File

@ -31,7 +31,9 @@
}
auth {
# When creating a request, if token will expire in less than this value, try to refresh the token
token_expiration_threshold_sec = 360
# When creating a request, if token will expire in less than this value, try to refresh the token. Default 12 hours
token_expiration_threshold_sec: 43200
# When requesting a token, request specific expiration time. Server default (and maximum) is 30 days
# request_token_expiration_sec: 2592000
}
}

View File

@ -1,3 +1,4 @@
from ...backend_config.converters import safe_text_to_bool
from ...backend_config.environment import EnvEntry
@ -12,3 +13,6 @@ ENV_HOST_VERIFY_CERT = EnvEntry("CLEARML_API_HOST_VERIFY_CERT", "TRAINS_API_HOST
ENV_CONDA_ENV_PACKAGE = EnvEntry("CLEARML_CONDA_ENV_PACKAGE", "TRAINS_CONDA_ENV_PACKAGE")
ENV_NO_DEFAULT_SERVER = EnvEntry("CLEARML_NO_DEFAULT_SERVER", "TRAINS_NO_DEFAULT_SERVER", type=bool, default=True)
ENV_DISABLE_VAULT_SUPPORT = EnvEntry('CLEARML_AGENT_DISABLE_VAULT_SUPPORT', type=bool)
ENV_INITIAL_CONNECT_RETRY_OVERRIDE = EnvEntry(
'CLEARML_AGENT_INITIAL_CONNECT_RETRY_OVERRIDE', default=True, converter=safe_text_to_bool
)

View File

@ -1,8 +1,10 @@
import json as json_lib
import os
import sys
import types
from socket import gethostname
from typing import Optional
import jwt
import requests
@ -13,7 +15,7 @@ from six.moves.urllib.parse import urlparse, urlunparse
from .callresult import CallResult
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, ENV_FILES_HOST, ENV_AUTH_TOKEN, \
ENV_NO_DEFAULT_SERVER, ENV_DISABLE_VAULT_SUPPORT
ENV_NO_DEFAULT_SERVER, ENV_DISABLE_VAULT_SUPPORT, ENV_INITIAL_CONNECT_RETRY_OVERRIDE
from .request import Request, BatchRequest
from .token_manager import TokenManager
from ..config import load
@ -42,7 +44,7 @@ class Session(TokenManager):
_session_requests = 0
_session_initial_timeout = (3.0, 10.)
_session_timeout = (10.0, 30.)
_session_initial_connect_retry = 4
_session_initial_retry_connect_override = 4
_write_session_data_size = 15000
_write_session_timeout = (30.0, 30.)
@ -102,9 +104,7 @@ class Session(TokenManager):
if initialize_logging:
self.config.initialize_logging(debug=kwargs.get('debug', False))
token_expiration_threshold_sec = self.config.get(
"auth.token_expiration_threshold_sec", 60
)
super(Session, self).__init__(config=config, **kwargs)
self._verbose = verbose if verbose is not None else ENV_VERBOSE.get()
self._logger = logger
@ -113,10 +113,7 @@ class Session(TokenManager):
if ENV_AUTH_TOKEN.get(
value_cb=lambda key, value: print("Using environment access token {}=********".format(key))
):
self._set_auth_token(ENV_AUTH_TOKEN.get())
# if we use a token we override make sure we are at least 3600 seconds (1 hour)
# away from the token expiration date, ask for a new one.
token_expiration_threshold_sec = max(token_expiration_threshold_sec, 3600)
self.set_auth_token(ENV_AUTH_TOKEN.get())
else:
self.__access_key = api_key or ENV_ACCESS_KEY.get(
default=(self.config.get("api.credentials.access_key", None) or self.default_key),
@ -136,10 +133,6 @@ class Session(TokenManager):
"Missing secret_key. Please set in configuration file or pass in session init."
)
super(Session, self).__init__(
token_expiration_threshold_sec=token_expiration_threshold_sec, **kwargs
)
if self.access_key == self.default_key and self.secret_key == self.default_secret:
print("Using built-in ClearML default key/secret")
@ -153,10 +146,6 @@ class Session(TokenManager):
)
self.__host = host.strip("/")
http_retries_config = http_retries_config or self.config.get(
"api.http.retries", ConfigTree()
).as_plain_ordered_dict()
http_retries_config["status_forcelist"] = self._retry_codes
self.__worker = worker or gethostname()
@ -167,13 +156,15 @@ class Session(TokenManager):
self.client = client or "api-{}".format(__version__)
# limit the reconnect retries, so we get an error if we are starting the session
http_no_retries_config = dict(**http_retries_config)
http_no_retries_config['connect'] = self._session_initial_connect_retry
self.__http_session = get_http_session_with_retry(**http_no_retries_config)
_, self.__http_session = self._setup_session(
http_retries_config,
initial_session=True,
default_initial_connect_override=(False if kwargs.get("command") == "execute" else None)
)
# try to connect with the server
self.refresh_token()
# create the default session with many retries
self.__http_session = get_http_session_with_retry(**http_retries_config)
http_retries_config, self.__http_session = self._setup_session(http_retries_config)
# update api version from server response
try:
@ -194,6 +185,31 @@ class Session(TokenManager):
self._load_vaults()
def _setup_session(self, http_retries_config, initial_session=False, default_initial_connect_override=None):
# type: (dict, bool, Optional[bool]) -> (dict, requests.Session)
http_retries_config = http_retries_config or self.config.get(
"api.http.retries", ConfigTree()
).as_plain_ordered_dict()
http_retries_config["status_forcelist"] = self._retry_codes
if initial_session:
kwargs = {} if default_initial_connect_override is None else {
"default": default_initial_connect_override
}
if ENV_INITIAL_CONNECT_RETRY_OVERRIDE.get(**kwargs):
connect_retries = self._session_initial_retry_connect_override
try:
value = ENV_INITIAL_CONNECT_RETRY_OVERRIDE.get(converter=str)
if not isinstance(value, bool):
connect_retries = abs(int(value))
except ValueError:
pass
http_retries_config = dict(**http_retries_config)
http_retries_config['connect'] = connect_retries
return http_retries_config, get_http_session_with_retry(**http_retries_config)
def _load_vaults(self):
if not self.check_min_api_version("2.15") or self.feature_set == "basic":
return
@ -299,13 +315,9 @@ class Session(TokenManager):
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
return headers
def _set_auth_token(self, auth_token):
self.__access_key = self.__secret_key = None
self.__auth_token = auth_token
def set_auth_token(self, auth_token):
self._set_auth_token(auth_token)
self.refresh_token()
self.__access_key = self.__secret_key = None
self._set_token(auth_token)
def send_request(
self,
@ -576,7 +588,7 @@ class Session(TokenManager):
return v + (0,) * max(0, 3 - len(v))
return version_tuple(cls.api_version) >= version_tuple(str(min_api_version))
def _do_refresh_token(self, old_token, exp=None):
def _do_refresh_token(self, current_token, exp=None):
""" TokenManager abstract method implementation.
Here we ignore the old token and simply obtain a new token.
"""
@ -588,13 +600,13 @@ class Session(TokenManager):
)
)
auth = None
headers = None
# use token only once (the second time the token is already built into the http session)
if self.__auth_token:
headers = dict(Authorization="Bearer {}".format(self.__auth_token))
self.__auth_token = None
if self.access_key and self.secret_key:
auth = HTTPBasicAuth(self.access_key, self.secret_key)
elif current_token:
headers = dict(Authorization="Bearer {}".format(current_token))
auth = HTTPBasicAuth(self.access_key, self.secret_key) if self.access_key and self.secret_key else None
res = None
try:
data = {"expiration_sec": exp} if exp else {}
@ -619,7 +631,10 @@ class Session(TokenManager):
)
if verbose:
self._logger.info("Received new token")
return resp["data"]["token"]
token = resp["data"]["token"]
if ENV_AUTH_TOKEN.get():
os.environ[ENV_AUTH_TOKEN.key] = token
return token
except LoginError:
six.reraise(*sys.exc_info())
except KeyError as ex:

View File

@ -9,6 +9,8 @@ import six
@six.add_metaclass(ABCMeta)
class TokenManager(object):
_default_token_exp_threshold_sec = 12 * 60 * 60
_default_req_token_expiration_sec = None
@property
def token_expiration_threshold_sec(self):
@ -41,17 +43,30 @@ class TokenManager(object):
return self.__token
def __init__(
self,
token=None,
req_token_expiration_sec=None,
token_history=None,
token_expiration_threshold_sec=60,
**kwargs
self,
token=None,
req_token_expiration_sec=None,
token_history=None,
token_expiration_threshold_sec=None,
config=None,
**kwargs
):
super(TokenManager, self).__init__()
assert isinstance(token_history, (type(None), dict))
self.token_expiration_threshold_sec = token_expiration_threshold_sec
self.req_token_expiration_sec = req_token_expiration_sec
if config:
req_token_expiration_sec = req_token_expiration_sec or config.get(
"api.auth.request_token_expiration_sec", None
)
token_expiration_threshold_sec = (
token_expiration_threshold_sec
or config.get("api.auth.token_expiration_threshold_sec", None)
)
self.token_expiration_threshold_sec = (
token_expiration_threshold_sec or self._default_token_exp_threshold_sec
)
self.req_token_expiration_sec = (
req_token_expiration_sec or self._default_req_token_expiration_sec
)
self._set_token(token)
def _calc_token_valid_period_sec(self, token, exp=None, at_least_sec=None):
@ -59,7 +74,9 @@ class TokenManager(object):
try:
exp = exp or self._get_token_exp(token)
if at_least_sec:
at_least_sec = max(at_least_sec, self.token_expiration_threshold_sec)
at_least_sec = max(
at_least_sec, self.token_expiration_threshold_sec
)
else:
at_least_sec = self.token_expiration_threshold_sec
return max(0, (exp - time() - at_least_sec))
@ -71,14 +88,16 @@ class TokenManager(object):
def get_decoded_token(cls, token, verify=False):
""" Get token expiration time. If not present, assume forever """
return jwt.decode(
token, verify=verify,
token,
verify=verify,
options=dict(verify_signature=False),
algorithms=get_default_algorithms())
algorithms=get_default_algorithms(),
)
@classmethod
def _get_token_exp(cls, token):
""" Get token expiration time. If not present, assume forever """
return cls.get_decoded_token(token).get('exp', sys.maxsize)
return cls.get_decoded_token(token).get("exp", sys.maxsize)
def _set_token(self, token):
if token:
@ -89,7 +108,9 @@ class TokenManager(object):
self.__token_expiration_sec = 0
def get_token_valid_period_sec(self):
return self._calc_token_valid_period_sec(self.__token, self.token_expiration_sec)
return self._calc_token_valid_period_sec(
self.__token, self.token_expiration_sec
)
def _get_token(self):
if self.get_token_valid_period_sec() <= 0:
@ -101,4 +122,6 @@ class TokenManager(object):
pass
def refresh_token(self):
self._set_token(self._do_refresh_token(self.__token, exp=self.req_token_expiration_sec))
self._set_token(
self._do_refresh_token(self.__token, exp=self.req_token_expiration_sec)
)

View File

@ -24,6 +24,14 @@ def text_to_bool(value):
return bool(strtobool(value))
def safe_text_to_bool(value):
# type: (Text) -> bool
try:
return text_to_bool(value)
except ValueError:
return bool(value)
def any_to_bool(value):
# type: (Optional[Union[int, float, Text]]) -> bool
if isinstance(value, six.text_type):