Refactor service_repo

Code cleanup
This commit is contained in:
allegroai 2021-01-05 18:50:42 +02:00
parent 64c63d2560
commit 6870d8aba9
8 changed files with 73 additions and 74 deletions

View File

@ -1 +1 @@
from .basic import BasicConfig, ConfigurationError, Factory from .basic import BasicConfig, ConfigurationError

View File

@ -6,7 +6,7 @@ from functools import reduce
from os import getenv from os import getenv
from os.path import expandvars from os.path import expandvars
from pathlib import Path from pathlib import Path
from typing import List, Any, Type, TypeVar from typing import List, Any, TypeVar
from pyhocon import ConfigTree, ConfigFactory from pyhocon import ConfigTree, ConfigFactory
from pyparsing import ( from pyparsing import (
@ -169,29 +169,8 @@ class BasicConfig:
class ConfigurationError(Exception): class ConfigurationError(Exception):
def __init__(self, msg, file_path=None, *args): def __init__(self, msg, file_path=None, *args):
super(ConfigurationError, self).__init__(msg, *args) super().__init__(msg, *args)
self.file_path = file_path self.file_path = file_path
ConfigType = TypeVar("ConfigType", bound=BasicConfig) ConfigType = TypeVar("ConfigType", bound=BasicConfig)
class Factory:
_config_cls: Type[ConfigType] = BasicConfig
@classmethod
def get(cls) -> BasicConfig:
config = cls._config_cls()
config.initialize_logging()
return config
@classmethod
def set_cls(cls, cls_: Type[ConfigType]):
cls._config_cls = cls_
__all__ = [
"Factory",
"BasicConfig",
"ConfigurationError",
]

View File

@ -1,3 +1,4 @@
from apiserver.config import Factory from apiserver.config import BasicConfig
config = Factory.get() config = BasicConfig()
config.initialize_logging()

View File

@ -17,7 +17,6 @@ from apiserver.utilities.partial_version import PartialVersion
log = config.logger(__file__) log = config.logger(__file__)
root = Path(__file__).parent / "services"
ALL_ROLES = "*" ALL_ROLES = "*"
@ -196,11 +195,12 @@ class Schema:
@attr.s() @attr.s()
class SchemaReader: class SchemaReader:
root = Path(__file__).parent / "services"
cache_path: Path = None cache_path: Path = None
def __attrs_post_init__(self): def __attrs_post_init__(self):
if not self.cache_path: if not self.cache_path:
self.cache_path = root / "_cache.json" self.cache_path = self.root / "_cache.json"
@staticmethod @staticmethod
def mod_time(path): def mod_time(path):
@ -220,7 +220,7 @@ class SchemaReader:
""" """
services = [ services = [
service service
for service in root.glob("*.conf") for service in self.root.glob("*.conf")
if not service.name.startswith("_") if not service.name.startswith("_")
] ]
@ -244,7 +244,7 @@ class SchemaReader:
log.info("regenerating schema cache") log.info("regenerating schema cache")
services = {path.stem: self.read_file(path) for path in services} services = {path.stem: self.read_file(path) for path in services}
api_defaults = self.read_file(root / "_api_defaults.conf") api_defaults = self.read_file(self.root / "_api_defaults.conf")
try: try:
self.cache_path.write_text( self.cache_path.write_text(

View File

@ -13,8 +13,6 @@ log = config.logger(__file__)
class RequestHandlers: class RequestHandlers:
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None) _request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
_service_repo_cls = ServiceRepo
_api_call_cls = APICall
def before_app_first_request(self): def before_app_first_request(self):
pass pass
@ -27,7 +25,7 @@ class RequestHandlers:
try: try:
call = self._create_api_call(request) call = self._create_api_call(request)
content, content_type = self._service_repo_cls.handle_call(call) content, content_type = ServiceRepo.handle_call(call)
if call.result.redirect: if call.result.redirect:
response = redirect(call.result.redirect.url, call.result.redirect.code) response = redirect(call.result.redirect.url, call.result.redirect.code)
@ -39,7 +37,10 @@ class RequestHandlers:
} }
response = Response( response = Response(
content, mimetype=content_type, status=call.result.code, headers=headers content,
mimetype=content_type,
status=call.result.code,
headers=headers,
) )
if call.result.cookies: if call.result.cookies:
@ -47,13 +48,11 @@ class RequestHandlers:
kwargs = config.get("apiserver.auth.cookies") kwargs = config.get("apiserver.auth.cookies")
if value is None: if value is None:
kwargs = kwargs.copy() kwargs = kwargs.copy()
kwargs['max_age'] = 0 kwargs["max_age"] = 0
kwargs['expires'] = 0 kwargs["expires"] = 0
response.set_cookie(key, "", **kwargs) response.set_cookie(key, "", **kwargs)
else: else:
response.set_cookie( response.set_cookie(key, value, **kwargs)
key, value, **kwargs
)
return response return response
except Exception as ex: except Exception as ex:
@ -96,7 +95,7 @@ class RequestHandlers:
call.data = json_body or form or {} call.data = json_body or form or {}
def _call_or_empty_with_error(self, call, req, msg, code=500, subcode=0): def _call_or_empty_with_error(self, call, req, msg, code=500, subcode=0):
call = call or self._api_call_cls( call = call or APICall(
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files "", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
) )
call.set_error_result(msg=msg, code=code, subcode=subcode) call.set_error_result(msg=msg, code=code, subcode=subcode)
@ -107,9 +106,11 @@ class RequestHandlers:
try: try:
# Parse the request path # Parse the request path
path = req.path path = req.path
if self._request_strip_prefix and path.startswith(self._request_strip_prefix): if self._request_strip_prefix and path.startswith(
path = path[len(self._request_strip_prefix):] self._request_strip_prefix
endpoint_version, endpoint_name = self._service_repo_cls.parse_endpoint_path(path) ):
path = path[len(self._request_strip_prefix) :]
endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(path)
# Resolve authorization: if cookies contain an authorization token, use it as a starting point. # Resolve authorization: if cookies contain an authorization token, use it as a starting point.
# in any case, request headers always take precedence. # in any case, request headers always take precedence.
@ -126,7 +127,7 @@ class RequestHandlers:
) # add (possibly override with) the headers ) # add (possibly override with) the headers
# Construct call instance # Construct call instance
call = self._api_call_cls( call = APICall(
endpoint_name=endpoint_name, endpoint_name=endpoint_name,
remote_addr=req.remote_addr, remote_addr=req.remote_addr,
endpoint_version=endpoint_version, endpoint_version=endpoint_version,
@ -145,9 +146,13 @@ class RequestHandlers:
except BadRequest as ex: except BadRequest as ex:
call = self._call_or_empty_with_error(call, req, ex.description, 400) call = self._call_or_empty_with_error(call, req, ex.description, 400)
except BaseError as ex: except BaseError as ex:
call = self._call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode) call = self._call_or_empty_with_error(
call, req, ex.msg, ex.code, ex.subcode
)
except Exception as ex: except Exception as ex:
log.exception("Error creating call") log.exception("Error creating call")
call = self._call_or_empty_with_error(call, req, ex.args[0] if ex.args else type(ex).__name__, 500) call = self._call_or_empty_with_error(
call, req, ex.args[0] if ex.args else type(ex).__name__, 500
)
return call return call

View File

@ -186,7 +186,7 @@ class APICallResult(DataContainer):
error_data=None, error_data=None,
cookies=None, cookies=None,
): ):
super(APICallResult, self).__init__(data) super().__init__(data)
self._code = code self._code = code
self._subcode = subcode self._subcode = subcode
self._msg = msg self._msg = msg
@ -297,9 +297,7 @@ class MissingIdentity(Exception):
def _get_headers(name: str) -> Tuple[str, ...]: def _get_headers(name: str) -> Tuple[str, ...]:
return tuple( return tuple("-".join(("X", p, name)) for p in ("ClearML", "Trains"))
"-".join(("X", p, name)) for p in ("ClearML", "Trains")
)
class APICall(DataContainer): class APICall(DataContainer):
@ -308,8 +306,6 @@ class APICall(DataContainer):
HEADER_FORWARDED_FOR = "X-Forwarded-For" HEADER_FORWARDED_FOR = "X-Forwarded-For"
""" Standard headers """ """ Standard headers """
_call_result_cls = APICallResult
_transaction_headers = _get_headers("Trx") _transaction_headers = _get_headers("Trx")
""" Transaction ID """ """ Transaction ID """
@ -358,7 +354,7 @@ class APICall(DataContainer):
host=None, host=None,
auth_cookie=None, auth_cookie=None,
): ):
super(APICall, self).__init__(data=data, batched_data=batched_data) super().__init__(data=data, batched_data=batched_data)
self._id = database.utils.id() self._id = database.utils.id()
self._files = files # currently dic of key to flask's FileStorage) self._files = files # currently dic of key to flask's FileStorage)
@ -375,7 +371,7 @@ class APICall(DataContainer):
self._log_api = True self._log_api = True
if headers: if headers:
self._headers.update(headers) self._headers.update(headers)
self._result = self._call_result_cls() self._result = APICallResult()
self._auth = None self._auth = None
self._impersonation = None self._impersonation = None
if trx: if trx:
@ -640,7 +636,7 @@ class APICall(DataContainer):
self, msg, code=500, subcode=0, include_stack=False, error_data=None self, msg, code=500, subcode=0, include_stack=False, error_data=None
): ):
tb = format_exc() if include_stack else None tb = format_exc() if include_stack else None
self._result = self._call_result_cls( self._result = APICallResult(
data=self._result.data, data=self._result.data,
code=code, code=code,
subcode=subcode, subcode=subcode,

View File

@ -38,7 +38,6 @@ class Endpoint(object):
:param response_data_model: response jsonschema model, will be validated if validate_schema=False :param response_data_model: response jsonschema model, will be validated if validate_schema=False
:param validate_schema: whether request and response schema should be validated :param validate_schema: whether request and response schema should be validated
""" """
super(Endpoint, self).__init__()
self.name = name self.name = name
self.min_version = PartialVersion(min_version) self.min_version = PartialVersion(min_version)
self.func = func self.func = func

View File

@ -2,13 +2,11 @@ import re
from importlib import import_module from importlib import import_module
from itertools import chain from itertools import chain
from pathlib import Path from pathlib import Path
from typing import cast, Iterable, List, MutableMapping, Optional, Tuple from typing import cast, Iterable, List, MutableMapping, Optional, Tuple, Callable
import jsonmodels.models import jsonmodels.models
from apiserver import timing_context from apiserver.apierrors import APIError, errors
from apiserver.apierrors import APIError
from apiserver.apierrors.errors.bad_request import RequestPathHasInvalidVersion
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.utilities.partial_version import PartialVersion from apiserver.utilities.partial_version import PartialVersion
from .apicall import APICall from .apicall import APICall
@ -77,18 +75,36 @@ class ServiceRepo(object):
""" Token for internal calls """ """ Token for internal calls """
@classmethod @classmethod
def load(cls, root_module="services"): def _load_from_path(
root_module = Path(__file__).parents[1] / root_module cls,
root_module: Path,
module_prefix: Optional[str] = None,
predicate: Optional[Callable[[Path], bool]] = None,
):
log.info(f"Loading services from {str(root_module.absolute())}")
sub_module = None sub_module = None
for sub_module in root_module.glob("*"): for sub_module in root_module.glob("*"):
if predicate and not predicate(sub_module):
continue
if ( if (
sub_module.is_file() sub_module.is_file()
and sub_module.suffix == ".py" and sub_module.suffix == ".py"
and not sub_module.stem == "__init__" and not sub_module.stem == "__init__"
): ):
import_module(f"apiserver.{root_module.stem}.{sub_module.stem}") import_module(
if sub_module.is_dir(): ".".join(
import_module(f"apiserver.{root_module.stem}.{sub_module.stem}") filter(None, (module_prefix, root_module.stem, sub_module.stem))
)
)
if sub_module.is_dir() and not sub_module.stem == "__pycache__":
import_module(
".".join(
filter(None, (module_prefix, root_module.stem, sub_module.stem))
)
)
# leave no trace of the 'sub_module' local # leave no trace of the 'sub_module' local
del sub_module del sub_module
@ -101,8 +117,14 @@ class ServiceRepo(object):
) )
@classmethod @classmethod
def register(cls, endpoint): def load(cls, root_module="services"):
assert isinstance(endpoint, Endpoint) cls._load_from_path(
root_module=Path(__file__).parents[1] / root_module,
module_prefix="apiserver",
)
@classmethod
def register(cls, endpoint: Endpoint):
if cls._endpoints.get(endpoint.name): if cls._endpoints.get(endpoint.name):
if any( if any(
ep.min_version == endpoint.min_version ep.min_version == endpoint.min_version
@ -149,7 +171,6 @@ class ServiceRepo(object):
@classmethod @classmethod
def _resolve_endpoint_from_call(cls, call: APICall) -> Optional[Endpoint]: def _resolve_endpoint_from_call(cls, call: APICall) -> Optional[Endpoint]:
assert isinstance(call, APICall)
endpoint = cls._get_endpoint( endpoint = cls._get_endpoint(
call.endpoint_name, call.requested_endpoint_version call.endpoint_name, call.requested_endpoint_version
) )
@ -165,7 +186,6 @@ class ServiceRepo(object):
) )
return return
assert isinstance(endpoint, Endpoint)
call.actual_endpoint_version = endpoint.min_version call.actual_endpoint_version = endpoint.min_version
call.requires_authorization = endpoint.authorize call.requires_authorization = endpoint.authorize
return endpoint return endpoint
@ -185,7 +205,9 @@ class ServiceRepo(object):
try: try:
version = PartialVersion(version) version = PartialVersion(version)
except ValueError as e: except ValueError as e:
raise RequestPathHasInvalidVersion(version=version, reason=e) raise errors.bad_request.RequestPathHasInvalidVersion(
version=version, reason=e
)
if cls._check_max_version and version > cls._max_version: if cls._check_max_version and version > cls._max_version:
raise InvalidVersionError( raise InvalidVersionError(
f"Invalid API version (max. supported version is {cls._max_version})" f"Invalid API version (max. supported version is {cls._max_version})"
@ -232,8 +254,6 @@ class ServiceRepo(object):
@classmethod @classmethod
def handle_call(cls, call: APICall): def handle_call(cls, call: APICall):
try: try:
assert isinstance(call, APICall)
if call.failed: if call.failed:
raise CallFailedError() raise CallFailedError()
@ -242,8 +262,7 @@ class ServiceRepo(object):
if call.failed: if call.failed:
raise CallFailedError() raise CallFailedError()
with timing_context.TimingContext("service_repo", "validate_call"): validate_all(call, endpoint)
validate_all(call, endpoint)
if call.failed: if call.failed:
raise CallFailedError() raise CallFailedError()