mirror of
https://github.com/clearml/clearml
synced 2025-06-04 03:47:57 +00:00
1055 lines
40 KiB
Python
1055 lines
40 KiB
Python
from __future__ import print_function
|
|
import json as json_lib
|
|
import logging
|
|
import os
|
|
import sys
|
|
import types
|
|
import weakref
|
|
from socket import gethostname
|
|
from time import sleep
|
|
|
|
import jwt
|
|
import requests
|
|
import six
|
|
from requests.auth import HTTPBasicAuth
|
|
from six.moves.urllib.parse import urlparse, urlunparse
|
|
from typing import List, Optional
|
|
|
|
from .callresult import CallResult
|
|
from .defs import (
|
|
ENV_VERBOSE,
|
|
ENV_HOST,
|
|
ENV_ACCESS_KEY,
|
|
ENV_SECRET_KEY,
|
|
ENV_WEB_HOST,
|
|
ENV_FILES_HOST,
|
|
ENV_OFFLINE_MODE,
|
|
ENV_AUTH_TOKEN,
|
|
ENV_DISABLE_VAULT_SUPPORT,
|
|
ENV_ENABLE_ENV_CONFIG_SECTION,
|
|
ENV_ENABLE_FILES_CONFIG_SECTION,
|
|
ENV_API_EXTRA_RETRY_CODES,
|
|
ENV_API_DEFAULT_REQ_METHOD,
|
|
ENV_FORCE_MAX_API_VERSION,
|
|
MissingConfigError
|
|
)
|
|
from .request import Request, BatchRequest # noqa: F401
|
|
from .token_manager import TokenManager
|
|
from ..utils import get_http_session_with_retry, urllib_log_warning_setup
|
|
from ...backend_config.defs import get_config_file
|
|
from ...debugging import get_logger
|
|
from ...debugging.log import resolve_logging_level
|
|
from ...utilities.pyhocon import ConfigTree, ConfigFactory
|
|
from ...version import __version__
|
|
from ...backend_config.utils import apply_files, apply_environment
|
|
|
|
try:
|
|
from OpenSSL.SSL import Error as SSLError
|
|
except ImportError:
|
|
from requests.exceptions import SSLError
|
|
|
|
|
|
class LoginError(Exception):
|
|
pass
|
|
|
|
|
|
class MaxRequestSizeError(Exception):
|
|
pass
|
|
|
|
|
|
class Session(TokenManager):
|
|
""" ClearML API Session class. """
|
|
|
|
_AUTHORIZATION_HEADER = "Authorization"
|
|
_WORKER_HEADER = ("X-ClearML-Worker", "X-Trains-Worker", )
|
|
_ASYNC_HEADER = ("X-ClearML-Async", "X-Trains-Async", )
|
|
_CLIENT_HEADER = ("X-ClearML-Client", "X-Trains-Client", )
|
|
|
|
_async_status_code = 202
|
|
_session_requests = 0
|
|
_session_initial_timeout = (3.0, 10.)
|
|
_session_timeout = (10.0, 300.)
|
|
_write_session_data_size = 15000
|
|
_write_session_timeout = (300.0, 300.)
|
|
_sessions_created = 0
|
|
_ssl_error_count_verbosity = 2
|
|
_offline_mode = ENV_OFFLINE_MODE.get()
|
|
_offline_default_version = '2.9'
|
|
# we want to keep track of sessions, but we also want to allow them to be collected by the GC if they are not used anymore
|
|
_sessions_weakrefs = []
|
|
|
|
_client = [(__package__.partition(".")[0], __version__)]
|
|
|
|
api_version = '2.9' # this default version should match the lowest api version we have under service
|
|
max_api_version = '2.9'
|
|
feature_set = 'basic'
|
|
default_demo_host = "https://demoapi.demo.clear.ml"
|
|
default_host = "https://api.clear.ml"
|
|
default_web = "https://app.clear.ml"
|
|
default_files = "https://files.clear.ml"
|
|
default_key = "" # "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
|
default_secret = "" # "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
|
force_max_api_version = ENV_FORCE_MAX_API_VERSION.get()
|
|
|
|
legacy_file_servers = ["https://files.community.clear.ml"]
|
|
|
|
# TODO: add requests.codes.gateway_timeout once we support async commits
|
|
_retry_codes = [
|
|
requests.codes.bad_gateway,
|
|
requests.codes.service_unavailable,
|
|
requests.codes.bandwidth_limit_exceeded,
|
|
requests.codes.too_many_requests,
|
|
]
|
|
|
|
@property
|
|
def access_key(self):
|
|
return self.__access_key
|
|
|
|
@property
|
|
def secret_key(self):
|
|
return self.__secret_key
|
|
|
|
@property
|
|
def auth_token(self):
|
|
return self.__auth_token
|
|
|
|
@property
|
|
def host(self):
|
|
return self.__host
|
|
|
|
@property
|
|
def worker(self):
|
|
return self.__worker
|
|
|
|
def __init__(
|
|
self,
|
|
worker=None,
|
|
api_key=None,
|
|
secret_key=None,
|
|
host=None,
|
|
logger=None,
|
|
verbose=None,
|
|
config=None,
|
|
http_retries_config=None,
|
|
**kwargs
|
|
):
|
|
self.__class__._sessions_weakrefs.append(weakref.ref(self))
|
|
|
|
self._verbose = verbose if verbose is not None else ENV_VERBOSE.get()
|
|
self._logger = logger
|
|
if self._verbose and not self._logger:
|
|
level = resolve_logging_level(ENV_VERBOSE.get(converter=str))
|
|
self._logger = get_logger(level=level, stream=sys.stderr if level is logging.DEBUG else None)
|
|
self.__worker = worker or self.get_worker_host_name()
|
|
self.client = ", ".join("{}-{}".format(*x) for x in self._client)
|
|
|
|
self.__init_api_key = api_key
|
|
self.__init_secret_key = secret_key
|
|
self.__init_host = host
|
|
self.__init_http_retries_config = http_retries_config
|
|
self.__token_manager_kwargs = kwargs
|
|
if config is not None:
|
|
self.config = config
|
|
else:
|
|
from clearml.config import ConfigWrapper
|
|
self.config = ConfigWrapper._init()
|
|
|
|
self._connect()
|
|
|
|
@classmethod
|
|
def add_client(cls, client, value, first=True):
|
|
# noinspection PyBroadException
|
|
try:
|
|
if not any(True for c in cls._client if c[0] == client):
|
|
if first:
|
|
cls._client.insert(0, (client, value))
|
|
else:
|
|
cls._client.append((client, value))
|
|
cls.client = ", ".join("{}-{}".format(*x) for x in cls._client)
|
|
except Exception:
|
|
pass
|
|
|
|
def _connect(self):
|
|
if self._offline_mode:
|
|
return
|
|
|
|
self._ssl_error_count_verbosity = self.config.get(
|
|
"api.ssl_error_count_verbosity", self._ssl_error_count_verbosity)
|
|
|
|
self.__host = self.__init_host or self.get_api_server_host(config=self.config)
|
|
if not self.__host:
|
|
raise ValueError("ClearML host was not set, check your configuration file or environment variable")
|
|
self.__host = self.__host.strip("/")
|
|
self.__http_retries_config = self.__init_http_retries_config or self.config.get(
|
|
"api.http.retries", ConfigTree()).as_plain_ordered_dict()
|
|
|
|
self.__http_retries_config["status_forcelist"] = self._get_retry_codes()
|
|
self.__http_retries_config["config"] = self.config
|
|
self.__http_session = get_http_session_with_retry(**self.__http_retries_config)
|
|
self.__http_session.write_timeout = self._write_session_timeout
|
|
self.__http_session.request_size_threshold = self._write_session_data_size
|
|
|
|
self.__max_req_size = self.config.get("api.http.max_req_size", None)
|
|
if not self.__max_req_size:
|
|
raise ValueError("missing max request size")
|
|
|
|
token_expiration_threshold_sec = self.config.get(
|
|
"auth.token_expiration_threshold_sec", 60
|
|
)
|
|
req_token_expiration_sec = self.config.get("api.auth.req_token_expiration_sec", None)
|
|
self.__auth_token = None
|
|
self._update_default_api_method()
|
|
if ENV_AUTH_TOKEN.get():
|
|
self.__access_key = self.__secret_key = None
|
|
self.__auth_token = ENV_AUTH_TOKEN.get()
|
|
# if we use a token we override make sure we are at least 3600 seconds (1 hour)
|
|
# away from the token expiration date, ask for a new one.
|
|
token_expiration_threshold_sec = max(token_expiration_threshold_sec, 3600)
|
|
else:
|
|
self.__access_key = self.__init_api_key or ENV_ACCESS_KEY.get(
|
|
default=(self.config.get("api.credentials.access_key", None) or self.default_key)
|
|
)
|
|
self.__secret_key = self.__init_secret_key or ENV_SECRET_KEY.get(
|
|
default=(self.config.get("api.credentials.secret_key", None) or self.default_secret)
|
|
)
|
|
|
|
if not self.secret_key and not self.access_key and not self.__auth_token:
|
|
raise MissingConfigError()
|
|
|
|
super(Session, self).__init__(
|
|
**self.__token_manager_kwargs,
|
|
token_expiration_threshold_sec=token_expiration_threshold_sec,
|
|
req_token_expiration_sec=req_token_expiration_sec
|
|
)
|
|
self.refresh_token()
|
|
|
|
local_logger = self._LocalLogger(self._logger)
|
|
|
|
# update api version from server response
|
|
try:
|
|
token_dict = TokenManager.get_decoded_token(self.token)
|
|
api_version = token_dict.get('api_version')
|
|
if not api_version:
|
|
api_version = '2.2' if token_dict.get('env', '') == 'prod' else Session.api_version
|
|
if token_dict.get('server_version'):
|
|
self.add_client('clearml-server', token_dict.get('server_version'))
|
|
|
|
Session.max_api_version = Session.api_version = str(api_version)
|
|
Session.feature_set = str(token_dict.get('feature_set', self.feature_set) or "basic")
|
|
except (jwt.DecodeError, ValueError):
|
|
local_logger().warning(
|
|
"Failed parsing server API level, defaulting to {}".format(Session.api_version))
|
|
|
|
# now setup the session reporting, so one consecutive retries will show warning
|
|
# we do that here, so if we have problems authenticating, we see them immediately
|
|
# notice: this is across the board warning omission
|
|
urllib_log_warning_setup(total_retries=self.__http_retries_config.get('total', 0), display_warning_after=3)
|
|
|
|
if self.force_max_api_version and self.check_min_api_version(self.force_max_api_version):
|
|
Session.max_api_version = Session.api_version = str(self.force_max_api_version)
|
|
|
|
# update only after we have max_api
|
|
self.__class__._sessions_created += 1
|
|
|
|
if self._load_vaults():
|
|
from clearml.config import ConfigWrapper, ConfigSDKWrapper
|
|
ConfigWrapper.set_config_impl(self.config)
|
|
ConfigSDKWrapper.clear_config_impl()
|
|
|
|
self._apply_config_sections(local_logger)
|
|
|
|
self._update_default_api_method()
|
|
|
|
def _update_default_api_method(self):
|
|
if not ENV_API_DEFAULT_REQ_METHOD.get(default=None) and self.config.get("api.http.default_method", None):
|
|
def_method = str(self.config.get("api.http.default_method", None)).strip()
|
|
if def_method.upper() not in ("GET", "POST", "PUT"):
|
|
raise ValueError(
|
|
"api.http.default_method variable must be 'get', 'post' or 'put' (any case is allowed)."
|
|
)
|
|
Request.def_method = def_method
|
|
Request._method = Request.def_method
|
|
|
|
def _get_retry_codes(self):
|
|
# type: () -> List[int]
|
|
retry_codes = set(self._retry_codes)
|
|
|
|
extra = self.config.get("api.http.extra_retry_codes", [])
|
|
if ENV_API_EXTRA_RETRY_CODES.get():
|
|
extra = [s.strip() for s in ENV_API_EXTRA_RETRY_CODES.get().split(",") if s.strip()]
|
|
|
|
for code in extra or []:
|
|
try:
|
|
retry_codes.add(int(code))
|
|
except (ValueError, TypeError):
|
|
print("Warning: invalid extra HTTP retry code detected: {}".format(code))
|
|
|
|
if retry_codes.difference(self._retry_codes):
|
|
print("Using extra HTTP retry codes {}".format(sorted(retry_codes.difference(self._retry_codes))))
|
|
|
|
return list(retry_codes)
|
|
|
|
def _read_vaults(self):
|
|
# () -> Optional[List[dict]]
|
|
if not self.check_min_api_version("2.15") or self.feature_set == "basic":
|
|
return
|
|
|
|
def parse(vault):
|
|
# noinspection PyBroadException
|
|
try:
|
|
d = vault.get('data', None)
|
|
if d:
|
|
r = ConfigFactory.parse_string(d)
|
|
if isinstance(r, (ConfigTree, dict)):
|
|
return r
|
|
except Exception as e:
|
|
(self._logger or get_logger()).warning("Failed parsing vault {}: {}".format(
|
|
vault.get("description", "<unknown>"), e))
|
|
|
|
# noinspection PyBroadException
|
|
try:
|
|
# Use params and not data/json otherwise payload might be dropped if we're using GET with a strict firewall
|
|
res = self.send_request("users", "get_vaults", params="enabled=true&types=config&types=config")
|
|
if res.ok:
|
|
vaults = res.json().get("data", {}).get("vaults", [])
|
|
data = list(filter(None, map(parse, vaults)))
|
|
if data:
|
|
return data
|
|
elif res.status_code != 404:
|
|
raise Exception(res.json().get("meta", {}).get("result_msg", res.text))
|
|
except Exception as ex:
|
|
(self._logger or get_logger()).warning("Failed getting vaults: {}".format(ex))
|
|
|
|
def _load_vaults(self):
|
|
# () -> Optional[bool]
|
|
if ENV_DISABLE_VAULT_SUPPORT.get():
|
|
# (self._logger or get_logger()).debug("Vault support is disabled")
|
|
return
|
|
|
|
data = self._read_vaults()
|
|
if data:
|
|
self.config.set_overrides(*data)
|
|
return True
|
|
|
|
def _apply_config_sections(self, local_logger):
|
|
# type: (_LocalLogger) -> None # noqa: F821
|
|
default = self.config.get("sdk.apply_environment", False)
|
|
if ENV_ENABLE_ENV_CONFIG_SECTION.get(default=default):
|
|
try:
|
|
keys = apply_environment(self.config)
|
|
if keys:
|
|
print("Environment variables set from configuration: {}".format(keys))
|
|
except Exception as ex:
|
|
local_logger().warning("Failed applying environment from configuration: {}".format(ex))
|
|
|
|
default = self.config.get("sdk.apply_files", default=False)
|
|
if ENV_ENABLE_FILES_CONFIG_SECTION.get(default=default):
|
|
try:
|
|
apply_files(self.config)
|
|
except Exception as ex:
|
|
local_logger().warning("Failed applying files from configuration: {}".format(ex))
|
|
|
|
def _send_request(
|
|
self,
|
|
service,
|
|
action,
|
|
version=None,
|
|
method=None,
|
|
headers=None,
|
|
auth=None,
|
|
data=None,
|
|
json=None,
|
|
refresh_token_if_unauthorized=True,
|
|
params=None,
|
|
):
|
|
""" Internal implementation for making a raw API request.
|
|
- Constructs the api endpoint name
|
|
- Injects the worker id into the headers
|
|
- Allows custom authorization using a requests auth object
|
|
- Intercepts `Unauthorized` responses and automatically attempts to refresh the session token once in this
|
|
case (only once). This is done since permissions are embedded in the token, and addresses a case where
|
|
server-side permissions have changed but are not reflected in the current token. Refreshing the token will
|
|
generate a token with the updated permissions.
|
|
"""
|
|
if self._offline_mode:
|
|
return None
|
|
|
|
if not method:
|
|
method = Request.def_method
|
|
|
|
res = None
|
|
host = self.host
|
|
headers = headers.copy() if headers else {}
|
|
for h in self._WORKER_HEADER:
|
|
headers[h] = self.worker
|
|
for h in self._CLIENT_HEADER:
|
|
headers[h] = self.client
|
|
|
|
token_refreshed_on_error = False
|
|
url = (
|
|
"{host}/v{version}/{service}.{action}"
|
|
if version
|
|
else "{host}/{service}.{action}"
|
|
).format(**locals())
|
|
retry_counter = 0
|
|
while True:
|
|
if data and len(data) > self._write_session_data_size:
|
|
timeout = self._write_session_timeout
|
|
elif self._session_requests < 1:
|
|
timeout = self._session_initial_timeout
|
|
else:
|
|
timeout = self._session_timeout
|
|
try:
|
|
if self._verbose and self._logger:
|
|
size = len(data or "")
|
|
if json and self._logger.level == logging.DEBUG:
|
|
size += len(json_lib.dumps(json))
|
|
self._logger.debug("%s: %s [%d bytes, %d headers]", method.upper(), url, size, len(headers or {}))
|
|
res = self.__http_session.request(
|
|
method, url, headers=headers, auth=auth, data=data, json=json, timeout=timeout, params=params)
|
|
if self._verbose and self._logger:
|
|
self._logger.debug("--> took %s", res.elapsed)
|
|
# except Exception as ex:
|
|
except SSLError as ex:
|
|
retry_counter += 1
|
|
# we should retry
|
|
if retry_counter >= self._ssl_error_count_verbosity:
|
|
(self._logger or get_logger()).warning("SSLError Retrying {}".format(ex))
|
|
sleep(0.1)
|
|
continue
|
|
|
|
if (
|
|
refresh_token_if_unauthorized
|
|
and res.status_code == requests.codes.unauthorized
|
|
and not token_refreshed_on_error
|
|
):
|
|
# it seems we're unauthorized, so we'll try to refresh our token once in case permissions changed since
|
|
# the last time we got the token, and try again
|
|
self.refresh_token()
|
|
token_refreshed_on_error = True
|
|
# try again
|
|
retry_counter += 1
|
|
continue
|
|
if (
|
|
res.status_code == requests.codes.service_unavailable
|
|
and self.config.get("api.http.wait_on_maintenance_forever", True)
|
|
):
|
|
(self._logger or get_logger()).warning(
|
|
"Service unavailable: {} is undergoing maintenance, retrying...".format(
|
|
host
|
|
)
|
|
)
|
|
retry_counter += 1
|
|
continue
|
|
break
|
|
self._session_requests += 1
|
|
return res
|
|
|
|
def add_auth_headers(self, headers):
|
|
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
|
|
return headers
|
|
|
|
def send_request(
|
|
self,
|
|
service,
|
|
action,
|
|
version=None,
|
|
method=None,
|
|
headers=None,
|
|
data=None,
|
|
json=None,
|
|
async_enable=False,
|
|
params=None,
|
|
):
|
|
"""
|
|
Send a raw API request.
|
|
:param service: service name
|
|
:param action: action name
|
|
:param version: version number (default is the preconfigured api version)
|
|
:param method: method type (default is 'get')
|
|
:param headers: request headers (authorization and content type headers will be automatically added)
|
|
:param json: json to send in the request body (jsonable object or builtin types construct. if used,
|
|
content type will be application/json)
|
|
:param data: Dictionary, bytes, or file-like object to send in the request body
|
|
:param async_enable: whether request is asynchronous
|
|
:param params: additional query parameters
|
|
:return: requests Response instance
|
|
"""
|
|
if not method:
|
|
method = Request.def_method
|
|
headers = self.add_auth_headers(
|
|
headers.copy() if headers else {}
|
|
)
|
|
if async_enable:
|
|
for h in self._ASYNC_HEADER:
|
|
headers[h] = "1"
|
|
return self._send_request(
|
|
service=service,
|
|
action=action,
|
|
version=version,
|
|
method=method,
|
|
headers=headers,
|
|
data=data,
|
|
json=json,
|
|
params=params,
|
|
)
|
|
|
|
def send_request_batch(
|
|
self,
|
|
service,
|
|
action,
|
|
version=None,
|
|
headers=None,
|
|
data=None,
|
|
json=None,
|
|
method=None,
|
|
):
|
|
"""
|
|
Send a raw batch API request. Batch requests always use application/json-lines content type.
|
|
:param service: service name
|
|
:param action: action name
|
|
:param version: version number (default is the preconfigured api version)
|
|
:param headers: request headers (authorization and content type headers will be automatically added)
|
|
:param json: iterable of json items (batched items, jsonable objects or builtin types constructs). These will
|
|
be sent as a multi-line payload in the request body.
|
|
:param data: iterable of bytes objects (batched items). These will be sent as a multi-line payload in the
|
|
request body.
|
|
:param method: HTTP method
|
|
:return: requests Response instance
|
|
"""
|
|
if not all(
|
|
isinstance(x, (list, tuple, type(None), types.GeneratorType))
|
|
for x in (data, json)
|
|
):
|
|
raise ValueError("Expecting list, tuple or generator in 'data' or 'json'")
|
|
|
|
if not data and not json:
|
|
# Missing data (data or json), batch requests are meaningless without it.
|
|
return None
|
|
|
|
if not method:
|
|
method = Request.def_method
|
|
|
|
headers = headers.copy() if headers else {}
|
|
headers["Content-Type"] = "application/json-lines"
|
|
|
|
if data:
|
|
req_data = "\n".join(data)
|
|
else:
|
|
req_data = "\n".join(json_lib.dumps(x) for x in json)
|
|
|
|
cur = 0
|
|
results = []
|
|
while True:
|
|
size = self.__max_req_size
|
|
slice = req_data[cur: cur + size]
|
|
if not slice:
|
|
break
|
|
if len(slice) < size:
|
|
# this is the remainder, no need to search for newline
|
|
pass
|
|
elif slice[-1] != "\n":
|
|
# search for the last newline in order to send a coherent request
|
|
size = slice.rfind("\n") + 1
|
|
# readjust the slice
|
|
slice = req_data[cur: cur + size]
|
|
if not slice:
|
|
raise MaxRequestSizeError('Error: {}.{} request exceeds limit {} > {} bytes'.format(
|
|
service, action, len(req_data), self.__max_req_size))
|
|
res = self.send_request(
|
|
method=method,
|
|
service=service,
|
|
action=action,
|
|
data=slice,
|
|
headers=headers,
|
|
version=version,
|
|
)
|
|
results.append(res)
|
|
if res.status_code != requests.codes.ok:
|
|
break
|
|
cur += size
|
|
return results
|
|
|
|
def validate_request(self, req_obj):
|
|
""" Validate an API request against the current version and the request's schema """
|
|
|
|
try:
|
|
# make sure we're using a compatible version for this request
|
|
# validate the request (checks required fields and specific field version restrictions)
|
|
validate = req_obj.validate
|
|
except AttributeError:
|
|
raise TypeError(
|
|
'"req_obj" parameter must be an backend_api.session.Request object'
|
|
)
|
|
|
|
validate()
|
|
|
|
def send_async(self, req_obj):
|
|
"""
|
|
Asynchronously sends an API request using a request object.
|
|
:param req_obj: The request object
|
|
:type req_obj: Request
|
|
:return: CallResult object containing the raw response, response metadata and parsed response object.
|
|
"""
|
|
return self.send(req_obj=req_obj, async_enable=True)
|
|
|
|
def send(self, req_obj, async_enable=False, headers=None):
|
|
"""
|
|
Sends an API request using a request object.
|
|
:param req_obj: The request object
|
|
:type req_obj: Request
|
|
:param async_enable: Request this method be executed in an asynchronous manner
|
|
:param headers: Additional headers to send with request
|
|
:return: CallResult object containing the raw response, response metadata and parsed response object.
|
|
"""
|
|
self.validate_request(req_obj)
|
|
|
|
if self._offline_mode:
|
|
return None
|
|
|
|
if isinstance(req_obj, BatchRequest):
|
|
# TODO: support async for batch requests as well
|
|
if async_enable:
|
|
raise NotImplementedError(
|
|
"Async behavior is currently not implemented for batch requests"
|
|
)
|
|
|
|
json_data = req_obj.get_json()
|
|
res = self.send_request_batch(
|
|
service=req_obj._service,
|
|
action=req_obj._action,
|
|
version=req_obj._version,
|
|
json=json_data,
|
|
method=req_obj._method,
|
|
headers=headers,
|
|
)
|
|
# TODO: handle multiple results in this case
|
|
if res is not None:
|
|
try:
|
|
res = next(r for r in res if r.status_code != 200)
|
|
except StopIteration:
|
|
# all are 200
|
|
res = res[0]
|
|
else:
|
|
res = self.send_request(
|
|
service=req_obj._service,
|
|
action=req_obj._action,
|
|
version=req_obj._version,
|
|
json=req_obj.to_dict(),
|
|
method=req_obj._method,
|
|
async_enable=async_enable,
|
|
headers=headers,
|
|
)
|
|
|
|
call_result = CallResult.from_result(
|
|
res=res,
|
|
request_cls=req_obj.__class__,
|
|
logger=self._logger,
|
|
service=req_obj._service,
|
|
action=req_obj._action,
|
|
session=self,
|
|
)
|
|
|
|
return call_result
|
|
|
|
@classmethod
|
|
def _make_all_sessions_go_online(cls):
|
|
for active_session in cls._get_all_active_sessions():
|
|
# noinspection PyProtectedMember
|
|
active_session._connect()
|
|
|
|
@classmethod
|
|
def _get_all_active_sessions(cls):
|
|
active_sessions = []
|
|
new_sessions_weakrefs = []
|
|
for session_weakref in cls._sessions_weakrefs:
|
|
session = session_weakref()
|
|
if session:
|
|
active_sessions.append(session)
|
|
new_sessions_weakrefs.append(session_weakref)
|
|
cls._sessions_weakrefs = new_sessions_weakrefs
|
|
return active_sessions
|
|
|
|
@classmethod
|
|
def get_api_server_host(cls, config=None):
|
|
if not config:
|
|
from ...config import config_obj
|
|
config = config_obj
|
|
return ENV_HOST.get(default=(config.get("api.api_server", None) or
|
|
config.get("api.host", None) or cls.default_host)).rstrip('/')
|
|
|
|
@classmethod
|
|
def get_app_server_host(cls, config=None):
|
|
if not config:
|
|
from ...config import config_obj
|
|
config = config_obj
|
|
|
|
# get from config/environment
|
|
web_host = ENV_WEB_HOST.get(default=config.get("api.web_server", "")).rstrip('/')
|
|
if web_host:
|
|
return web_host
|
|
|
|
# return default
|
|
host = cls.get_api_server_host(config)
|
|
if host == cls.default_host and cls.default_web:
|
|
return cls.default_web
|
|
|
|
# compose ourselves
|
|
if '://demoapi.' in host:
|
|
return host.replace('://demoapi.', '://demoapp.', 1)
|
|
if '://api.' in host:
|
|
return host.replace('://api.', '://app.', 1)
|
|
|
|
parsed = urlparse(host)
|
|
if parsed.port == 8008:
|
|
return host.replace(':8008', ':8080', 1)
|
|
|
|
raise ValueError('Could not detect ClearML web application server')
|
|
|
|
@classmethod
|
|
def get_files_server_host(cls, config=None):
|
|
if not config:
|
|
from ...config import config_obj
|
|
config = config_obj
|
|
# get from config/environment
|
|
files_host = ENV_FILES_HOST.get(default=(config.get("api.files_server", ""))).rstrip('/')
|
|
if files_host:
|
|
return files_host
|
|
|
|
# return default
|
|
host = cls.get_api_server_host(config)
|
|
if host == cls.default_host and cls.default_files:
|
|
return cls.default_files
|
|
|
|
# compose ourselves
|
|
app_host = cls.get_app_server_host(config)
|
|
parsed = urlparse(app_host)
|
|
if parsed.port:
|
|
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1))
|
|
elif parsed.netloc.startswith('demoapp.'):
|
|
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1))
|
|
elif parsed.netloc.startswith('app.'):
|
|
parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1))
|
|
else:
|
|
parsed = parsed._replace(netloc=parsed.netloc + ':8081')
|
|
|
|
return urlunparse(parsed)
|
|
|
|
@classmethod
|
|
def check_min_api_version(cls, min_api_version):
|
|
"""
|
|
Return True if Session.api_version is greater or equal >= to min_api_version
|
|
"""
|
|
# If no session was created, create a default one, in order to get the backend api version.
|
|
if cls._sessions_created <= 0:
|
|
if cls._offline_mode:
|
|
# allow to change the offline mode version by setting ENV_OFFLINE_MODE to the required API version
|
|
if cls.api_version != cls._offline_default_version:
|
|
offline_api = ENV_OFFLINE_MODE.get(converter=lambda x: x)
|
|
if offline_api:
|
|
try:
|
|
# check cast to float, but leave original str if we pass it.
|
|
# minimum version is 2.3
|
|
if float(offline_api) >= 2.3:
|
|
cls._offline_default_version = str(offline_api)
|
|
except ValueError:
|
|
pass
|
|
cls.max_api_version = cls.api_version = cls._offline_default_version
|
|
else:
|
|
# if the requested version is lower then the minimum we support,
|
|
# no need to actually check what the server has, we assume it must have at least our version.
|
|
if cls._version_tuple(cls.api_version) >= cls._version_tuple(str(min_api_version)):
|
|
return True
|
|
|
|
# noinspection PyBroadException
|
|
try:
|
|
cls()
|
|
except Exception:
|
|
pass
|
|
|
|
return cls._version_tuple(cls.api_version) >= cls._version_tuple(str(min_api_version))
|
|
|
|
@classmethod
|
|
def check_min_api_server_version(cls, min_api_version):
|
|
"""
|
|
Return True if Session.max_api_version is greater or equal >= to min_api_version
|
|
Notice this is the api version server reported, not the current SDK max supported api version
|
|
"""
|
|
if cls.check_min_api_version(min_api_version):
|
|
return True
|
|
|
|
return cls._version_tuple(cls.max_api_version) >= cls._version_tuple(str(min_api_version))
|
|
|
|
@classmethod
|
|
def get_worker_host_name(cls):
|
|
from ...config import dev_worker_name
|
|
return dev_worker_name() or gethostname()
|
|
|
|
@classmethod
|
|
def get_clients(cls):
|
|
return cls._client
|
|
|
|
@staticmethod
|
|
def _version_tuple(v):
|
|
v = tuple(map(int, (v.split("."))))
|
|
return v + (0,) * max(0, 3 - len(v))
|
|
|
|
def _do_refresh_token(self, old_token, exp=None):
|
|
""" TokenManager abstract method implementation.
|
|
Here we ignore the old token and simply obtain a new token.
|
|
"""
|
|
verbose = self._verbose and self._logger
|
|
if verbose:
|
|
self._logger.info(
|
|
"Refreshing token from {} (access_key={}, exp={})".format(
|
|
self.host, self.access_key, exp
|
|
)
|
|
)
|
|
|
|
headers = None
|
|
# use token only once (the second time the token is already built into the http session)
|
|
if self.__auth_token:
|
|
headers = dict(Authorization="Bearer {}".format(self.__auth_token))
|
|
self.__auth_token = None
|
|
|
|
auth = HTTPBasicAuth(self.access_key, self.secret_key) if self.access_key and self.secret_key else None
|
|
res = None
|
|
try:
|
|
res = self._send_request(
|
|
method=Request.def_method,
|
|
service="auth",
|
|
action="login",
|
|
auth=auth,
|
|
headers=headers,
|
|
refresh_token_if_unauthorized=False,
|
|
params={"expiration_sec": exp} if exp else {},
|
|
)
|
|
try:
|
|
resp = res.json()
|
|
except ValueError:
|
|
resp = {}
|
|
if res.status_code != 200:
|
|
msg = resp.get("meta", {}).get("result_msg", res.reason)
|
|
raise LoginError(
|
|
"Failed getting token (error {} from {}): {}".format(
|
|
res.status_code, self.host, msg
|
|
)
|
|
)
|
|
if verbose:
|
|
self._logger.info("Received new token")
|
|
|
|
# make sure we keep the token updated on the OS environment, so that child processes will have access.
|
|
if ENV_AUTH_TOKEN.get():
|
|
ENV_AUTH_TOKEN.set(resp["data"]["token"])
|
|
|
|
return resp["data"]["token"]
|
|
except LoginError:
|
|
six.reraise(*sys.exc_info())
|
|
except KeyError as ex:
|
|
# check if this is a misconfigured api server (getting 200 without the data section)
|
|
if res and res.status_code == 200:
|
|
raise ValueError('It seems *api_server* is misconfigured. '
|
|
'Is this the ClearML API server {} ?'.format(self.host))
|
|
else:
|
|
raise LoginError("Response data mismatch: No 'token' in 'data' value from res, receive : {}, "
|
|
"exception: {}".format(res, ex))
|
|
except Exception as ex:
|
|
raise LoginError('Unrecognized Authentication Error: {} {}'.format(type(ex), ex))
|
|
|
|
@staticmethod
|
|
def __get_browser_token(webserver):
|
|
# try to get the token if we are running inside a browser session (i.e. CoLab, Kaggle etc.)
|
|
if not os.environ.get("JPY_PARENT_PID"):
|
|
return None
|
|
|
|
try:
|
|
from google.colab import output # noqa
|
|
from google.colab._message import MessageError # noqa
|
|
from IPython import display # noqa
|
|
|
|
# must have cookie to same-origin: None for this one to work
|
|
display.display(
|
|
display.Javascript(
|
|
"""
|
|
window._ApiKey = new Promise((resolve, reject) => {
|
|
const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
|
|
fetch("%s/api/auth.login", {
|
|
method: 'GET',
|
|
credentials: 'include'
|
|
})
|
|
.then((response) => resolve(response.json()))
|
|
.then((json) => {
|
|
clearTimeout(timeout);
|
|
}).catch((err) => {
|
|
clearTimeout(timeout);
|
|
reject(err);
|
|
});
|
|
});
|
|
""" % webserver.rstrip("/")
|
|
))
|
|
|
|
response = output.eval_js("_ApiKey")
|
|
if not response:
|
|
return None
|
|
result_code = response.get("meta", {}).get("result_code")
|
|
token = response.get("data", {}).get("token")
|
|
except: # noqa
|
|
return None
|
|
|
|
if result_code != 200:
|
|
raise ValueError(
|
|
"Automatic authenticating failed, please login to {} and try again".format(webserver))
|
|
|
|
return token
|
|
|
|
def __str__(self):
|
|
return "{self.__class__.__name__}[{self.host}, {self.access_key}/{secret_key}]".format(
|
|
self=self, secret_key=self.secret_key[:5] + "*" * (len(self.secret_key) - 5)
|
|
)
|
|
|
|
class _LocalLogger:
|
|
def __init__(self, local_logger):
|
|
self.logger = local_logger
|
|
|
|
def __call__(self):
|
|
if not self.logger:
|
|
self.logger = get_logger()
|
|
return self.logger
|
|
|
|
|
|
def browser_login(clearml_server=None):
|
|
# type: (Optional[str]) -> ()
|
|
"""
|
|
Alternative authentication / login method, (instead of configuring ~/clearml.conf or Environment variables)
|
|
** Only applicable when running inside a browser session,
|
|
for example Google Colab, Kaggle notebook, Jupyter Notebooks etc. **
|
|
|
|
Notice: If called inside a python script, or when running with an agent, this function is ignored
|
|
|
|
:param clearml_server: Optional, set the clearml server address, default: https://app.clear.ml
|
|
"""
|
|
|
|
# check if we are running inside a Jupyter notebook of a sort
|
|
if not os.environ.get("JPY_PARENT_PID"):
|
|
return
|
|
|
|
# if we are running remotely or in offline mode, skip login
|
|
from clearml.config import running_remotely
|
|
# noinspection PyProtectedMember
|
|
if running_remotely():
|
|
return
|
|
|
|
# if we have working local configuration, nothing to do
|
|
try:
|
|
Session()
|
|
# make sure we set environment variables to point to our api/app/files hosts
|
|
ENV_WEB_HOST.set(Session.get_app_server_host())
|
|
ENV_HOST.set(Session.get_api_server_host())
|
|
ENV_FILES_HOST.set(Session.get_files_server_host())
|
|
return
|
|
except: # noqa
|
|
pass
|
|
|
|
# conform clearml_server address
|
|
if clearml_server:
|
|
if not clearml_server.lower().startswith("http"):
|
|
clearml_server = "http://{}".format(clearml_server)
|
|
|
|
parsed = urlparse(clearml_server)
|
|
if parsed.port:
|
|
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8008', 1))
|
|
|
|
if parsed.netloc.startswith('demoapp.'):
|
|
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demoapi.', 1))
|
|
elif parsed.netloc.startswith('app.'):
|
|
parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'api.', 1))
|
|
elif parsed.netloc.startswith('api.'):
|
|
pass
|
|
else:
|
|
parsed = parsed._replace(netloc='api.' + parsed.netloc)
|
|
|
|
clearml_server = urlunparse(parsed)
|
|
|
|
# set for later usage
|
|
ENV_HOST.set(clearml_server)
|
|
|
|
token = None
|
|
counter = 0
|
|
clearml_app_server = Session.get_app_server_host()
|
|
while not token:
|
|
# try to get authentication toke
|
|
try:
|
|
# noinspection PyProtectedMember
|
|
token = Session._Session__get_browser_token(clearml_app_server)
|
|
except ValueError:
|
|
token = None
|
|
except Exception: # noqa
|
|
token = None
|
|
# if we could not get a token, instruct the user to login
|
|
if not token:
|
|
if not counter:
|
|
print(
|
|
"ClearML automatic browser login failed, please login or create a new account\n"
|
|
"To get started with ClearML: setup your own `clearml-server`, "
|
|
"or create a free account at {}\n".format(clearml_app_server)
|
|
)
|
|
print("Please login to {} , then press [Enter] to connect ".format(clearml_app_server), end="")
|
|
input()
|
|
elif counter < 1:
|
|
print("Oh no we failed to connect \N{worried face}, "
|
|
"try to logout and login again - Press [Enter] to retry ", end="")
|
|
input()
|
|
else:
|
|
print(
|
|
"\n"
|
|
"We cannot connect automatically (adblocker / incognito?) \N{worried face} \n"
|
|
"Please go to {}/settings/workspace-configuration \n"
|
|
"Then press \x1B[1m\x1B[48;2;26;30;44m\x1B[37m + Create new credentials \x1b[0m \n"
|
|
"And copy/paste your \x1B[1m\x1B[4mAccess Key\x1b[0m here: ".format(
|
|
clearml_app_server.lstrip("/")), end="")
|
|
|
|
creds = input()
|
|
if creds:
|
|
print(" Setting access key ")
|
|
ENV_ACCESS_KEY.set(creds.strip())
|
|
|
|
print("Now copy/paste your \x1B[1m\x1B[4mSecret Key\x1b[0m here: ", end="")
|
|
creds = input()
|
|
if creds:
|
|
print(" Setting secret key ")
|
|
ENV_SECRET_KEY.set(creds.strip())
|
|
|
|
if ENV_ACCESS_KEY.get() and ENV_SECRET_KEY.get():
|
|
# store in conf file for persistence in runtime
|
|
# noinspection PyBroadException
|
|
try:
|
|
with open(get_config_file(), "wt") as f:
|
|
f.write("api.credentials.access_key={}\napi.credentials.secret_key={}\n".format(
|
|
ENV_ACCESS_KEY.get(), ENV_SECRET_KEY.get()
|
|
))
|
|
except Exception:
|
|
pass
|
|
break
|
|
|
|
counter += 1
|
|
|
|
print("")
|
|
if counter:
|
|
# these emojis actually requires python 3.6+
|
|
# print("\nHurrah! \N{face with party horn and party hat} \N{confetti ball} \N{party popper}")
|
|
print("\nHurrah! \U0001f973 \U0001f38a \U0001f389")
|
|
|
|
if token:
|
|
# set Token
|
|
ENV_AUTH_TOKEN.set(token)
|
|
|
|
if token or (ENV_ACCESS_KEY.get() and ENV_SECRET_KEY.get()):
|
|
# make sure we set environment variables to point to our api/app/files hosts
|
|
ENV_WEB_HOST.set(Session.get_app_server_host())
|
|
ENV_HOST.set(Session.get_api_server_host())
|
|
ENV_FILES_HOST.set(Session.get_files_server_host())
|
|
# verify token
|
|
Session()
|
|
# success
|
|
print("\N{robot face} ClearML connected successfully - let's build something! \N{rocket}")
|