from __future__ import unicode_literals import abc import os from argparse import Namespace from collections import OrderedDict from enum import Enum from functools import reduce, wraps, WRAPPER_ASSIGNMENTS from importlib import import_module from itertools import chain from operator import itemgetter from types import ModuleType from typing import Dict, Text, Tuple, Type, Any, Sequence import six from ... import services as api_services from ....backend_api.session import CallResult from ....backend_api.session import Session, Request as APIRequest from ....backend_api.session.response import ResponseMeta from ....backend_config.defs import LOCAL_CONFIG_FILE_OVERRIDE_VAR SERVICE_TO_ENTITY_CLASS_NAMES = {"storage": "StorageItem"} def entity_class_name(service): # type: (ModuleType) -> Text service_name = api_entity_name(service) return SERVICE_TO_ENTITY_CLASS_NAMES.get(service_name.lower(), service_name) def api_entity_name(service): return module_name(service).rstrip("s") @six.python_2_unicode_compatible class APIError(Exception): """ Class for representing an API error. self.data - ``dict`` of all returned JSON data self.code - HTTP response code self.subcode - server response subcode self.codes - (self.code, self.subcode) tuple self.message - result message sent from server """ def __init__(self, response, extra_info=None): """ Create a new APIError from a server response """ super(APIError, self).__init__() self._response = response # type: CallResult self.extra_info = extra_info self.data = response.response_data # type: Dict self.meta = response.meta # type: ResponseMeta self.code = response.meta.result_code # type: int self.subcode = response.meta.result_subcode # type: int self.message = response.meta.result_msg # type: Text self.codes = (self.code, self.subcode) # type: Tuple[int, int] def get_traceback(self): """ Return server traceback for error, or None if doesn't exist. """ try: return self.meta.error_stack except AttributeError: return None def __str__(self): message = "{}: ".format(type(self).__name__) if self.extra_info: message += "{}: ".format(self.extra_info) if not self.meta: message += "no meta available" return message if not self.code: message += "no error code available" return message message += "code {0.code}".format(self) if self.subcode: message += "/{.subcode}".format(self) if self.message: message += ": {.message}".format(self) return message class StrictSession(Session): """ Session that raises exceptions on errors, and be configured with explicit ``config_file`` path. """ def __init__(self, config_file=None, initialize_logging=False, *args, **kwargs): """ :param config_file: configuration file to use, else use the default :type config_file: Path | Text """ def init(): super(StrictSession, self).__init__( initialize_logging=initialize_logging, *args, **kwargs ) if not config_file: init() return original = os.environ.get(LOCAL_CONFIG_FILE_OVERRIDE_VAR, None) try: os.environ[LOCAL_CONFIG_FILE_OVERRIDE_VAR] = str(config_file) init() finally: if original is None: os.environ.pop(LOCAL_CONFIG_FILE_OVERRIDE_VAR, None) else: os.environ[LOCAL_CONFIG_FILE_OVERRIDE_VAR] = original def send(self, request, *args, **kwargs): result = super(StrictSession, self).send(request, *args, **kwargs) if not result.ok(): raise APIError(result) if not result.response: raise APIError(result, extra_info="Invalid response") return result class Response(object): """ Proxy object for API result data. Exposes "meta" of the original result. """ def __init__(self, result, dest=None): """ :param result: result of endpoint call :type result: CallResult :param dest: if all of a response's data is contained in one field, use that field :type dest: Text """ self.response = None self._result = result response = getattr(result, "response", result) if getattr(response, "_service") == "events" and \ getattr(response, "_action") in ("scalar_metrics_iter_histogram", "multi_task_scalar_metrics_iter_histogram", "vector_metrics_iter_histogram", ): # put all the response data under metrics: response.metrics = result.response_data if 'metrics' not in response.__class__._get_data_props(): response.__class__._data_props_list['metrics'] = 'metrics' if dest: response = getattr(response, dest) self.response = response def __getattr__(self, attr): if self.response is None: return None return getattr(self.response, attr) @property def meta(self): return self._result.meta def __repr__(self): return repr(self.response) def __dir__(self): fields = [ name for name in dir(self.response) if isinstance(getattr(type(self.response), name, None), property) ] return list(set(chain(super(Response, self).__dir__(), fields)) - {"response"}) @six.python_2_unicode_compatible class TableResponse(Response): """ Representation of result containing an array of entities """ def __init__( self, service, # type: Service entity, # type: Type[entity] fields=None, # type: Sequence[Text] *args, **kwargs ): """ :param service: service of entity :param entity: class representing entity :param fields: entity attributes requested by client """ super(TableResponse, self).__init__(*args, **kwargs) self.service = service self.entity = entity self.fields = fields or ("id", "name") self.response = [entity(service, item) for item in self] def __repr__(self, fields=None): return self._format_table(fields=fields) __str__ = __repr__ def _format_table(self, fields=None): """ Display attributes of each element in a table :param fields: """ def getter(obj, attr): result = reduce( lambda x, name: x if x is None else getattr(x, name, None), attr.split("."), obj, ) return "" if result is None else result fields = fields or self.fields from trains_agent.helper.base import create_table return create_table( (dict((attr, getter(item, attr)) for attr in fields) for item in self), titles=fields, columns=fields, headers=True, ) def display(self, fields=None): print(self._format_table(fields=fields)) def where(self, predicate=None, **kwargs): """ Filter items. is a callable from a single item to a boolean. Items for which is True will be returned. Keyword arguments are interpreted as attribute equivalence, meaning: >>> datasets.where(name='foo') will return only datasets whose name is "foo". Giving more than one condition (predicate and keyword arguments) establishes an "and" relation. """ def compare_enum(x, y): return x == y or isinstance(x, Enum) and x.value == y return TableResponse( self.service, self.entity, self.fields, [ item for item in self if (not predicate or predicate(item)) and all( compare_enum(getattr(item, key), value) for key, value in kwargs.items() ) ], ) def __getitem__(self, item): return self.response[item] def __iter__(self): return iter(self.response) def __len__(self): return len(self.response) @six.add_metaclass(abc.ABCMeta) class Entity(object): """ Represent a server object. Enables calls like: >>> entity = client.service.get_by_id(entity_id) >>> entity.action(**kwargs) instead of: >>> client.service.action(id=entity_id, **kwargs) """ @abc.abstractproperty def entity_name(self): # type: () -> Text """ Singular name of entity """ pass @abc.abstractproperty def get_by_id_request(self): # type: () -> Type[APIRequest] """ get_by_id request class """ pass def __init__(self, service, data): self._service = service self.data = getattr(data, self.entity_name, data) self.__doc__ = self.data.__doc__ def fetch(self): """ Update the entity data from the server. """ result = self._service.session.send(self.get_by_id_request(self.data.id)) self.data = getattr(result.response, self.entity_name) def _get_default_kwargs(self): return {self.entity_name: self.data.id} def __getattr__(self, attr): """ Inject the entity's ID to the method call. All missing properties are assumed to be functions. """ try: return getattr(self.data, attr) except AttributeError: pass func = getattr(self._service, attr) @wrap_request_class(func) def new_func(*args, **kwargs): kwargs = dict(self._get_default_kwargs(), **kwargs) return func(*args, **kwargs) return new_func def __dir__(self): """ Add ``self._service``'s methods to ``dir()`` result. """ try: dir_ = super(Entity, self).__dir__ except AttributeError: base = self.__dict__ else: base = dir_() return list(set(base).union(dir(self._service), dir(self.data))) def __repr__(self): """ Display entity type, ID, and - if available - name. """ parts = (type(self).__name__, ": ", "id={}".format(self.data.id)) try: parts += (", ", 'name="{}"'.format(self.data.name)) except AttributeError: pass return "<{}>".format("".join(parts)) def wrap_request_class(cls): return wraps(cls, assigned=WRAPPER_ASSIGNMENTS + ("from_dict",)) def make_action(service, request_cls): action = request_cls._action try: get_by_id_request = service.GetByIdRequest except AttributeError: get_by_id_request = None wrap = wrap_request_class(request_cls) if action not in ["get_all", "get_all_ex", "get_by_id", "create"]: @wrap def new_func(self, *args, **kwargs): return Response(self.session.send(request_cls(*args, **kwargs))) new_func.__name__ = new_func.__qualname__ = action return new_func entity_name = api_entity_name(service) class_name = entity_class_name(service).capitalize() properties = { "__module__": __name__, "entity_name": entity_name.lower(), "get_by_id_request": get_by_id_request, } entity = type(str(class_name), (Entity,), properties) if action == "get_by_id": @wrap def get(self, *args, **kwargs): return entity( self, self.session.send(request_cls(*args, **kwargs)).response ) elif action == "create": @wrap def get(self, *args, **kwargs): return entity( self, Namespace( id=self.session.send(request_cls(*args, **kwargs)).response.id ), ) elif action in ["get_all", "get_all_ex"]: dest = service.response_mapping[request_cls]._get_data_props().popitem()[0] @wrap def get(self, *args, **kwargs): return TableResponse( service=self, entity=entity, result=self.session.send(request_cls(*args, **kwargs)), dest=dest, fields=kwargs.pop("only_fields", None), ) else: assert False get.__name__ = get.__qualname__ = action return get @six.add_metaclass(abc.ABCMeta) class Service(object): """ Superclass for action-grouping classes. """ name = abc.abstractproperty() __doc__ = abc.abstractproperty() def __init__(self, session): self.session = session def get_requests(service): return OrderedDict( (key, value) for key, value in sorted(vars(service).items(), key=itemgetter(0)) if isinstance(value, type) and issubclass(value, APIRequest) and value._action ) def make_service_class(module): # type: (...) -> Type[Service] """ Create a service class from service module. """ properties = OrderedDict( [ ("__module__", __name__), ("__doc__", module.__doc__), ("name", module_name(module)), ] ) properties.update( (f.__name__, f) for f in ( make_action(module, value) for key, value in get_requests(module).items() ) ) # noinspection PyTypeChecker return type(str(module_name(module)), (Service,), properties) def module_name(module): try: module = module.__name__ except AttributeError: pass base_name = module.split(".")[-1] return "".join(s.capitalize() for s in base_name.split("_")) class Version(Entity): entity_name = "version" get_by_id_request = None def fetch(self): try: published = self.data.status == "published" except AttributeError: published = False self.data = self._service.get_versions( dataset=self.dataset, only_published=published, versions=[self.id] )[0].data def _get_default_kwargs(self): return dict( super(Version, self)._get_default_kwargs(), **{"dataset": self.data.dataset} ) class APIClient(object): auth = None # type: Any debug = None # type: Any queues = None # type: Any tasks = None # type: Any workers = None # type: Any events = None # type: Any def __init__(self, session=None, api_version=None): self.session = session or StrictSession() def import_(*args, **kwargs): try: return import_module(*args, **kwargs) except ImportError: return None if api_version: api_version = "v{}".format(str(api_version).replace(".", "_")) services = OrderedDict( (name, mod) for name, mod in ( ( name, import_(".".join((api_services.__name__, api_version, name))), ) for name in api_services.__all__ ) if mod ) else: services = OrderedDict( (name, getattr(api_services, name)) for name in api_services.__all__ ) self.__dict__.update( dict( { name: make_service_class(module)(self.session) for name, module in services.items() }, ) )