Refactor service_repo

Code cleanup
This commit is contained in:
allegroai 2021-01-05 18:50:42 +02:00
parent 64c63d2560
commit 6870d8aba9
8 changed files with 73 additions and 74 deletions

View File

@ -1 +1 @@
from .basic import BasicConfig, ConfigurationError, Factory
from .basic import BasicConfig, ConfigurationError

View File

@ -6,7 +6,7 @@ from functools import reduce
from os import getenv
from os.path import expandvars
from pathlib import Path
from typing import List, Any, Type, TypeVar
from typing import List, Any, TypeVar
from pyhocon import ConfigTree, ConfigFactory
from pyparsing import (
@ -169,29 +169,8 @@ class BasicConfig:
class ConfigurationError(Exception):
def __init__(self, msg, file_path=None, *args):
super(ConfigurationError, self).__init__(msg, *args)
super().__init__(msg, *args)
self.file_path = file_path
ConfigType = TypeVar("ConfigType", bound=BasicConfig)
class Factory:
_config_cls: Type[ConfigType] = BasicConfig
@classmethod
def get(cls) -> BasicConfig:
config = cls._config_cls()
config.initialize_logging()
return config
@classmethod
def set_cls(cls, cls_: Type[ConfigType]):
cls._config_cls = cls_
__all__ = [
"Factory",
"BasicConfig",
"ConfigurationError",
]

View File

@ -1,3 +1,4 @@
from apiserver.config import Factory
from apiserver.config import BasicConfig
config = Factory.get()
config = BasicConfig()
config.initialize_logging()

View File

@ -17,7 +17,6 @@ from apiserver.utilities.partial_version import PartialVersion
log = config.logger(__file__)
root = Path(__file__).parent / "services"
ALL_ROLES = "*"
@ -196,11 +195,12 @@ class Schema:
@attr.s()
class SchemaReader:
root = Path(__file__).parent / "services"
cache_path: Path = None
def __attrs_post_init__(self):
if not self.cache_path:
self.cache_path = root / "_cache.json"
self.cache_path = self.root / "_cache.json"
@staticmethod
def mod_time(path):
@ -220,7 +220,7 @@ class SchemaReader:
"""
services = [
service
for service in root.glob("*.conf")
for service in self.root.glob("*.conf")
if not service.name.startswith("_")
]
@ -244,7 +244,7 @@ class SchemaReader:
log.info("regenerating schema cache")
services = {path.stem: self.read_file(path) for path in services}
api_defaults = self.read_file(root / "_api_defaults.conf")
api_defaults = self.read_file(self.root / "_api_defaults.conf")
try:
self.cache_path.write_text(

View File

@ -13,8 +13,6 @@ log = config.logger(__file__)
class RequestHandlers:
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
_service_repo_cls = ServiceRepo
_api_call_cls = APICall
def before_app_first_request(self):
pass
@ -27,7 +25,7 @@ class RequestHandlers:
try:
call = self._create_api_call(request)
content, content_type = self._service_repo_cls.handle_call(call)
content, content_type = ServiceRepo.handle_call(call)
if call.result.redirect:
response = redirect(call.result.redirect.url, call.result.redirect.code)
@ -39,7 +37,10 @@ class RequestHandlers:
}
response = Response(
content, mimetype=content_type, status=call.result.code, headers=headers
content,
mimetype=content_type,
status=call.result.code,
headers=headers,
)
if call.result.cookies:
@ -47,13 +48,11 @@ class RequestHandlers:
kwargs = config.get("apiserver.auth.cookies")
if value is None:
kwargs = kwargs.copy()
kwargs['max_age'] = 0
kwargs['expires'] = 0
kwargs["max_age"] = 0
kwargs["expires"] = 0
response.set_cookie(key, "", **kwargs)
else:
response.set_cookie(
key, value, **kwargs
)
response.set_cookie(key, value, **kwargs)
return response
except Exception as ex:
@ -96,7 +95,7 @@ class RequestHandlers:
call.data = json_body or form or {}
def _call_or_empty_with_error(self, call, req, msg, code=500, subcode=0):
call = call or self._api_call_cls(
call = call or APICall(
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
)
call.set_error_result(msg=msg, code=code, subcode=subcode)
@ -107,9 +106,11 @@ class RequestHandlers:
try:
# Parse the request path
path = req.path
if self._request_strip_prefix and path.startswith(self._request_strip_prefix):
path = path[len(self._request_strip_prefix):]
endpoint_version, endpoint_name = self._service_repo_cls.parse_endpoint_path(path)
if self._request_strip_prefix and path.startswith(
self._request_strip_prefix
):
path = path[len(self._request_strip_prefix) :]
endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(path)
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
# in any case, request headers always take precedence.
@ -126,7 +127,7 @@ class RequestHandlers:
) # add (possibly override with) the headers
# Construct call instance
call = self._api_call_cls(
call = APICall(
endpoint_name=endpoint_name,
remote_addr=req.remote_addr,
endpoint_version=endpoint_version,
@ -145,9 +146,13 @@ class RequestHandlers:
except BadRequest as ex:
call = self._call_or_empty_with_error(call, req, ex.description, 400)
except BaseError as ex:
call = self._call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode)
call = self._call_or_empty_with_error(
call, req, ex.msg, ex.code, ex.subcode
)
except Exception as ex:
log.exception("Error creating call")
call = self._call_or_empty_with_error(call, req, ex.args[0] if ex.args else type(ex).__name__, 500)
call = self._call_or_empty_with_error(
call, req, ex.args[0] if ex.args else type(ex).__name__, 500
)
return call

View File

@ -186,7 +186,7 @@ class APICallResult(DataContainer):
error_data=None,
cookies=None,
):
super(APICallResult, self).__init__(data)
super().__init__(data)
self._code = code
self._subcode = subcode
self._msg = msg
@ -297,9 +297,7 @@ class MissingIdentity(Exception):
def _get_headers(name: str) -> Tuple[str, ...]:
return tuple(
"-".join(("X", p, name)) for p in ("ClearML", "Trains")
)
return tuple("-".join(("X", p, name)) for p in ("ClearML", "Trains"))
class APICall(DataContainer):
@ -308,8 +306,6 @@ class APICall(DataContainer):
HEADER_FORWARDED_FOR = "X-Forwarded-For"
""" Standard headers """
_call_result_cls = APICallResult
_transaction_headers = _get_headers("Trx")
""" Transaction ID """
@ -358,7 +354,7 @@ class APICall(DataContainer):
host=None,
auth_cookie=None,
):
super(APICall, self).__init__(data=data, batched_data=batched_data)
super().__init__(data=data, batched_data=batched_data)
self._id = database.utils.id()
self._files = files # currently dic of key to flask's FileStorage)
@ -375,7 +371,7 @@ class APICall(DataContainer):
self._log_api = True
if headers:
self._headers.update(headers)
self._result = self._call_result_cls()
self._result = APICallResult()
self._auth = None
self._impersonation = None
if trx:
@ -640,7 +636,7 @@ class APICall(DataContainer):
self, msg, code=500, subcode=0, include_stack=False, error_data=None
):
tb = format_exc() if include_stack else None
self._result = self._call_result_cls(
self._result = APICallResult(
data=self._result.data,
code=code,
subcode=subcode,

View File

@ -38,7 +38,6 @@ class Endpoint(object):
:param response_data_model: response jsonschema model, will be validated if validate_schema=False
:param validate_schema: whether request and response schema should be validated
"""
super(Endpoint, self).__init__()
self.name = name
self.min_version = PartialVersion(min_version)
self.func = func

View File

@ -2,13 +2,11 @@ import re
from importlib import import_module
from itertools import chain
from pathlib import Path
from typing import cast, Iterable, List, MutableMapping, Optional, Tuple
from typing import cast, Iterable, List, MutableMapping, Optional, Tuple, Callable
import jsonmodels.models
from apiserver import timing_context
from apiserver.apierrors import APIError
from apiserver.apierrors.errors.bad_request import RequestPathHasInvalidVersion
from apiserver.apierrors import APIError, errors
from apiserver.config_repo import config
from apiserver.utilities.partial_version import PartialVersion
from .apicall import APICall
@ -77,18 +75,36 @@ class ServiceRepo(object):
""" Token for internal calls """
@classmethod
def load(cls, root_module="services"):
root_module = Path(__file__).parents[1] / root_module
def _load_from_path(
cls,
root_module: Path,
module_prefix: Optional[str] = None,
predicate: Optional[Callable[[Path], bool]] = None,
):
log.info(f"Loading services from {str(root_module.absolute())}")
sub_module = None
for sub_module in root_module.glob("*"):
if predicate and not predicate(sub_module):
continue
if (
sub_module.is_file()
and sub_module.suffix == ".py"
and not sub_module.stem == "__init__"
):
import_module(f"apiserver.{root_module.stem}.{sub_module.stem}")
if sub_module.is_dir():
import_module(f"apiserver.{root_module.stem}.{sub_module.stem}")
import_module(
".".join(
filter(None, (module_prefix, root_module.stem, sub_module.stem))
)
)
if sub_module.is_dir() and not sub_module.stem == "__pycache__":
import_module(
".".join(
filter(None, (module_prefix, root_module.stem, sub_module.stem))
)
)
# leave no trace of the 'sub_module' local
del sub_module
@ -101,8 +117,14 @@ class ServiceRepo(object):
)
@classmethod
def register(cls, endpoint):
assert isinstance(endpoint, Endpoint)
def load(cls, root_module="services"):
cls._load_from_path(
root_module=Path(__file__).parents[1] / root_module,
module_prefix="apiserver",
)
@classmethod
def register(cls, endpoint: Endpoint):
if cls._endpoints.get(endpoint.name):
if any(
ep.min_version == endpoint.min_version
@ -149,7 +171,6 @@ class ServiceRepo(object):
@classmethod
def _resolve_endpoint_from_call(cls, call: APICall) -> Optional[Endpoint]:
assert isinstance(call, APICall)
endpoint = cls._get_endpoint(
call.endpoint_name, call.requested_endpoint_version
)
@ -165,7 +186,6 @@ class ServiceRepo(object):
)
return
assert isinstance(endpoint, Endpoint)
call.actual_endpoint_version = endpoint.min_version
call.requires_authorization = endpoint.authorize
return endpoint
@ -185,7 +205,9 @@ class ServiceRepo(object):
try:
version = PartialVersion(version)
except ValueError as e:
raise RequestPathHasInvalidVersion(version=version, reason=e)
raise errors.bad_request.RequestPathHasInvalidVersion(
version=version, reason=e
)
if cls._check_max_version and version > cls._max_version:
raise InvalidVersionError(
f"Invalid API version (max. supported version is {cls._max_version})"
@ -232,8 +254,6 @@ class ServiceRepo(object):
@classmethod
def handle_call(cls, call: APICall):
try:
assert isinstance(call, APICall)
if call.failed:
raise CallFailedError()
@ -242,8 +262,7 @@ class ServiceRepo(object):
if call.failed:
raise CallFailedError()
with timing_context.TimingContext("service_repo", "validate_call"):
validate_all(call, endpoint)
validate_all(call, endpoint)
if call.failed:
raise CallFailedError()