diff --git a/apiserver/config/__init__.py b/apiserver/config/__init__.py index 3b65241..c84920b 100644 --- a/apiserver/config/__init__.py +++ b/apiserver/config/__init__.py @@ -1 +1 @@ -from .basic import BasicConfig, ConfigurationError, Factory +from .basic import BasicConfig, ConfigurationError diff --git a/apiserver/config/basic.py b/apiserver/config/basic.py index d9ec908..0a031d7 100644 --- a/apiserver/config/basic.py +++ b/apiserver/config/basic.py @@ -6,7 +6,7 @@ from functools import reduce from os import getenv from os.path import expandvars from pathlib import Path -from typing import List, Any, Type, TypeVar +from typing import List, Any, TypeVar from pyhocon import ConfigTree, ConfigFactory from pyparsing import ( @@ -169,29 +169,8 @@ class BasicConfig: class ConfigurationError(Exception): def __init__(self, msg, file_path=None, *args): - super(ConfigurationError, self).__init__(msg, *args) + super().__init__(msg, *args) self.file_path = file_path ConfigType = TypeVar("ConfigType", bound=BasicConfig) - - -class Factory: - _config_cls: Type[ConfigType] = BasicConfig - - @classmethod - def get(cls) -> BasicConfig: - config = cls._config_cls() - config.initialize_logging() - return config - - @classmethod - def set_cls(cls, cls_: Type[ConfigType]): - cls._config_cls = cls_ - - -__all__ = [ - "Factory", - "BasicConfig", - "ConfigurationError", -] diff --git a/apiserver/config_repo.py b/apiserver/config_repo.py index b72a1f6..631f0d8 100644 --- a/apiserver/config_repo.py +++ b/apiserver/config_repo.py @@ -1,3 +1,4 @@ -from apiserver.config import Factory +from apiserver.config import BasicConfig -config = Factory.get() +config = BasicConfig() +config.initialize_logging() diff --git a/apiserver/schema/schema_reader.py b/apiserver/schema/schema_reader.py index bf7cbbc..ef00976 100644 --- a/apiserver/schema/schema_reader.py +++ b/apiserver/schema/schema_reader.py @@ -17,7 +17,6 @@ from apiserver.utilities.partial_version import PartialVersion log = config.logger(__file__) -root = Path(__file__).parent / "services" ALL_ROLES = "*" @@ -196,11 +195,12 @@ class Schema: @attr.s() class SchemaReader: + root = Path(__file__).parent / "services" cache_path: Path = None def __attrs_post_init__(self): if not self.cache_path: - self.cache_path = root / "_cache.json" + self.cache_path = self.root / "_cache.json" @staticmethod def mod_time(path): @@ -220,7 +220,7 @@ class SchemaReader: """ services = [ service - for service in root.glob("*.conf") + for service in self.root.glob("*.conf") if not service.name.startswith("_") ] @@ -244,7 +244,7 @@ class SchemaReader: log.info("regenerating schema cache") services = {path.stem: self.read_file(path) for path in services} - api_defaults = self.read_file(root / "_api_defaults.conf") + api_defaults = self.read_file(self.root / "_api_defaults.conf") try: self.cache_path.write_text( diff --git a/apiserver/server_init/request_handlers.py b/apiserver/server_init/request_handlers.py index ee604b6..c75c399 100644 --- a/apiserver/server_init/request_handlers.py +++ b/apiserver/server_init/request_handlers.py @@ -13,8 +13,6 @@ log = config.logger(__file__) class RequestHandlers: _request_strip_prefix = config.get("apiserver.request.strip_prefix", None) - _service_repo_cls = ServiceRepo - _api_call_cls = APICall def before_app_first_request(self): pass @@ -27,7 +25,7 @@ class RequestHandlers: try: call = self._create_api_call(request) - content, content_type = self._service_repo_cls.handle_call(call) + content, content_type = ServiceRepo.handle_call(call) if call.result.redirect: response = redirect(call.result.redirect.url, call.result.redirect.code) @@ -39,7 +37,10 @@ class RequestHandlers: } response = Response( - content, mimetype=content_type, status=call.result.code, headers=headers + content, + mimetype=content_type, + status=call.result.code, + headers=headers, ) if call.result.cookies: @@ -47,13 +48,11 @@ class RequestHandlers: kwargs = config.get("apiserver.auth.cookies") if value is None: kwargs = kwargs.copy() - kwargs['max_age'] = 0 - kwargs['expires'] = 0 + kwargs["max_age"] = 0 + kwargs["expires"] = 0 response.set_cookie(key, "", **kwargs) else: - response.set_cookie( - key, value, **kwargs - ) + response.set_cookie(key, value, **kwargs) return response except Exception as ex: @@ -96,7 +95,7 @@ class RequestHandlers: call.data = json_body or form or {} def _call_or_empty_with_error(self, call, req, msg, code=500, subcode=0): - call = call or self._api_call_cls( + call = call or APICall( "", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files ) call.set_error_result(msg=msg, code=code, subcode=subcode) @@ -107,9 +106,11 @@ class RequestHandlers: try: # Parse the request path path = req.path - if self._request_strip_prefix and path.startswith(self._request_strip_prefix): - path = path[len(self._request_strip_prefix):] - endpoint_version, endpoint_name = self._service_repo_cls.parse_endpoint_path(path) + if self._request_strip_prefix and path.startswith( + self._request_strip_prefix + ): + path = path[len(self._request_strip_prefix) :] + endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(path) # Resolve authorization: if cookies contain an authorization token, use it as a starting point. # in any case, request headers always take precedence. @@ -126,7 +127,7 @@ class RequestHandlers: ) # add (possibly override with) the headers # Construct call instance - call = self._api_call_cls( + call = APICall( endpoint_name=endpoint_name, remote_addr=req.remote_addr, endpoint_version=endpoint_version, @@ -145,9 +146,13 @@ class RequestHandlers: except BadRequest as ex: call = self._call_or_empty_with_error(call, req, ex.description, 400) except BaseError as ex: - call = self._call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode) + call = self._call_or_empty_with_error( + call, req, ex.msg, ex.code, ex.subcode + ) except Exception as ex: log.exception("Error creating call") - call = self._call_or_empty_with_error(call, req, ex.args[0] if ex.args else type(ex).__name__, 500) + call = self._call_or_empty_with_error( + call, req, ex.args[0] if ex.args else type(ex).__name__, 500 + ) return call diff --git a/apiserver/service_repo/apicall.py b/apiserver/service_repo/apicall.py index 0139041..8087369 100644 --- a/apiserver/service_repo/apicall.py +++ b/apiserver/service_repo/apicall.py @@ -186,7 +186,7 @@ class APICallResult(DataContainer): error_data=None, cookies=None, ): - super(APICallResult, self).__init__(data) + super().__init__(data) self._code = code self._subcode = subcode self._msg = msg @@ -297,9 +297,7 @@ class MissingIdentity(Exception): def _get_headers(name: str) -> Tuple[str, ...]: - return tuple( - "-".join(("X", p, name)) for p in ("ClearML", "Trains") - ) + return tuple("-".join(("X", p, name)) for p in ("ClearML", "Trains")) class APICall(DataContainer): @@ -308,8 +306,6 @@ class APICall(DataContainer): HEADER_FORWARDED_FOR = "X-Forwarded-For" """ Standard headers """ - _call_result_cls = APICallResult - _transaction_headers = _get_headers("Trx") """ Transaction ID """ @@ -358,7 +354,7 @@ class APICall(DataContainer): host=None, auth_cookie=None, ): - super(APICall, self).__init__(data=data, batched_data=batched_data) + super().__init__(data=data, batched_data=batched_data) self._id = database.utils.id() self._files = files # currently dic of key to flask's FileStorage) @@ -375,7 +371,7 @@ class APICall(DataContainer): self._log_api = True if headers: self._headers.update(headers) - self._result = self._call_result_cls() + self._result = APICallResult() self._auth = None self._impersonation = None if trx: @@ -640,7 +636,7 @@ class APICall(DataContainer): self, msg, code=500, subcode=0, include_stack=False, error_data=None ): tb = format_exc() if include_stack else None - self._result = self._call_result_cls( + self._result = APICallResult( data=self._result.data, code=code, subcode=subcode, diff --git a/apiserver/service_repo/endpoint.py b/apiserver/service_repo/endpoint.py index 1fdc0b5..a2949b8 100644 --- a/apiserver/service_repo/endpoint.py +++ b/apiserver/service_repo/endpoint.py @@ -38,7 +38,6 @@ class Endpoint(object): :param response_data_model: response jsonschema model, will be validated if validate_schema=False :param validate_schema: whether request and response schema should be validated """ - super(Endpoint, self).__init__() self.name = name self.min_version = PartialVersion(min_version) self.func = func diff --git a/apiserver/service_repo/service_repo.py b/apiserver/service_repo/service_repo.py index 4196d06..ca8a8f9 100644 --- a/apiserver/service_repo/service_repo.py +++ b/apiserver/service_repo/service_repo.py @@ -2,13 +2,11 @@ import re from importlib import import_module from itertools import chain from pathlib import Path -from typing import cast, Iterable, List, MutableMapping, Optional, Tuple +from typing import cast, Iterable, List, MutableMapping, Optional, Tuple, Callable import jsonmodels.models -from apiserver import timing_context -from apiserver.apierrors import APIError -from apiserver.apierrors.errors.bad_request import RequestPathHasInvalidVersion +from apiserver.apierrors import APIError, errors from apiserver.config_repo import config from apiserver.utilities.partial_version import PartialVersion from .apicall import APICall @@ -77,18 +75,36 @@ class ServiceRepo(object): """ Token for internal calls """ @classmethod - def load(cls, root_module="services"): - root_module = Path(__file__).parents[1] / root_module + def _load_from_path( + cls, + root_module: Path, + module_prefix: Optional[str] = None, + predicate: Optional[Callable[[Path], bool]] = None, + ): + log.info(f"Loading services from {str(root_module.absolute())}") sub_module = None for sub_module in root_module.glob("*"): + if predicate and not predicate(sub_module): + continue + if ( sub_module.is_file() and sub_module.suffix == ".py" and not sub_module.stem == "__init__" ): - import_module(f"apiserver.{root_module.stem}.{sub_module.stem}") - if sub_module.is_dir(): - import_module(f"apiserver.{root_module.stem}.{sub_module.stem}") + import_module( + ".".join( + filter(None, (module_prefix, root_module.stem, sub_module.stem)) + ) + ) + + if sub_module.is_dir() and not sub_module.stem == "__pycache__": + import_module( + ".".join( + filter(None, (module_prefix, root_module.stem, sub_module.stem)) + ) + ) + # leave no trace of the 'sub_module' local del sub_module @@ -101,8 +117,14 @@ class ServiceRepo(object): ) @classmethod - def register(cls, endpoint): - assert isinstance(endpoint, Endpoint) + def load(cls, root_module="services"): + cls._load_from_path( + root_module=Path(__file__).parents[1] / root_module, + module_prefix="apiserver", + ) + + @classmethod + def register(cls, endpoint: Endpoint): if cls._endpoints.get(endpoint.name): if any( ep.min_version == endpoint.min_version @@ -149,7 +171,6 @@ class ServiceRepo(object): @classmethod def _resolve_endpoint_from_call(cls, call: APICall) -> Optional[Endpoint]: - assert isinstance(call, APICall) endpoint = cls._get_endpoint( call.endpoint_name, call.requested_endpoint_version ) @@ -165,7 +186,6 @@ class ServiceRepo(object): ) return - assert isinstance(endpoint, Endpoint) call.actual_endpoint_version = endpoint.min_version call.requires_authorization = endpoint.authorize return endpoint @@ -185,7 +205,9 @@ class ServiceRepo(object): try: version = PartialVersion(version) except ValueError as e: - raise RequestPathHasInvalidVersion(version=version, reason=e) + raise errors.bad_request.RequestPathHasInvalidVersion( + version=version, reason=e + ) if cls._check_max_version and version > cls._max_version: raise InvalidVersionError( f"Invalid API version (max. supported version is {cls._max_version})" @@ -232,8 +254,6 @@ class ServiceRepo(object): @classmethod def handle_call(cls, call: APICall): try: - assert isinstance(call, APICall) - if call.failed: raise CallFailedError() @@ -242,8 +262,7 @@ class ServiceRepo(object): if call.failed: raise CallFailedError() - with timing_context.TimingContext("service_repo", "validate_call"): - validate_all(call, endpoint) + validate_all(call, endpoint) if call.failed: raise CallFailedError()