mirror of
synced 2025-03-19 01:22:40 +00:00
544 lines
16 KiB
544 lines
16 KiB
from __future__ import unicode_literals
import abc
import os
from argparse import Namespace
from collections import OrderedDict
from enum import Enum
from functools import reduce, wraps, WRAPPER_ASSIGNMENTS
from importlib import import_module
from itertools import chain
from operator import itemgetter
from types import ModuleType
from typing import Dict, Text, Tuple, Type, Any, Sequence
import six
from ... import services as api_services
from ....backend_api.session import CallResult
from ....backend_api.session import Session, Request as APIRequest
from ....backend_api.session.response import ResponseMeta
from ....backend_config.defs import LOCAL_CONFIG_FILE_OVERRIDE_VAR
SERVICE_TO_ENTITY_CLASS_NAMES = {"storage": "StorageItem"}
def entity_class_name(service):
# type: (ModuleType) -> Text
service_name = api_entity_name(service)
return SERVICE_TO_ENTITY_CLASS_NAMES.get(service_name.lower(), service_name)
def api_entity_name(service):
return module_name(service).rstrip("s")
class APIError(Exception):
Class for representing an API error.
self.data - ``dict`` of all returned JSON data
self.code - HTTP response code
self.subcode - server response subcode
self.codes - (self.code, self.subcode) tuple
self.message - result message sent from server
def __init__(self, response, extra_info=None):
Create a new APIError from a server response
super(APIError, self).__init__()
self._response = response # type: CallResult
self.extra_info = extra_info
self.data = response.response_data # type: Dict
self.meta = response.meta # type: ResponseMeta
self.code = response.meta.result_code # type: int
self.subcode = response.meta.result_subcode # type: int
self.message = response.meta.result_msg # type: Text
self.codes = (self.code, self.subcode) # type: Tuple[int, int]
def get_traceback(self):
Return server traceback for error, or None if doesn't exist.
return self.meta.error_stack
except AttributeError:
return None
def __str__(self):
message = "{}: ".format(type(self).__name__)
if self.extra_info:
message += "{}: ".format(self.extra_info)
if not self.meta:
message += "no meta available"
return message
if not self.code:
message += "no error code available"
return message
message += "code {0.code}".format(self)
if self.subcode:
message += "/{.subcode}".format(self)
if self.message:
message += ": {.message}".format(self)
return message
class StrictSession(Session):
Session that raises exceptions on errors, and be configured with explicit ``config_file`` path.
def __init__(self, config_file=None, initialize_logging=False, *args, **kwargs):
:param config_file: configuration file to use, else use the default
:type config_file: Path | Text
def init():
super(StrictSession, self).__init__(
initialize_logging=initialize_logging, *args, **kwargs
if not config_file:
original = os.environ.get(LOCAL_CONFIG_FILE_OVERRIDE_VAR, None)
os.environ[LOCAL_CONFIG_FILE_OVERRIDE_VAR] = str(config_file)
if original is None:
os.environ[LOCAL_CONFIG_FILE_OVERRIDE_VAR] = original
def send(self, request, *args, **kwargs):
result = super(StrictSession, self).send(request, *args, **kwargs)
if not result.ok():
raise APIError(result)
if not result.response:
raise APIError(result, extra_info="Invalid response")
return result
class Response(object):
Proxy object for API result data.
Exposes "meta" of the original result.
def __init__(self, result, dest=None):
:param result: result of endpoint call
:type result: CallResult
:param dest: if all of a response's data is contained in one field, use that field
:type dest: Text
self.response = None
self._result = result
response = getattr(result, "response", result)
if getattr(response, "_service") == "events" and \
getattr(response, "_action") in ("scalar_metrics_iter_histogram",
# put all the response data under metrics:
response.metrics = result.response_data
if 'metrics' not in response.__class__._get_data_props():
response.__class__._data_props_list['metrics'] = 'metrics'
if dest:
response = getattr(response, dest)
self.response = response
def __getattr__(self, attr):
if self.response is None:
return None
return getattr(self.response, attr)
def meta(self):
return self._result.meta
def __repr__(self):
return repr(self.response)
def __dir__(self):
fields = [
for name in dir(self.response)
if isinstance(getattr(type(self.response), name, None), property)
return list(set(chain(super(Response, self).__dir__(), fields)) - {"response"})
class TableResponse(Response):
Representation of result containing an array of entities
def __init__(
service, # type: Service
entity, # type: Type[entity]
fields=None, # type: Sequence[Text]
:param service: service of entity
:param entity: class representing entity
:param fields: entity attributes requested by client
super(TableResponse, self).__init__(*args, **kwargs)
self.service = service
self.entity = entity
self.fields = fields or ("id", "name")
self.response = [entity(service, item) for item in self]
def __repr__(self, fields=None):
return self._format_table(fields=fields)
__str__ = __repr__
def _format_table(self, fields=None):
Display <fields> attributes of each element in a table
:param fields:
def getter(obj, attr):
result = reduce(
lambda x, name: x if x is None else getattr(x, name, None),
return "" if result is None else result
fields = fields or self.fields
from trains_agent.helper.base import create_table
return create_table(
(dict((attr, getter(item, attr)) for attr in fields) for item in self),
titles=fields, columns=fields, headers=True,
def display(self, fields=None):
def where(self, predicate=None, **kwargs):
Filter items.
<predicate> is a callable from a single item to a boolean. Items for which <predicate> is True will be returned.
Keyword arguments are interpreted as attribute equivalence, meaning:
>>> datasets.where(name='foo')
will return only datasets whose name is "foo".
Giving more than one condition (predicate and keyword arguments) establishes an "and" relation.
def compare_enum(x, y):
return x == y or isinstance(x, Enum) and x.value == y
return TableResponse(
for item in self
if (not predicate or predicate(item))
and all(
compare_enum(getattr(item, key), value)
for key, value in kwargs.items()
def __getitem__(self, item):
return self.response[item]
def __iter__(self):
return iter(self.response)
def __len__(self):
return len(self.response)
class Entity(object):
Represent a server object.
Enables calls like:
>>> entity = client.service.get_by_id(entity_id)
>>> entity.action(**kwargs)
instead of:
>>> client.service.action(id=entity_id, **kwargs)
def entity_name(self): # type: () -> Text
Singular name of entity
def get_by_id_request(self): # type: () -> Type[APIRequest]
get_by_id request class
def __init__(self, service, data):
self._service = service
self.data = getattr(data, self.entity_name, data)
self.__doc__ = self.data.__doc__
def fetch(self):
Update the entity data from the server.
result = self._service.session.send(self.get_by_id_request(self.data.id))
self.data = getattr(result.response, self.entity_name)
def _get_default_kwargs(self):
return {self.entity_name: self.data.id}
def __getattr__(self, attr):
Inject the entity's ID to the method call.
All missing properties are assumed to be functions.
return getattr(self.data, attr)
except AttributeError:
func = getattr(self._service, attr)
def new_func(*args, **kwargs):
kwargs = dict(self._get_default_kwargs(), **kwargs)
return func(*args, **kwargs)
return new_func
def __dir__(self):
Add ``self._service``'s methods to ``dir()`` result.
dir_ = super(Entity, self).__dir__
except AttributeError:
base = self.__dict__
base = dir_()
return list(set(base).union(dir(self._service), dir(self.data)))
def __repr__(self):
Display entity type, ID, and - if available - name.
parts = (type(self).__name__, ": ", "id={}".format(self.data.id))
parts += (", ", 'name="{}"'.format(self.data.name))
except AttributeError:
return "<{}>".format("".join(parts))
def wrap_request_class(cls):
return wraps(cls, assigned=WRAPPER_ASSIGNMENTS + ("from_dict",))
def make_action(service, request_cls):
action = request_cls._action
get_by_id_request = service.GetByIdRequest
except AttributeError:
get_by_id_request = None
wrap = wrap_request_class(request_cls)
if action not in ["get_all", "get_all_ex", "get_by_id", "create"]:
def new_func(self, *args, **kwargs):
return Response(self.session.send(request_cls(*args, **kwargs)))
new_func.__name__ = new_func.__qualname__ = action
return new_func
entity_name = api_entity_name(service)
class_name = entity_class_name(service).capitalize()
properties = {
"__module__": __name__,
"entity_name": entity_name.lower(),
"get_by_id_request": get_by_id_request,
entity = type(str(class_name), (Entity,), properties)
if action == "get_by_id":
def get(self, *args, **kwargs):
return entity(
self, self.session.send(request_cls(*args, **kwargs)).response
elif action == "create":
def get(self, *args, **kwargs):
return entity(
id=self.session.send(request_cls(*args, **kwargs)).response.id
elif action in ["get_all", "get_all_ex"]:
dest = service.response_mapping[request_cls]._get_data_props().popitem()[0]
def get(self, *args, **kwargs):
return TableResponse(
result=self.session.send(request_cls(*args, **kwargs)),
fields=kwargs.pop("only_fields", None),
assert False
get.__name__ = get.__qualname__ = action
return get
class Service(object):
Superclass for action-grouping classes.
name = abc.abstractproperty()
__doc__ = abc.abstractproperty()
def __init__(self, session):
self.session = session
def get_requests(service):
return OrderedDict(
(key, value)
for key, value in sorted(vars(service).items(), key=itemgetter(0))
if isinstance(value, type) and issubclass(value, APIRequest) and value._action
def make_service_class(module):
# type: (...) -> Type[Service]
Create a service class from service module.
properties = OrderedDict(
("__module__", __name__),
("__doc__", module.__doc__),
("name", module_name(module)),
(f.__name__, f)
for f in (
make_action(module, value) for key, value in get_requests(module).items()
# noinspection PyTypeChecker
return type(str(module_name(module)), (Service,), properties)
def module_name(module):
module = module.__name__
except AttributeError:
base_name = module.split(".")[-1]
return "".join(s.capitalize() for s in base_name.split("_"))
class Version(Entity):
entity_name = "version"
get_by_id_request = None
def fetch(self):
published = self.data.status == "published"
except AttributeError:
published = False
self.data = self._service.get_versions(
dataset=self.dataset, only_published=published, versions=[self.id]
def _get_default_kwargs(self):
return dict(
super(Version, self)._get_default_kwargs(), **{"dataset": self.data.dataset}
class APIClient(object):
auth = None # type: Any
debug = None # type: Any
queues = None # type: Any
tasks = None # type: Any
workers = None # type: Any
events = None # type: Any
def __init__(self, session=None, api_version=None):
self.session = session or StrictSession()
def import_(*args, **kwargs):
return import_module(*args, **kwargs)
except ImportError:
return None
if api_version:
api_version = "v{}".format(str(api_version).replace(".", "_"))
services = OrderedDict(
(name, mod)
for name, mod in (
import_(".".join((api_services.__name__, api_version, name))),
for name in api_services.__all__
if mod
services = OrderedDict(
(name, getattr(api_services, name)) for name in api_services.__all__
name: make_service_class(module)(self.session)
for name, module in services.items()