Fix default method might get updated after default argument was initialized

Fix default api method does not work when set in configuration
This commit is contained in:
allegroai 2022-11-21 17:06:17 +02:00
parent 164169b73a
commit 443c6dc814
3 changed files with 45 additions and 14 deletions
clearml/backend_api

View File

@ -8,12 +8,15 @@ ENV_FILES_HOST = EnvEntry("CLEARML_FILES_HOST", "TRAINS_FILES_HOST")
ENV_ACCESS_KEY = EnvEntry("CLEARML_API_ACCESS_KEY", "TRAINS_API_ACCESS_KEY")
ENV_SECRET_KEY = EnvEntry("CLEARML_API_SECRET_KEY", "TRAINS_API_SECRET_KEY")
ENV_AUTH_TOKEN = EnvEntry("CLEARML_AUTH_TOKEN")
ENV_VERBOSE = EnvEntry("CLEARML_API_VERBOSE", "TRAINS_API_VERBOSE", type=bool, default=False)
ENV_VERBOSE = EnvEntry(
"CLEARML_API_VERBOSE", "TRAINS_API_VERBOSE", converter=safe_text_to_bool, type=bool, default=False
)
ENV_HOST_VERIFY_CERT = EnvEntry("CLEARML_API_HOST_VERIFY_CERT", "TRAINS_API_HOST_VERIFY_CERT",
type=bool, default=True)
ENV_OFFLINE_MODE = EnvEntry("CLEARML_OFFLINE_MODE", "TRAINS_OFFLINE_MODE", type=bool, converter=safe_text_to_bool)
ENV_CLEARML_NO_DEFAULT_SERVER = EnvEntry("CLEARML_NO_DEFAULT_SERVER", "TRAINS_NO_DEFAULT_SERVER",
converter=safe_text_to_bool, type=bool, default=True)
ENV_CLEARML_NO_DEFAULT_SERVER = EnvEntry(
"CLEARML_NO_DEFAULT_SERVER", "TRAINS_NO_DEFAULT_SERVER", converter=safe_text_to_bool, type=bool, default=True
)
ENV_DISABLE_VAULT_SUPPORT = EnvEntry('CLEARML_DISABLE_VAULT_SUPPORT', type=bool)
ENV_ENABLE_ENV_CONFIG_SECTION = EnvEntry('CLEARML_ENABLE_ENV_CONFIG_SECTION', type=bool)
ENV_ENABLE_FILES_CONFIG_SECTION = EnvEntry('CLEARML_ENABLE_FILES_CONFIG_SECTION', type=bool)

View File

