clearml-server/apiserver/tests/api_client.py
2023-07-26 18:46:28 +03:00

298 lines
9.2 KiB
Python

import json
import os
import time
from contextlib import contextmanager
from typing import Sequence
from typing import Union, Tuple, Type
import requests
import six
from boltons.iterutils import remap
from boltons.typeutils import issubclass
from requests.adapters import HTTPAdapter, Retry
from requests.auth import HTTPBasicAuth
from apiserver.apierrors.base import BaseError
from apiserver.config_repo import config
log = config.logger(__file__)
class APICallResult:
def __init__(self, res):
if isinstance(res, str):
self._dict = json.loads(res)
else:
self._dict = res
self.data = self._dict["data"]
self.meta = APICallResultMeta(self._dict["meta"])
def as_dict(self):
return self._dict
class APICallResultMeta:
def __init__(self, meta_dict):
m = meta_dict
self._dict = m
self.call_id = m["id"]
self.trx_id = m["trx"]
self.result_code = m["result_code"]
self.result_msg = m["result_msg"]
self.result_subcode = m["result_subcode"]
self.error_stack = m.get("error_stack") or ""
self.endpoint_name = m["endpoint"]["name"]
self.endpoint_version = m["endpoint"]["actual_version"]
self.requested_endpoint_version = m["endpoint"]["requested_version"]
self.codes = self.result_code, self.result_subcode
def format_duration(duration):
return "" if duration is None else f" ({int(duration)} sec)"
class APIError(Exception):
def __init__(self, result: APICallResult, duration=None):
super().__init__(result)
self.result = result
self.duration = duration
def __str__(self):
meta = self.result.meta
message = (
f"APIClient got {meta.result_code}/{meta.result_subcode} from {meta.endpoint_name}: "
f"{meta.result_msg}{format_duration(self.duration)}"
)
if meta.error_stack:
header = "\n--- SERVER ERROR {} ---\n"
formatted_traceback = "\n{}{}{}\n".format(header.format("START"), meta.error_stack, header.format("END"))
message += formatted_traceback
return message
class AttrDict(dict):
"""
``dict`` which supports attribute lookup syntax.
Use to implement polymorphism over ``dict``s and database objects, which don't support subscription syntax.
"""
def __init__(self, dct=None, **kwargs):
super().__init__(
remap(
{**(dct or {}), **kwargs},
lambda _, key, value: (
key,
type(self)(value)
if isinstance(value, dict) and not isinstance(value, type(self))
else value,
),
)
)
def __getattr__(self, item):
return self[item]
class APIClient:
def __init__(
self,
api_key=None,
secret_key=None,
base_url=None,
impersonated_user_id=None,
session_token=None,
):
if not session_token:
self.api_key = (
api_key
or os.environ.get("SM_API_KEY")
or config.get("apiclient.api_key")
)
if not self.api_key:
raise ValueError("APIClient requires api_key in constructor or config")
self.secret_key = (
secret_key
or os.environ.get("SM_API_SECRET")
or config.get("apiclient.secret_key")
)
if not self.secret_key:
raise ValueError(
"APIClient requires secret_key in constructor or config"
)
self.base_url = (
base_url or os.environ.get("SM_API_URL") or config.get("apiclient.base_url")
)
if not self.base_url:
raise ValueError("APIClient requires base_url in constructor or config")
if self.base_url.endswith("/"):
self.base_url = self.base_url[:-1]
self.session_token = session_token
# create http session
self.http_session = requests.session()
retries = config.get("apiclient.retries", 7)
backoff_factor = config.get("apiclient.backoff_factor", 0.3)
status_forcelist = config.get("apiclient.status_forcelist", (500, 502, 504))
retry = Retry(
total=retries,
read=retries,
connect=retries,
backoff_factor=backoff_factor,
status_forcelist=status_forcelist,
)
adapter = HTTPAdapter(max_retries=retry)
self.http_session.mount("http://", adapter)
self.http_session.mount("https://", adapter)
if impersonated_user_id:
self.http_session.headers["X-ClearML-Impersonate-As"] = impersonated_user_id
if not self.session_token:
self.login()
def login(self):
res, self.session_token = self.send("auth.login", extract="token")
def impersonate(self, user_id):
return type(self)(
impersonated_user_id=user_id,
**{
key: getattr(self, key)
for key in ("api_key", "secret_key", "base_url", "session_token")
},
)
def send_batch(
self,
endpoint,
items: Sequence[dict],
headers_overrides=None,
raise_errors=True,
log_response=False,
extract=None,
):
headers_overrides = headers_overrides or {}
headers_overrides.update({"Content-Type": "application/json-lines"})
assert isinstance(items, (list, tuple)), type(items)
json_lines = (json.dumps(item) for item in items)
data = "\n".join(json_lines)
return self.send(
endpoint,
data=data,
headers_overrides=headers_overrides,
raise_errors=raise_errors,
log_response=log_response,
extract=extract,
)
def send(
self,
endpoint,
data=None,
headers_overrides=None,
raise_errors=True,
log_response=False,
is_async=False,
extract=None,
):
if headers_overrides is None:
headers_overrides = {}
if data is None:
data = {}
headers = {"Content-Type": "application/json"}
headers.update(headers_overrides)
if is_async:
headers["X-ClearML-Async"] = "1"
if not isinstance(data, six.string_types):
data = json.dumps(data)
if not self.session_token:
auth = HTTPBasicAuth(self.api_key, self.secret_key)
else:
auth = None
headers["Authorization"] = "Bearer %s" % self.session_token
url = "%s/%s" % (self.base_url, endpoint)
start = time.time()
http_res = self.http_session.post(url, headers=headers, data=data, auth=auth)
if not http_res.text:
msg = "APIClient got non standard response from %s. http_status=%s " % (
url,
http_res.status_code,
)
log.error(msg)
raise Exception(msg)
res = APICallResult(http_res.text)
if res.meta.result_code == 202:
# poll server for async result
got_result = False
call_id = res.meta.call_id
async_res_url = "%s/async.result?id=%s" % (self.base_url, call_id)
async_res_headers = headers.copy()
async_res_headers.pop("X-ClearML-Async")
while not got_result:
log.info("Got 202. Checking async result for %s (%s)" % (url, call_id))
http_res = self.http_session.get(
async_res_url, headers=async_res_headers
)
if http_res.status_code == 202:
time.sleep(5)
else:
res = APICallResult(http_res.text)
got_result = True
duration = time.time() - start
if res.meta.result_code != 200:
error = APIError(res, duration)
log.error(error)
if raise_errors:
raise error
else:
msg = "APIClient got {} from {}{}".format(
res.meta.result_code, url, format_duration(duration)
)
log.info(msg)
if log_response:
log.debug(res.as_dict())
if extract is None:
return res, res.data
else:
return res, res.data.get(extract)
class Service(object):
def __init__(self, api, name):
self.api = api
self.name = name
def __getattr__(self, item):
return lambda **kwargs: AttrDict(
self.api.send("{}.{}".format(self.name, item), kwargs)[1]
)
def __getattr__(self, item):
return self.Service(self, item)
@staticmethod
@contextmanager
def raises(codes: Union[Type[BaseError], Tuple[int, int]]):
if issubclass(codes, BaseError):
codes = codes.codes
log.info("Expecting failure: %s, %s", *codes)
try:
yield
except APIError as e:
if e.result.meta.codes != codes:
raise
else:
assert False, "call should fail"