mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 02:46:53 +00:00
Refactor service_repo
Code cleanup
This commit is contained in:
parent
64c63d2560
commit
6870d8aba9
@ -1 +1 @@
|
||||
from .basic import BasicConfig, ConfigurationError, Factory
|
||||
from .basic import BasicConfig, ConfigurationError
|
||||
|
@ -6,7 +6,7 @@ from functools import reduce
|
||||
from os import getenv
|
||||
from os.path import expandvars
|
||||
from pathlib import Path
|
||||
from typing import List, Any, Type, TypeVar
|
||||
from typing import List, Any, TypeVar
|
||||
|
||||
from pyhocon import ConfigTree, ConfigFactory
|
||||
from pyparsing import (
|
||||
@ -169,29 +169,8 @@ class BasicConfig:
|
||||
|
||||
class ConfigurationError(Exception):
|
||||
def __init__(self, msg, file_path=None, *args):
|
||||
super(ConfigurationError, self).__init__(msg, *args)
|
||||
super().__init__(msg, *args)
|
||||
self.file_path = file_path
|
||||
|
||||
|
||||
ConfigType = TypeVar("ConfigType", bound=BasicConfig)
|
||||
|
||||
|
||||
class Factory:
|
||||
_config_cls: Type[ConfigType] = BasicConfig
|
||||
|
||||
@classmethod
|
||||
def get(cls) -> BasicConfig:
|
||||
config = cls._config_cls()
|
||||
config.initialize_logging()
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def set_cls(cls, cls_: Type[ConfigType]):
|
||||
cls._config_cls = cls_
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Factory",
|
||||
"BasicConfig",
|
||||
"ConfigurationError",
|
||||
]
|
||||
|
@ -1,3 +1,4 @@
|
||||
from apiserver.config import Factory
|
||||
from apiserver.config import BasicConfig
|
||||
|
||||
config = Factory.get()
|
||||
config = BasicConfig()
|
||||
config.initialize_logging()
|
||||
|
@ -17,7 +17,6 @@ from apiserver.utilities.partial_version import PartialVersion
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
root = Path(__file__).parent / "services"
|
||||
ALL_ROLES = "*"
|
||||
|
||||
|
||||
@ -196,11 +195,12 @@ class Schema:
|
||||
|
||||
@attr.s()
|
||||
class SchemaReader:
|
||||
root = Path(__file__).parent / "services"
|
||||
cache_path: Path = None
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
if not self.cache_path:
|
||||
self.cache_path = root / "_cache.json"
|
||||
self.cache_path = self.root / "_cache.json"
|
||||
|
||||
@staticmethod
|
||||
def mod_time(path):
|
||||
@ -220,7 +220,7 @@ class SchemaReader:
|
||||
"""
|
||||
services = [
|
||||
service
|
||||
for service in root.glob("*.conf")
|
||||
for service in self.root.glob("*.conf")
|
||||
if not service.name.startswith("_")
|
||||
]
|
||||
|
||||
@ -244,7 +244,7 @@ class SchemaReader:
|
||||
|
||||
log.info("regenerating schema cache")
|
||||
services = {path.stem: self.read_file(path) for path in services}
|
||||
api_defaults = self.read_file(root / "_api_defaults.conf")
|
||||
api_defaults = self.read_file(self.root / "_api_defaults.conf")
|
||||
|
||||
try:
|
||||
self.cache_path.write_text(
|
||||
|
@ -13,8 +13,6 @@ log = config.logger(__file__)
|
||||
|
||||
class RequestHandlers:
|
||||
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
|
||||
_service_repo_cls = ServiceRepo
|
||||
_api_call_cls = APICall
|
||||
|
||||
def before_app_first_request(self):
|
||||
pass
|
||||
@ -27,7 +25,7 @@ class RequestHandlers:
|
||||
|
||||
try:
|
||||
call = self._create_api_call(request)
|
||||
content, content_type = self._service_repo_cls.handle_call(call)
|
||||
content, content_type = ServiceRepo.handle_call(call)
|
||||
|
||||
if call.result.redirect:
|
||||
response = redirect(call.result.redirect.url, call.result.redirect.code)
|
||||
@ -39,7 +37,10 @@ class RequestHandlers:
|
||||
}
|
||||
|
||||
response = Response(
|
||||
content, mimetype=content_type, status=call.result.code, headers=headers
|
||||
content,
|
||||
mimetype=content_type,
|
||||
status=call.result.code,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if call.result.cookies:
|
||||
@ -47,13 +48,11 @@ class RequestHandlers:
|
||||
kwargs = config.get("apiserver.auth.cookies")
|
||||
if value is None:
|
||||
kwargs = kwargs.copy()
|
||||
kwargs['max_age'] = 0
|
||||
kwargs['expires'] = 0
|
||||
kwargs["max_age"] = 0
|
||||
kwargs["expires"] = 0
|
||||
response.set_cookie(key, "", **kwargs)
|
||||
else:
|
||||
response.set_cookie(
|
||||
key, value, **kwargs
|
||||
)
|
||||
response.set_cookie(key, value, **kwargs)
|
||||
|
||||
return response
|
||||
except Exception as ex:
|
||||
@ -96,7 +95,7 @@ class RequestHandlers:
|
||||
call.data = json_body or form or {}
|
||||
|
||||
def _call_or_empty_with_error(self, call, req, msg, code=500, subcode=0):
|
||||
call = call or self._api_call_cls(
|
||||
call = call or APICall(
|
||||
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
|
||||
)
|
||||
call.set_error_result(msg=msg, code=code, subcode=subcode)
|
||||
@ -107,9 +106,11 @@ class RequestHandlers:
|
||||
try:
|
||||
# Parse the request path
|
||||
path = req.path
|
||||
if self._request_strip_prefix and path.startswith(self._request_strip_prefix):
|
||||
path = path[len(self._request_strip_prefix):]
|
||||
endpoint_version, endpoint_name = self._service_repo_cls.parse_endpoint_path(path)
|
||||
if self._request_strip_prefix and path.startswith(
|
||||
self._request_strip_prefix
|
||||
):
|
||||
path = path[len(self._request_strip_prefix) :]
|
||||
endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(path)
|
||||
|
||||
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
|
||||
# in any case, request headers always take precedence.
|
||||
@ -126,7 +127,7 @@ class RequestHandlers:
|
||||
) # add (possibly override with) the headers
|
||||
|
||||
# Construct call instance
|
||||
call = self._api_call_cls(
|
||||
call = APICall(
|
||||
endpoint_name=endpoint_name,
|
||||
remote_addr=req.remote_addr,
|
||||
endpoint_version=endpoint_version,
|
||||
@ -145,9 +146,13 @@ class RequestHandlers:
|
||||
except BadRequest as ex:
|
||||
call = self._call_or_empty_with_error(call, req, ex.description, 400)
|
||||
except BaseError as ex:
|
||||
call = self._call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode)
|
||||
call = self._call_or_empty_with_error(
|
||||
call, req, ex.msg, ex.code, ex.subcode
|
||||
)
|
||||
except Exception as ex:
|
||||
log.exception("Error creating call")
|
||||
call = self._call_or_empty_with_error(call, req, ex.args[0] if ex.args else type(ex).__name__, 500)
|
||||
call = self._call_or_empty_with_error(
|
||||
call, req, ex.args[0] if ex.args else type(ex).__name__, 500
|
||||
)
|
||||
|
||||
return call
|
||||
|
@ -186,7 +186,7 @@ class APICallResult(DataContainer):
|
||||
error_data=None,
|
||||
cookies=None,
|
||||
):
|
||||
super(APICallResult, self).__init__(data)
|
||||
super().__init__(data)
|
||||
self._code = code
|
||||
self._subcode = subcode
|
||||
self._msg = msg
|
||||
@ -297,9 +297,7 @@ class MissingIdentity(Exception):
|
||||
|
||||
|
||||
def _get_headers(name: str) -> Tuple[str, ...]:
|
||||
return tuple(
|
||||
"-".join(("X", p, name)) for p in ("ClearML", "Trains")
|
||||
)
|
||||
return tuple("-".join(("X", p, name)) for p in ("ClearML", "Trains"))
|
||||
|
||||
|
||||
class APICall(DataContainer):
|
||||
@ -308,8 +306,6 @@ class APICall(DataContainer):
|
||||
HEADER_FORWARDED_FOR = "X-Forwarded-For"
|
||||
""" Standard headers """
|
||||
|
||||
_call_result_cls = APICallResult
|
||||
|
||||
_transaction_headers = _get_headers("Trx")
|
||||
""" Transaction ID """
|
||||
|
||||
@ -358,7 +354,7 @@ class APICall(DataContainer):
|
||||
host=None,
|
||||
auth_cookie=None,
|
||||
):
|
||||
super(APICall, self).__init__(data=data, batched_data=batched_data)
|
||||
super().__init__(data=data, batched_data=batched_data)
|
||||
|
||||
self._id = database.utils.id()
|
||||
self._files = files # currently dic of key to flask's FileStorage)
|
||||
@ -375,7 +371,7 @@ class APICall(DataContainer):
|
||||
self._log_api = True
|
||||
if headers:
|
||||
self._headers.update(headers)
|
||||
self._result = self._call_result_cls()
|
||||
self._result = APICallResult()
|
||||
self._auth = None
|
||||
self._impersonation = None
|
||||
if trx:
|
||||
@ -640,7 +636,7 @@ class APICall(DataContainer):
|
||||
self, msg, code=500, subcode=0, include_stack=False, error_data=None
|
||||
):
|
||||
tb = format_exc() if include_stack else None
|
||||
self._result = self._call_result_cls(
|
||||
self._result = APICallResult(
|
||||
data=self._result.data,
|
||||
code=code,
|
||||
subcode=subcode,
|
||||
|
@ -38,7 +38,6 @@ class Endpoint(object):
|
||||
:param response_data_model: response jsonschema model, will be validated if validate_schema=False
|
||||
:param validate_schema: whether request and response schema should be validated
|
||||
"""
|
||||
super(Endpoint, self).__init__()
|
||||
self.name = name
|
||||
self.min_version = PartialVersion(min_version)
|
||||
self.func = func
|
||||
|
@ -2,13 +2,11 @@ import re
|
||||
from importlib import import_module
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import cast, Iterable, List, MutableMapping, Optional, Tuple
|
||||
from typing import cast, Iterable, List, MutableMapping, Optional, Tuple, Callable
|
||||
|
||||
import jsonmodels.models
|
||||
|
||||
from apiserver import timing_context
|
||||
from apiserver.apierrors import APIError
|
||||
from apiserver.apierrors.errors.bad_request import RequestPathHasInvalidVersion
|
||||
from apiserver.apierrors import APIError, errors
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.utilities.partial_version import PartialVersion
|
||||
from .apicall import APICall
|
||||
@ -77,18 +75,36 @@ class ServiceRepo(object):
|
||||
""" Token for internal calls """
|
||||
|
||||
@classmethod
|
||||
def load(cls, root_module="services"):
|
||||
root_module = Path(__file__).parents[1] / root_module
|
||||
def _load_from_path(
|
||||
cls,
|
||||
root_module: Path,
|
||||
module_prefix: Optional[str] = None,
|
||||
predicate: Optional[Callable[[Path], bool]] = None,
|
||||
):
|
||||
log.info(f"Loading services from {str(root_module.absolute())}")
|
||||
sub_module = None
|
||||
for sub_module in root_module.glob("*"):
|
||||
if predicate and not predicate(sub_module):
|
||||
continue
|
||||
|
||||
if (
|
||||
sub_module.is_file()
|
||||
and sub_module.suffix == ".py"
|
||||
and not sub_module.stem == "__init__"
|
||||
):
|
||||
import_module(f"apiserver.{root_module.stem}.{sub_module.stem}")
|
||||
if sub_module.is_dir():
|
||||
import_module(f"apiserver.{root_module.stem}.{sub_module.stem}")
|
||||
import_module(
|
||||
".".join(
|
||||
filter(None, (module_prefix, root_module.stem, sub_module.stem))
|
||||
)
|
||||
)
|
||||
|
||||
if sub_module.is_dir() and not sub_module.stem == "__pycache__":
|
||||
import_module(
|
||||
".".join(
|
||||
filter(None, (module_prefix, root_module.stem, sub_module.stem))
|
||||
)
|
||||
)
|
||||
|
||||
# leave no trace of the 'sub_module' local
|
||||
del sub_module
|
||||
|
||||
@ -101,8 +117,14 @@ class ServiceRepo(object):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def register(cls, endpoint):
|
||||
assert isinstance(endpoint, Endpoint)
|
||||
def load(cls, root_module="services"):
|
||||
cls._load_from_path(
|
||||
root_module=Path(__file__).parents[1] / root_module,
|
||||
module_prefix="apiserver",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def register(cls, endpoint: Endpoint):
|
||||
if cls._endpoints.get(endpoint.name):
|
||||
if any(
|
||||
ep.min_version == endpoint.min_version
|
||||
@ -149,7 +171,6 @@ class ServiceRepo(object):
|
||||
|
||||
@classmethod
|
||||
def _resolve_endpoint_from_call(cls, call: APICall) -> Optional[Endpoint]:
|
||||
assert isinstance(call, APICall)
|
||||
endpoint = cls._get_endpoint(
|
||||
call.endpoint_name, call.requested_endpoint_version
|
||||
)
|
||||
@ -165,7 +186,6 @@ class ServiceRepo(object):
|
||||
)
|
||||
return
|
||||
|
||||
assert isinstance(endpoint, Endpoint)
|
||||
call.actual_endpoint_version = endpoint.min_version
|
||||
call.requires_authorization = endpoint.authorize
|
||||
return endpoint
|
||||
@ -185,7 +205,9 @@ class ServiceRepo(object):
|
||||
try:
|
||||
version = PartialVersion(version)
|
||||
except ValueError as e:
|
||||
raise RequestPathHasInvalidVersion(version=version, reason=e)
|
||||
raise errors.bad_request.RequestPathHasInvalidVersion(
|
||||
version=version, reason=e
|
||||
)
|
||||
if cls._check_max_version and version > cls._max_version:
|
||||
raise InvalidVersionError(
|
||||
f"Invalid API version (max. supported version is {cls._max_version})"
|
||||
@ -232,8 +254,6 @@ class ServiceRepo(object):
|
||||
@classmethod
|
||||
def handle_call(cls, call: APICall):
|
||||
try:
|
||||
assert isinstance(call, APICall)
|
||||
|
||||
if call.failed:
|
||||
raise CallFailedError()
|
||||
|
||||
@ -242,8 +262,7 @@ class ServiceRepo(object):
|
||||
if call.failed:
|
||||
raise CallFailedError()
|
||||
|
||||
with timing_context.TimingContext("service_repo", "validate_call"):
|
||||
validate_all(call, endpoint)
|
||||
validate_all(call, endpoint)
|
||||
|
||||
if call.failed:
|
||||
raise CallFailedError()
|
||||
|
Loading…
Reference in New Issue
Block a user