@ -187,6 +187,7 @@ class Session(TokenManager):
"api.http.retries", ConfigTree()).as_plain_ordered_dict()
http_retries_config["status_forcelist"] = self._get_retry_codes()
http_retries_config["config"] = self.config
self.__http_session = get_http_session_with_retry(**http_retries_config)
self.__http_session.write_timeout = self._write_session_timeout
self.__http_session.request_size_threshold = self._write_session_data_size
@ -237,6 +238,18 @@ class Session(TokenManager):
self._apply_config_sections(local_logger)
self._update_default_api_method()
def _update_default_api_method(self):
if not ENV_API_DEFAULT_REQ_METHOD.get(default=None) and self.config.get("api.http.default_method", None):
def_method = str(self.config.get("api.http.default_method", None)).strip()
if def_method.upper() not in ("GET", "POST", "PUT"):
raise ValueError(
"api.http.default_method variable must be 'get' or 'post' (any case is allowed)."
)
Request.def_method = def_method
Request._method = Request.def_method
def _get_retry_codes(self):
# type: () -> List[int]
retry_codes = set(self._retry_codes)
@ -278,7 +291,8 @@ class Session(TokenManager):
# noinspection PyBroadException
try:
res = self.send_request("users", "get_vaults", json={"enabled": True, "types": ["config"]})
# Use params and not data/json otherwise payload might be dropped if we're using GET with a strict firewall
res = self.send_request("users", "get_vaults", params="enabled=true&types=config&types=config")
if res.ok:
vaults = res.json().get("data", {}).get("vaults", [])
data = list(filter(None, map(parse, vaults)))
@ -312,12 +326,13 @@ class Session(TokenManager):
service,
action,
version=None,
method=Request.def_method,
method=None,
headers=None,
auth=None,
data=None,
json=None,
refresh_token_if_unauthorized=True,
params=None,
):
""" Internal implementation for making a raw API request.
- Constructs the api endpoint name
@ -331,6 +346,9 @@ class Session(TokenManager):
if self._offline_mode:
return None
if not method:
method = Request.def_method
res = None
host = self.host
headers = headers.copy() if headers else {}
@ -401,11 +419,12 @@ class Session(TokenManager):
service,
action,
version=None,
method=Request.def_method,
method=None,
headers=None,
data=None,
json=None,
async_enable=False,
params=None,
):
"""
Send a raw API request.
@ -420,6 +439,8 @@ class Session(TokenManager):
:param async_enable: whether request is asynchronous
:return: requests Response instance
"""
if not method:
method = Request.def_method
headers = self.add_auth_headers(
headers.copy() if headers else {}
)
@ -434,6 +455,7 @@ class Session(TokenManager):
headers=headers,
data=data,
json=json,
params=params,
)
def send_request_batch(
@ -444,7 +466,7 @@ class Session(TokenManager):
headers=None,
data=None,
json=None,
method=Request.def_method,
method=None,
):
"""
Send a raw batch API request. Batch requests always use application/json-lines content type.
@ -469,6 +491,9 @@ class Session(TokenManager):
# Missing data (data or json), batch requests are meaningless without it.
return None
if not method:
method = Request.def_method
headers = headers.copy() if headers else {}
headers["Content-Type"] = "application/json-lines"
@ -677,7 +702,7 @@ class Session(TokenManager):
pass
cls.max_api_version = cls.api_version = cls._offline_default_version
else:
# if the requested version is lower then the minium we support,
# if the requested version is lower then the minimum we support,
# no need to actually check what the server has, we assume it must have at least our version.
if cls._version_tuple(cls.api_version) >= cls._version_tuple(str(min_api_version)):
return True
@ -736,15 +761,14 @@ class Session(TokenManager):
auth = HTTPBasicAuth(self.access_key, self.secret_key) if self.access_key and self.secret_key else None
res = None
try:
data = {"expiration_sec": exp} if exp else {}
res = self._send_request(
method=Request.def_method,
service="auth",
action="login",
auth=auth,
json=data,
headers=headers,
refresh_token_if_unauthorized=False,
params={"expiration_sec": exp} if exp else {},
)
try:
resp = res.json()

View File

@ -95,7 +95,9 @@ def get_http_session_with_retry(
backoff_factor=0,
backoff_max=None,
pool_connections=None,
pool_maxsize=None):
pool_maxsize=None,
config=None
):
global __disable_certificate_verification_warning
if not all(isinstance(x, (int, type(None))) for x in (total, connect, read, redirect, status)):
raise ValueError('Bad configuration. All retry count values must be null or int')
@ -103,16 +105,18 @@ def get_http_session_with_retry(
if status_forcelist and not all(isinstance(x, int) for x in status_forcelist):
raise ValueError('Bad configuration. Retry status_forcelist must be null or list of ints')
config = config or get_config()
pool_maxsize = (
pool_maxsize
if pool_maxsize is not None
else get_config().get('api.http.pool_maxsize', 512)
else config.get('api.http.pool_maxsize', 512)
)
pool_connections = (
pool_connections
if pool_connections is not None
else get_config().get('api.http.pool_connections', 512)
else config.get('api.http.pool_connections', 512)
)
session = SessionWithTimeout()
@ -135,7 +139,7 @@ def get_http_session_with_retry(
session.mount('http://', adapter)
session.mount('https://', adapter)
# update verify host certificate
session.verify = ENV_HOST_VERIFY_CERT.get(default=get_config().get('api.verify_certificate', True))
session.verify = ENV_HOST_VERIFY_CERT.get(default=config.get('api.verify_certificate', True))
if not session.verify and __disable_certificate_verification_warning < 2:
# show warning
__disable_certificate_verification_warning += 1