clearml-agent/clearml_agent/backend_api/session/client/client.py
2020-12-22 23:00:57 +02:00

544 lines
16 KiB
Python

from __future__ import unicode_literals
import abc
import os
from argparse import Namespace
from collections import OrderedDict
from enum import Enum
from functools import reduce, wraps, WRAPPER_ASSIGNMENTS
from importlib import import_module
from itertools import chain
from operator import itemgetter
from types import ModuleType
from typing import Dict, Text, Tuple, Type, Any, Sequence
import six
from ... import services as api_services
from ....backend_api.session import CallResult
from ....backend_api.session import Session, Request as APIRequest
from ....backend_api.session.response import ResponseMeta
from ....backend_config.defs import LOCAL_CONFIG_FILE_OVERRIDE_VAR
SERVICE_TO_ENTITY_CLASS_NAMES = {"storage": "StorageItem"}
def entity_class_name(service):
# type: (ModuleType) -> Text
service_name = api_entity_name(service)
return SERVICE_TO_ENTITY_CLASS_NAMES.get(service_name.lower(), service_name)
def api_entity_name(service):
return module_name(service).rstrip("s")
@six.python_2_unicode_compatible
class APIError(Exception):
"""
Class for representing an API error.
self.data - ``dict`` of all returned JSON data
self.code - HTTP response code
self.subcode - server response subcode
self.codes - (self.code, self.subcode) tuple
self.message - result message sent from server
"""
def __init__(self, response, extra_info=None):
"""
Create a new APIError from a server response
"""
super(APIError, self).__init__()
self._response = response # type: CallResult
self.extra_info = extra_info
self.data = response.response_data # type: Dict
self.meta = response.meta # type: ResponseMeta
self.code = response.meta.result_code # type: int
self.subcode = response.meta.result_subcode # type: int
self.message = response.meta.result_msg # type: Text
self.codes = (self.code, self.subcode) # type: Tuple[int, int]
def get_traceback(self):
"""
Return server traceback for error, or None if doesn't exist.
"""
try:
return self.meta.error_stack
except AttributeError:
return None
def __str__(self):
message = "{}: ".format(type(self).__name__)
if self.extra_info:
message += "{}: ".format(self.extra_info)
if not self.meta:
message += "no meta available"
return message
if not self.code:
message += "no error code available"
return message
message += "code {0.code}".format(self)
if self.subcode:
message += "/{.subcode}".format(self)
if self.message:
message += ": {.message}".format(self)
return message
class StrictSession(Session):
"""
Session that raises exceptions on errors, and be configured with explicit ``config_file`` path.
"""
def __init__(self, config_file=None, initialize_logging=False, *args, **kwargs):
"""
:param config_file: configuration file to use, else use the default
:type config_file: Path | Text
"""
def init():
super(StrictSession, self).__init__(
initialize_logging=initialize_logging, *args, **kwargs
)
if not config_file:
init()
return
original = LOCAL_CONFIG_FILE_OVERRIDE_VAR.get() or None
try:
LOCAL_CONFIG_FILE_OVERRIDE_VAR.set(str(config_file))
init()
finally:
if original is None:
LOCAL_CONFIG_FILE_OVERRIDE_VAR.pop()
else:
LOCAL_CONFIG_FILE_OVERRIDE_VAR.set(original)
def send(self, request, *args, **kwargs):
result = super(StrictSession, self).send(request, *args, **kwargs)
if not result.ok():
raise APIError(result)
if not result.response:
raise APIError(result, extra_info="Invalid response")
return result
class Response(object):
"""
Proxy object for API result data.
Exposes "meta" of the original result.
"""
def __init__(self, result, dest=None):
"""
:param result: result of endpoint call
:type result: CallResult
:param dest: if all of a response's data is contained in one field, use that field
:type dest: Text
"""
self.response = None
self._result = result
response = getattr(result, "response", result)
if getattr(response, "_service") == "events" and \
getattr(response, "_action") in ("scalar_metrics_iter_histogram",
"multi_task_scalar_metrics_iter_histogram",
"vector_metrics_iter_histogram",
):
# put all the response data under metrics:
response.metrics = result.response_data
if 'metrics' not in response.__class__._get_data_props():
response.__class__._data_props_list['metrics'] = 'metrics'
if dest:
response = getattr(response, dest)
self.response = response
def __getattr__(self, attr):
if self.response is None:
return None
return getattr(self.response, attr)
@property
def meta(self):
return self._result.meta
def __repr__(self):
return repr(self.response)
def __dir__(self):
fields = [
name
for name in dir(self.response)
if isinstance(getattr(type(self.response), name, None), property)
]
return list(set(chain(super(Response, self).__dir__(), fields)) - {"response"})
@six.python_2_unicode_compatible
class TableResponse(Response):
"""
Representation of result containing an array of entities
"""
def __init__(
self,
service, # type: Service
entity, # type: Type[entity]
fields=None, # type: Sequence[Text]
*args,
**kwargs
):
"""
:param service: service of entity
:param entity: class representing entity
:param fields: entity attributes requested by client
"""
super(TableResponse, self).__init__(*args, **kwargs)
self.service = service
self.entity = entity
self.fields = fields or ("id", "name")
self.response = [entity(service, item) for item in self]
def __repr__(self, fields=None):
return self._format_table(fields=fields)
__str__ = __repr__
def _format_table(self, fields=None):
"""
Display <fields> attributes of each element in a table
:param fields:
"""
def getter(obj, attr):
result = reduce(
lambda x, name: x if x is None else getattr(x, name, None),
attr.split("."),
obj,
)
return "" if result is None else result
fields = fields or self.fields
from clearml_agent.helper.base import create_table
return create_table(
(dict((attr, getter(item, attr)) for attr in fields) for item in self),
titles=fields, columns=fields, headers=True,
)
def display(self, fields=None):
print(self._format_table(fields=fields))
def where(self, predicate=None, **kwargs):
"""
Filter items.
<predicate> is a callable from a single item to a boolean. Items for which <predicate> is True will be returned.
Keyword arguments are interpreted as attribute equivalence, meaning:
>>> datasets.where(name='foo')
will return only datasets whose name is "foo".
Giving more than one condition (predicate and keyword arguments) establishes an "and" relation.
"""
def compare_enum(x, y):
return x == y or isinstance(x, Enum) and x.value == y
return TableResponse(
self.service,
self.entity,
self.fields,
[
item
for item in self
if (not predicate or predicate(item))
and all(
compare_enum(getattr(item, key), value)
for key, value in kwargs.items()
)
],
)
def __getitem__(self, item):
return self.response[item]
def __iter__(self):
return iter(self.response)
def __len__(self):
return len(self.response)
@six.add_metaclass(abc.ABCMeta)
class Entity(object):
"""
Represent a server object.
Enables calls like:
>>> entity = client.service.get_by_id(entity_id)
>>> entity.action(**kwargs)
instead of:
>>> client.service.action(id=entity_id, **kwargs)
"""
@abc.abstractproperty
def entity_name(self): # type: () -> Text
"""
Singular name of entity
"""
pass
@abc.abstractproperty
def get_by_id_request(self): # type: () -> Type[APIRequest]
"""
get_by_id request class
"""
pass
def __init__(self, service, data):
self._service = service
self.data = getattr(data, self.entity_name, data)
self.__doc__ = self.data.__doc__
def fetch(self):
"""
Update the entity data from the server.
"""
result = self._service.session.send(self.get_by_id_request(self.data.id))
self.data = getattr(result.response, self.entity_name)
def _get_default_kwargs(self):
return {self.entity_name: self.data.id}
def __getattr__(self, attr):
"""
Inject the entity's ID to the method call.
All missing properties are assumed to be functions.
"""
try:
return getattr(self.data, attr)
except AttributeError:
pass
func = getattr(self._service, attr)
@wrap_request_class(func)
def new_func(*args, **kwargs):
kwargs = dict(self._get_default_kwargs(), **kwargs)
return func(*args, **kwargs)
return new_func
def __dir__(self):
"""
Add ``self._service``'s methods to ``dir()`` result.
"""
try:
dir_ = super(Entity, self).__dir__
except AttributeError:
base = self.__dict__
else:
base = dir_()
return list(set(base).union(dir(self._service), dir(self.data)))
def __repr__(self):
"""
Display entity type, ID, and - if available - name.
"""
parts = (type(self).__name__, ": ", "id={}".format(self.data.id))
try:
parts += (", ", 'name="{}"'.format(self.data.name))
except AttributeError:
pass
return "<{}>".format("".join(parts))
def wrap_request_class(cls):
return wraps(cls, assigned=WRAPPER_ASSIGNMENTS + ("from_dict",))
def make_action(service, request_cls):
action = request_cls._action
try:
get_by_id_request = service.GetByIdRequest
except AttributeError:
get_by_id_request = None
wrap = wrap_request_class(request_cls)
if action not in ["get_all", "get_all_ex", "get_by_id", "create"]:
@wrap
def new_func(self, *args, **kwargs):
return Response(self.session.send(request_cls(*args, **kwargs)))
new_func.__name__ = new_func.__qualname__ = action
return new_func
entity_name = api_entity_name(service)
class_name = entity_class_name(service).capitalize()
properties = {
"__module__": __name__,
"entity_name": entity_name.lower(),
"get_by_id_request": get_by_id_request,
}
entity = type(str(class_name), (Entity,), properties)
if action == "get_by_id":
@wrap
def get(self, *args, **kwargs):
return entity(
self, self.session.send(request_cls(*args, **kwargs)).response
)
elif action == "create":
@wrap
def get(self, *args, **kwargs):
return entity(
self,
Namespace(
id=self.session.send(request_cls(*args, **kwargs)).response.id
),
)
elif action in ["get_all", "get_all_ex"]:
dest = service.response_mapping[request_cls]._get_data_props().popitem()[0]
@wrap
def get(self, *args, **kwargs):
return TableResponse(
service=self,
entity=entity,
result=self.session.send(request_cls(*args, **kwargs)),
dest=dest,
fields=kwargs.pop("only_fields", None),
)
else:
assert False
get.__name__ = get.__qualname__ = action
return get
@six.add_metaclass(abc.ABCMeta)
class Service(object):
"""
Superclass for action-grouping classes.
"""
name = abc.abstractproperty()
__doc__ = abc.abstractproperty()
def __init__(self, session):
self.session = session
def get_requests(service):
return OrderedDict(
(key, value)
for key, value in sorted(vars(service).items(), key=itemgetter(0))
if isinstance(value, type) and issubclass(value, APIRequest) and value._action
)
def make_service_class(module):
# type: (...) -> Type[Service]
"""
Create a service class from service module.
"""
properties = OrderedDict(
[
("__module__", __name__),
("__doc__", module.__doc__),
("name", module_name(module)),
]
)
properties.update(
(f.__name__, f)
for f in (
make_action(module, value) for key, value in get_requests(module).items()
)
)
# noinspection PyTypeChecker
return type(str(module_name(module)), (Service,), properties)
def module_name(module):
try:
module = module.__name__
except AttributeError:
pass
base_name = module.split(".")[-1]
return "".join(s.capitalize() for s in base_name.split("_"))
class Version(Entity):
entity_name = "version"
get_by_id_request = None
def fetch(self):
try:
published = self.data.status == "published"
except AttributeError:
published = False
self.data = self._service.get_versions(
dataset=self.dataset, only_published=published, versions=[self.id]
)[0].data
def _get_default_kwargs(self):
return dict(
super(Version, self)._get_default_kwargs(), **{"dataset": self.data.dataset}
)
class APIClient(object):
auth = None # type: Any
debug = None # type: Any
queues = None # type: Any
tasks = None # type: Any
workers = None # type: Any
events = None # type: Any
def __init__(self, session=None, api_version=None):
self.session = session or StrictSession()
def import_(*args, **kwargs):
try:
return import_module(*args, **kwargs)
except ImportError:
return None
if api_version:
api_version = "v{}".format(str(api_version).replace(".", "_"))
services = OrderedDict(
(name, mod)
for name, mod in (
(
name,
import_(".".join((api_services.__name__, api_version, name))),
)
for name in api_services.__all__
)
if mod
)
else:
services = OrderedDict(
(name, getattr(api_services, name)) for name in api_services.__all__
)
self.__dict__.update(
dict(
{
name: make_service_class(module)(self.session)
for name, module in services.items()
},
)
)