mirror of
https://github.com/clearml/clearml-agent
synced 2025-06-09 16:08:39 +00:00
Add support for overriding initial server connection behavior using the CLEARML_AGENT_INITIAL_CONNECT_RETRY_OVERRIDE env var (defaults to true, allows boolean value or an explicit number specifying the number of connect retries)
This commit is contained in:
parent
21c4857795
commit
5a080798cb
@ -31,7 +31,9 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
auth {
|
auth {
|
||||||
# When creating a request, if token will expire in less than this value, try to refresh the token
|
# When creating a request, if token will expire in less than this value, try to refresh the token. Default 12 hours
|
||||||
token_expiration_threshold_sec = 360
|
token_expiration_threshold_sec: 43200
|
||||||
|
# When requesting a token, request specific expiration time. Server default (and maximum) is 30 days
|
||||||
|
# request_token_expiration_sec: 2592000
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from ...backend_config.converters import safe_text_to_bool
|
||||||
from ...backend_config.environment import EnvEntry
|
from ...backend_config.environment import EnvEntry
|
||||||
|
|
||||||
|
|
||||||
@ -12,3 +13,6 @@ ENV_HOST_VERIFY_CERT = EnvEntry("CLEARML_API_HOST_VERIFY_CERT", "TRAINS_API_HOST
|
|||||||
ENV_CONDA_ENV_PACKAGE = EnvEntry("CLEARML_CONDA_ENV_PACKAGE", "TRAINS_CONDA_ENV_PACKAGE")
|
ENV_CONDA_ENV_PACKAGE = EnvEntry("CLEARML_CONDA_ENV_PACKAGE", "TRAINS_CONDA_ENV_PACKAGE")
|
||||||
ENV_NO_DEFAULT_SERVER = EnvEntry("CLEARML_NO_DEFAULT_SERVER", "TRAINS_NO_DEFAULT_SERVER", type=bool, default=True)
|
ENV_NO_DEFAULT_SERVER = EnvEntry("CLEARML_NO_DEFAULT_SERVER", "TRAINS_NO_DEFAULT_SERVER", type=bool, default=True)
|
||||||
ENV_DISABLE_VAULT_SUPPORT = EnvEntry('CLEARML_AGENT_DISABLE_VAULT_SUPPORT', type=bool)
|
ENV_DISABLE_VAULT_SUPPORT = EnvEntry('CLEARML_AGENT_DISABLE_VAULT_SUPPORT', type=bool)
|
||||||
|
ENV_INITIAL_CONNECT_RETRY_OVERRIDE = EnvEntry(
|
||||||
|
'CLEARML_AGENT_INITIAL_CONNECT_RETRY_OVERRIDE', default=True, converter=safe_text_to_bool
|
||||||
|
)
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
|
|
||||||
import json as json_lib
|
import json as json_lib
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
from socket import gethostname
|
from socket import gethostname
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
import requests
|
import requests
|
||||||
@ -13,7 +15,7 @@ from six.moves.urllib.parse import urlparse, urlunparse
|
|||||||
|
|
||||||
from .callresult import CallResult
|
from .callresult import CallResult
|
||||||
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, ENV_FILES_HOST, ENV_AUTH_TOKEN, \
|
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, ENV_FILES_HOST, ENV_AUTH_TOKEN, \
|
||||||
ENV_NO_DEFAULT_SERVER, ENV_DISABLE_VAULT_SUPPORT
|
ENV_NO_DEFAULT_SERVER, ENV_DISABLE_VAULT_SUPPORT, ENV_INITIAL_CONNECT_RETRY_OVERRIDE
|
||||||
from .request import Request, BatchRequest
|
from .request import Request, BatchRequest
|
||||||
from .token_manager import TokenManager
|
from .token_manager import TokenManager
|
||||||
from ..config import load
|
from ..config import load
|
||||||
@ -42,7 +44,7 @@ class Session(TokenManager):
|
|||||||
_session_requests = 0
|
_session_requests = 0
|
||||||
_session_initial_timeout = (3.0, 10.)
|
_session_initial_timeout = (3.0, 10.)
|
||||||
_session_timeout = (10.0, 30.)
|
_session_timeout = (10.0, 30.)
|
||||||
_session_initial_connect_retry = 4
|
_session_initial_retry_connect_override = 4
|
||||||
_write_session_data_size = 15000
|
_write_session_data_size = 15000
|
||||||
_write_session_timeout = (30.0, 30.)
|
_write_session_timeout = (30.0, 30.)
|
||||||
|
|
||||||
@ -102,9 +104,7 @@ class Session(TokenManager):
|
|||||||
if initialize_logging:
|
if initialize_logging:
|
||||||
self.config.initialize_logging(debug=kwargs.get('debug', False))
|
self.config.initialize_logging(debug=kwargs.get('debug', False))
|
||||||
|
|
||||||
token_expiration_threshold_sec = self.config.get(
|
super(Session, self).__init__(config=config, **kwargs)
|
||||||
"auth.token_expiration_threshold_sec", 60
|
|
||||||
)
|
|
||||||
|
|
||||||
self._verbose = verbose if verbose is not None else ENV_VERBOSE.get()
|
self._verbose = verbose if verbose is not None else ENV_VERBOSE.get()
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
@ -113,10 +113,7 @@ class Session(TokenManager):
|
|||||||
if ENV_AUTH_TOKEN.get(
|
if ENV_AUTH_TOKEN.get(
|
||||||
value_cb=lambda key, value: print("Using environment access token {}=********".format(key))
|
value_cb=lambda key, value: print("Using environment access token {}=********".format(key))
|
||||||
):
|
):
|
||||||
self._set_auth_token(ENV_AUTH_TOKEN.get())
|
self.set_auth_token(ENV_AUTH_TOKEN.get())
|
||||||
# if we use a token we override make sure we are at least 3600 seconds (1 hour)
|
|
||||||
# away from the token expiration date, ask for a new one.
|
|
||||||
token_expiration_threshold_sec = max(token_expiration_threshold_sec, 3600)
|
|
||||||
else:
|
else:
|
||||||
self.__access_key = api_key or ENV_ACCESS_KEY.get(
|
self.__access_key = api_key or ENV_ACCESS_KEY.get(
|
||||||
default=(self.config.get("api.credentials.access_key", None) or self.default_key),
|
default=(self.config.get("api.credentials.access_key", None) or self.default_key),
|
||||||
@ -136,10 +133,6 @@ class Session(TokenManager):
|
|||||||
"Missing secret_key. Please set in configuration file or pass in session init."
|
"Missing secret_key. Please set in configuration file or pass in session init."
|
||||||
)
|
)
|
||||||
|
|
||||||
super(Session, self).__init__(
|
|
||||||
token_expiration_threshold_sec=token_expiration_threshold_sec, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.access_key == self.default_key and self.secret_key == self.default_secret:
|
if self.access_key == self.default_key and self.secret_key == self.default_secret:
|
||||||
print("Using built-in ClearML default key/secret")
|
print("Using built-in ClearML default key/secret")
|
||||||
|
|
||||||
@ -153,10 +146,6 @@ class Session(TokenManager):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.__host = host.strip("/")
|
self.__host = host.strip("/")
|
||||||
http_retries_config = http_retries_config or self.config.get(
|
|
||||||
"api.http.retries", ConfigTree()
|
|
||||||
).as_plain_ordered_dict()
|
|
||||||
http_retries_config["status_forcelist"] = self._retry_codes
|
|
||||||
|
|
||||||
self.__worker = worker or gethostname()
|
self.__worker = worker or gethostname()
|
||||||
|
|
||||||
@ -167,13 +156,15 @@ class Session(TokenManager):
|
|||||||
self.client = client or "api-{}".format(__version__)
|
self.client = client or "api-{}".format(__version__)
|
||||||
|
|
||||||
# limit the reconnect retries, so we get an error if we are starting the session
|
# limit the reconnect retries, so we get an error if we are starting the session
|
||||||
http_no_retries_config = dict(**http_retries_config)
|
_, self.__http_session = self._setup_session(
|
||||||
http_no_retries_config['connect'] = self._session_initial_connect_retry
|
http_retries_config,
|
||||||
self.__http_session = get_http_session_with_retry(**http_no_retries_config)
|
initial_session=True,
|
||||||
|
default_initial_connect_override=(False if kwargs.get("command") == "execute" else None)
|
||||||
|
)
|
||||||
# try to connect with the server
|
# try to connect with the server
|
||||||
self.refresh_token()
|
self.refresh_token()
|
||||||
# create the default session with many retries
|
# create the default session with many retries
|
||||||
self.__http_session = get_http_session_with_retry(**http_retries_config)
|
http_retries_config, self.__http_session = self._setup_session(http_retries_config)
|
||||||
|
|
||||||
# update api version from server response
|
# update api version from server response
|
||||||
try:
|
try:
|
||||||
@ -194,6 +185,31 @@ class Session(TokenManager):
|
|||||||
|
|
||||||
self._load_vaults()
|
self._load_vaults()
|
||||||
|
|
||||||
|
def _setup_session(self, http_retries_config, initial_session=False, default_initial_connect_override=None):
|
||||||
|
# type: (dict, bool, Optional[bool]) -> (dict, requests.Session)
|
||||||
|
http_retries_config = http_retries_config or self.config.get(
|
||||||
|
"api.http.retries", ConfigTree()
|
||||||
|
).as_plain_ordered_dict()
|
||||||
|
http_retries_config["status_forcelist"] = self._retry_codes
|
||||||
|
|
||||||
|
if initial_session:
|
||||||
|
kwargs = {} if default_initial_connect_override is None else {
|
||||||
|
"default": default_initial_connect_override
|
||||||
|
}
|
||||||
|
if ENV_INITIAL_CONNECT_RETRY_OVERRIDE.get(**kwargs):
|
||||||
|
connect_retries = self._session_initial_retry_connect_override
|
||||||
|
try:
|
||||||
|
value = ENV_INITIAL_CONNECT_RETRY_OVERRIDE.get(converter=str)
|
||||||
|
if not isinstance(value, bool):
|
||||||
|
connect_retries = abs(int(value))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
http_retries_config = dict(**http_retries_config)
|
||||||
|
http_retries_config['connect'] = connect_retries
|
||||||
|
|
||||||
|
return http_retries_config, get_http_session_with_retry(**http_retries_config)
|
||||||
|
|
||||||
def _load_vaults(self):
|
def _load_vaults(self):
|
||||||
if not self.check_min_api_version("2.15") or self.feature_set == "basic":
|
if not self.check_min_api_version("2.15") or self.feature_set == "basic":
|
||||||
return
|
return
|
||||||
@ -299,13 +315,9 @@ class Session(TokenManager):
|
|||||||
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
|
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
def _set_auth_token(self, auth_token):
|
|
||||||
self.__access_key = self.__secret_key = None
|
|
||||||
self.__auth_token = auth_token
|
|
||||||
|
|
||||||
def set_auth_token(self, auth_token):
|
def set_auth_token(self, auth_token):
|
||||||
self._set_auth_token(auth_token)
|
self.__access_key = self.__secret_key = None
|
||||||
self.refresh_token()
|
self._set_token(auth_token)
|
||||||
|
|
||||||
def send_request(
|
def send_request(
|
||||||
self,
|
self,
|
||||||
@ -576,7 +588,7 @@ class Session(TokenManager):
|
|||||||
return v + (0,) * max(0, 3 - len(v))
|
return v + (0,) * max(0, 3 - len(v))
|
||||||
return version_tuple(cls.api_version) >= version_tuple(str(min_api_version))
|
return version_tuple(cls.api_version) >= version_tuple(str(min_api_version))
|
||||||
|
|
||||||
def _do_refresh_token(self, old_token, exp=None):
|
def _do_refresh_token(self, current_token, exp=None):
|
||||||
""" TokenManager abstract method implementation.
|
""" TokenManager abstract method implementation.
|
||||||
Here we ignore the old token and simply obtain a new token.
|
Here we ignore the old token and simply obtain a new token.
|
||||||
"""
|
"""
|
||||||
@ -588,13 +600,13 @@ class Session(TokenManager):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
auth = None
|
||||||
headers = None
|
headers = None
|
||||||
# use token only once (the second time the token is already built into the http session)
|
if self.access_key and self.secret_key:
|
||||||
if self.__auth_token:
|
auth = HTTPBasicAuth(self.access_key, self.secret_key)
|
||||||
headers = dict(Authorization="Bearer {}".format(self.__auth_token))
|
elif current_token:
|
||||||
self.__auth_token = None
|
headers = dict(Authorization="Bearer {}".format(current_token))
|
||||||
|
|
||||||
auth = HTTPBasicAuth(self.access_key, self.secret_key) if self.access_key and self.secret_key else None
|
|
||||||
res = None
|
res = None
|
||||||
try:
|
try:
|
||||||
data = {"expiration_sec": exp} if exp else {}
|
data = {"expiration_sec": exp} if exp else {}
|
||||||
@ -619,7 +631,10 @@ class Session(TokenManager):
|
|||||||
)
|
)
|
||||||
if verbose:
|
if verbose:
|
||||||
self._logger.info("Received new token")
|
self._logger.info("Received new token")
|
||||||
return resp["data"]["token"]
|
token = resp["data"]["token"]
|
||||||
|
if ENV_AUTH_TOKEN.get():
|
||||||
|
os.environ[ENV_AUTH_TOKEN.key] = token
|
||||||
|
return token
|
||||||
except LoginError:
|
except LoginError:
|
||||||
six.reraise(*sys.exc_info())
|
six.reraise(*sys.exc_info())
|
||||||
except KeyError as ex:
|
except KeyError as ex:
|
||||||
|
@ -9,6 +9,8 @@ import six
|
|||||||
|
|
||||||
@six.add_metaclass(ABCMeta)
|
@six.add_metaclass(ABCMeta)
|
||||||
class TokenManager(object):
|
class TokenManager(object):
|
||||||
|
_default_token_exp_threshold_sec = 12 * 60 * 60
|
||||||
|
_default_req_token_expiration_sec = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def token_expiration_threshold_sec(self):
|
def token_expiration_threshold_sec(self):
|
||||||
@ -41,17 +43,30 @@ class TokenManager(object):
|
|||||||
return self.__token
|
return self.__token
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
token=None,
|
token=None,
|
||||||
req_token_expiration_sec=None,
|
req_token_expiration_sec=None,
|
||||||
token_history=None,
|
token_history=None,
|
||||||
token_expiration_threshold_sec=60,
|
token_expiration_threshold_sec=None,
|
||||||
**kwargs
|
config=None,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
super(TokenManager, self).__init__()
|
super(TokenManager, self).__init__()
|
||||||
assert isinstance(token_history, (type(None), dict))
|
assert isinstance(token_history, (type(None), dict))
|
||||||
self.token_expiration_threshold_sec = token_expiration_threshold_sec
|
if config:
|
||||||
self.req_token_expiration_sec = req_token_expiration_sec
|
req_token_expiration_sec = req_token_expiration_sec or config.get(
|
||||||
|
"api.auth.request_token_expiration_sec", None
|
||||||
|
)
|
||||||
|
token_expiration_threshold_sec = (
|
||||||
|
token_expiration_threshold_sec
|
||||||
|
or config.get("api.auth.token_expiration_threshold_sec", None)
|
||||||
|
)
|
||||||
|
self.token_expiration_threshold_sec = (
|
||||||
|
token_expiration_threshold_sec or self._default_token_exp_threshold_sec
|
||||||
|
)
|
||||||
|
self.req_token_expiration_sec = (
|
||||||
|
req_token_expiration_sec or self._default_req_token_expiration_sec
|
||||||
|
)
|
||||||
self._set_token(token)
|
self._set_token(token)
|
||||||
|
|
||||||
def _calc_token_valid_period_sec(self, token, exp=None, at_least_sec=None):
|
def _calc_token_valid_period_sec(self, token, exp=None, at_least_sec=None):
|
||||||
@ -59,7 +74,9 @@ class TokenManager(object):
|
|||||||
try:
|
try:
|
||||||
exp = exp or self._get_token_exp(token)
|
exp = exp or self._get_token_exp(token)
|
||||||
if at_least_sec:
|
if at_least_sec:
|
||||||
at_least_sec = max(at_least_sec, self.token_expiration_threshold_sec)
|
at_least_sec = max(
|
||||||
|
at_least_sec, self.token_expiration_threshold_sec
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
at_least_sec = self.token_expiration_threshold_sec
|
at_least_sec = self.token_expiration_threshold_sec
|
||||||
return max(0, (exp - time() - at_least_sec))
|
return max(0, (exp - time() - at_least_sec))
|
||||||
@ -71,14 +88,16 @@ class TokenManager(object):
|
|||||||
def get_decoded_token(cls, token, verify=False):
|
def get_decoded_token(cls, token, verify=False):
|
||||||
""" Get token expiration time. If not present, assume forever """
|
""" Get token expiration time. If not present, assume forever """
|
||||||
return jwt.decode(
|
return jwt.decode(
|
||||||
token, verify=verify,
|
token,
|
||||||
|
verify=verify,
|
||||||
options=dict(verify_signature=False),
|
options=dict(verify_signature=False),
|
||||||
algorithms=get_default_algorithms())
|
algorithms=get_default_algorithms(),
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_token_exp(cls, token):
|
def _get_token_exp(cls, token):
|
||||||
""" Get token expiration time. If not present, assume forever """
|
""" Get token expiration time. If not present, assume forever """
|
||||||
return cls.get_decoded_token(token).get('exp', sys.maxsize)
|
return cls.get_decoded_token(token).get("exp", sys.maxsize)
|
||||||
|
|
||||||
def _set_token(self, token):
|
def _set_token(self, token):
|
||||||
if token:
|
if token:
|
||||||
@ -89,7 +108,9 @@ class TokenManager(object):
|
|||||||
self.__token_expiration_sec = 0
|
self.__token_expiration_sec = 0
|
||||||
|
|
||||||
def get_token_valid_period_sec(self):
|
def get_token_valid_period_sec(self):
|
||||||
return self._calc_token_valid_period_sec(self.__token, self.token_expiration_sec)
|
return self._calc_token_valid_period_sec(
|
||||||
|
self.__token, self.token_expiration_sec
|
||||||
|
)
|
||||||
|
|
||||||
def _get_token(self):
|
def _get_token(self):
|
||||||
if self.get_token_valid_period_sec() <= 0:
|
if self.get_token_valid_period_sec() <= 0:
|
||||||
@ -101,4 +122,6 @@ class TokenManager(object):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def refresh_token(self):
|
def refresh_token(self):
|
||||||
self._set_token(self._do_refresh_token(self.__token, exp=self.req_token_expiration_sec))
|
self._set_token(
|
||||||
|
self._do_refresh_token(self.__token, exp=self.req_token_expiration_sec)
|
||||||
|
)
|
||||||
|
@ -24,6 +24,14 @@ def text_to_bool(value):
|
|||||||
return bool(strtobool(value))
|
return bool(strtobool(value))
|
||||||
|
|
||||||
|
|
||||||
|
def safe_text_to_bool(value):
|
||||||
|
# type: (Text) -> bool
|
||||||
|
try:
|
||||||
|
return text_to_bool(value)
|
||||||
|
except ValueError:
|
||||||
|
return bool(value)
|
||||||
|
|
||||||
|
|
||||||
def any_to_bool(value):
|
def any_to_bool(value):
|
||||||
# type: (Optional[Union[int, float, Text]]) -> bool
|
# type: (Optional[Union[int, float, Text]]) -> bool
|
||||||
if isinstance(value, six.text_type):
|
if isinstance(value, six.text_type):
|
||||||
|
Loading…
Reference in New Issue
Block a user