mirror of
https://github.com/clearml/clearml-server
synced 2025-06-23 08:45:30 +00:00
Refactor service_repo
Code cleanup
This commit is contained in:
parent
64c63d2560
commit
6870d8aba9
@ -1 +1 @@
|
|||||||
from .basic import BasicConfig, ConfigurationError, Factory
|
from .basic import BasicConfig, ConfigurationError
|
||||||
|
@ -6,7 +6,7 @@ from functools import reduce
|
|||||||
from os import getenv
|
from os import getenv
|
||||||
from os.path import expandvars
|
from os.path import expandvars
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Any, Type, TypeVar
|
from typing import List, Any, TypeVar
|
||||||
|
|
||||||
from pyhocon import ConfigTree, ConfigFactory
|
from pyhocon import ConfigTree, ConfigFactory
|
||||||
from pyparsing import (
|
from pyparsing import (
|
||||||
@ -169,29 +169,8 @@ class BasicConfig:
|
|||||||
|
|
||||||
class ConfigurationError(Exception):
|
class ConfigurationError(Exception):
|
||||||
def __init__(self, msg, file_path=None, *args):
|
def __init__(self, msg, file_path=None, *args):
|
||||||
super(ConfigurationError, self).__init__(msg, *args)
|
super().__init__(msg, *args)
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
|
|
||||||
|
|
||||||
ConfigType = TypeVar("ConfigType", bound=BasicConfig)
|
ConfigType = TypeVar("ConfigType", bound=BasicConfig)
|
||||||
|
|
||||||
|
|
||||||
class Factory:
|
|
||||||
_config_cls: Type[ConfigType] = BasicConfig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get(cls) -> BasicConfig:
|
|
||||||
config = cls._config_cls()
|
|
||||||
config.initialize_logging()
|
|
||||||
return config
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def set_cls(cls, cls_: Type[ConfigType]):
|
|
||||||
cls._config_cls = cls_
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Factory",
|
|
||||||
"BasicConfig",
|
|
||||||
"ConfigurationError",
|
|
||||||
]
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
from apiserver.config import Factory
|
from apiserver.config import BasicConfig
|
||||||
|
|
||||||
config = Factory.get()
|
config = BasicConfig()
|
||||||
|
config.initialize_logging()
|
||||||
|
@ -17,7 +17,6 @@ from apiserver.utilities.partial_version import PartialVersion
|
|||||||
|
|
||||||
log = config.logger(__file__)
|
log = config.logger(__file__)
|
||||||
|
|
||||||
root = Path(__file__).parent / "services"
|
|
||||||
ALL_ROLES = "*"
|
ALL_ROLES = "*"
|
||||||
|
|
||||||
|
|
||||||
@ -196,11 +195,12 @@ class Schema:
|
|||||||
|
|
||||||
@attr.s()
|
@attr.s()
|
||||||
class SchemaReader:
|
class SchemaReader:
|
||||||
|
root = Path(__file__).parent / "services"
|
||||||
cache_path: Path = None
|
cache_path: Path = None
|
||||||
|
|
||||||
def __attrs_post_init__(self):
|
def __attrs_post_init__(self):
|
||||||
if not self.cache_path:
|
if not self.cache_path:
|
||||||
self.cache_path = root / "_cache.json"
|
self.cache_path = self.root / "_cache.json"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mod_time(path):
|
def mod_time(path):
|
||||||
@ -220,7 +220,7 @@ class SchemaReader:
|
|||||||
"""
|
"""
|
||||||
services = [
|
services = [
|
||||||
service
|
service
|
||||||
for service in root.glob("*.conf")
|
for service in self.root.glob("*.conf")
|
||||||
if not service.name.startswith("_")
|
if not service.name.startswith("_")
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -244,7 +244,7 @@ class SchemaReader:
|
|||||||
|
|
||||||
log.info("regenerating schema cache")
|
log.info("regenerating schema cache")
|
||||||
services = {path.stem: self.read_file(path) for path in services}
|
services = {path.stem: self.read_file(path) for path in services}
|
||||||
api_defaults = self.read_file(root / "_api_defaults.conf")
|
api_defaults = self.read_file(self.root / "_api_defaults.conf")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.cache_path.write_text(
|
self.cache_path.write_text(
|
||||||
|
@ -13,8 +13,6 @@ log = config.logger(__file__)
|
|||||||
|
|
||||||
class RequestHandlers:
|
class RequestHandlers:
|
||||||
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
|
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
|
||||||
_service_repo_cls = ServiceRepo
|
|
||||||
_api_call_cls = APICall
|
|
||||||
|
|
||||||
def before_app_first_request(self):
|
def before_app_first_request(self):
|
||||||
pass
|
pass
|
||||||
@ -27,7 +25,7 @@ class RequestHandlers:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
call = self._create_api_call(request)
|
call = self._create_api_call(request)
|
||||||
content, content_type = self._service_repo_cls.handle_call(call)
|
content, content_type = ServiceRepo.handle_call(call)
|
||||||
|
|
||||||
if call.result.redirect:
|
if call.result.redirect:
|
||||||
response = redirect(call.result.redirect.url, call.result.redirect.code)
|
response = redirect(call.result.redirect.url, call.result.redirect.code)
|
||||||
@ -39,7 +37,10 @@ class RequestHandlers:
|
|||||||
}
|
}
|
||||||
|
|
||||||
response = Response(
|
response = Response(
|
||||||
content, mimetype=content_type, status=call.result.code, headers=headers
|
content,
|
||||||
|
mimetype=content_type,
|
||||||
|
status=call.result.code,
|
||||||
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
if call.result.cookies:
|
if call.result.cookies:
|
||||||
@ -47,13 +48,11 @@ class RequestHandlers:
|
|||||||
kwargs = config.get("apiserver.auth.cookies")
|
kwargs = config.get("apiserver.auth.cookies")
|
||||||
if value is None:
|
if value is None:
|
||||||
kwargs = kwargs.copy()
|
kwargs = kwargs.copy()
|
||||||
kwargs['max_age'] = 0
|
kwargs["max_age"] = 0
|
||||||
kwargs['expires'] = 0
|
kwargs["expires"] = 0
|
||||||
response.set_cookie(key, "", **kwargs)
|
response.set_cookie(key, "", **kwargs)
|
||||||
else:
|
else:
|
||||||
response.set_cookie(
|
response.set_cookie(key, value, **kwargs)
|
||||||
key, value, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
@ -96,7 +95,7 @@ class RequestHandlers:
|
|||||||
call.data = json_body or form or {}
|
call.data = json_body or form or {}
|
||||||
|
|
||||||
def _call_or_empty_with_error(self, call, req, msg, code=500, subcode=0):
|
def _call_or_empty_with_error(self, call, req, msg, code=500, subcode=0):
|
||||||
call = call or self._api_call_cls(
|
call = call or APICall(
|
||||||
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
|
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
|
||||||
)
|
)
|
||||||
call.set_error_result(msg=msg, code=code, subcode=subcode)
|
call.set_error_result(msg=msg, code=code, subcode=subcode)
|
||||||
@ -107,9 +106,11 @@ class RequestHandlers:
|
|||||||
try:
|
try:
|
||||||
# Parse the request path
|
# Parse the request path
|
||||||
path = req.path
|
path = req.path
|
||||||
if self._request_strip_prefix and path.startswith(self._request_strip_prefix):
|
if self._request_strip_prefix and path.startswith(
|
||||||
path = path[len(self._request_strip_prefix):]
|
self._request_strip_prefix
|
||||||
endpoint_version, endpoint_name = self._service_repo_cls.parse_endpoint_path(path)
|
):
|
||||||
|
path = path[len(self._request_strip_prefix) :]
|
||||||
|
endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(path)
|
||||||
|
|
||||||
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
|
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
|
||||||
# in any case, request headers always take precedence.
|
# in any case, request headers always take precedence.
|
||||||
@ -126,7 +127,7 @@ class RequestHandlers:
|
|||||||
) # add (possibly override with) the headers
|
) # add (possibly override with) the headers
|
||||||
|
|
||||||
# Construct call instance
|
# Construct call instance
|
||||||
call = self._api_call_cls(
|
call = APICall(
|
||||||
endpoint_name=endpoint_name,
|
endpoint_name=endpoint_name,
|
||||||
remote_addr=req.remote_addr,
|
remote_addr=req.remote_addr,
|
||||||
endpoint_version=endpoint_version,
|
endpoint_version=endpoint_version,
|
||||||
@ -145,9 +146,13 @@ class RequestHandlers:
|
|||||||
except BadRequest as ex:
|
except BadRequest as ex:
|
||||||
call = self._call_or_empty_with_error(call, req, ex.description, 400)
|
call = self._call_or_empty_with_error(call, req, ex.description, 400)
|
||||||
except BaseError as ex:
|
except BaseError as ex:
|
||||||
call = self._call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode)
|
call = self._call_or_empty_with_error(
|
||||||
|
call, req, ex.msg, ex.code, ex.subcode
|
||||||
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
log.exception("Error creating call")
|
log.exception("Error creating call")
|
||||||
call = self._call_or_empty_with_error(call, req, ex.args[0] if ex.args else type(ex).__name__, 500)
|
call = self._call_or_empty_with_error(
|
||||||
|
call, req, ex.args[0] if ex.args else type(ex).__name__, 500
|
||||||
|
)
|
||||||
|
|
||||||
return call
|
return call
|
||||||
|
@ -186,7 +186,7 @@ class APICallResult(DataContainer):
|
|||||||
error_data=None,
|
error_data=None,
|
||||||
cookies=None,
|
cookies=None,
|
||||||
):
|
):
|
||||||
super(APICallResult, self).__init__(data)
|
super().__init__(data)
|
||||||
self._code = code
|
self._code = code
|
||||||
self._subcode = subcode
|
self._subcode = subcode
|
||||||
self._msg = msg
|
self._msg = msg
|
||||||
@ -297,9 +297,7 @@ class MissingIdentity(Exception):
|
|||||||
|
|
||||||
|
|
||||||
def _get_headers(name: str) -> Tuple[str, ...]:
|
def _get_headers(name: str) -> Tuple[str, ...]:
|
||||||
return tuple(
|
return tuple("-".join(("X", p, name)) for p in ("ClearML", "Trains"))
|
||||||
"-".join(("X", p, name)) for p in ("ClearML", "Trains")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class APICall(DataContainer):
|
class APICall(DataContainer):
|
||||||
@ -308,8 +306,6 @@ class APICall(DataContainer):
|
|||||||
HEADER_FORWARDED_FOR = "X-Forwarded-For"
|
HEADER_FORWARDED_FOR = "X-Forwarded-For"
|
||||||
""" Standard headers """
|
""" Standard headers """
|
||||||
|
|
||||||
_call_result_cls = APICallResult
|
|
||||||
|
|
||||||
_transaction_headers = _get_headers("Trx")
|
_transaction_headers = _get_headers("Trx")
|
||||||
""" Transaction ID """
|
""" Transaction ID """
|
||||||
|
|
||||||
@ -358,7 +354,7 @@ class APICall(DataContainer):
|
|||||||
host=None,
|
host=None,
|
||||||
auth_cookie=None,
|
auth_cookie=None,
|
||||||
):
|
):
|
||||||
super(APICall, self).__init__(data=data, batched_data=batched_data)
|
super().__init__(data=data, batched_data=batched_data)
|
||||||
|
|
||||||
self._id = database.utils.id()
|
self._id = database.utils.id()
|
||||||
self._files = files # currently dic of key to flask's FileStorage)
|
self._files = files # currently dic of key to flask's FileStorage)
|
||||||
@ -375,7 +371,7 @@ class APICall(DataContainer):
|
|||||||
self._log_api = True
|
self._log_api = True
|
||||||
if headers:
|
if headers:
|
||||||
self._headers.update(headers)
|
self._headers.update(headers)
|
||||||
self._result = self._call_result_cls()
|
self._result = APICallResult()
|
||||||
self._auth = None
|
self._auth = None
|
||||||
self._impersonation = None
|
self._impersonation = None
|
||||||
if trx:
|
if trx:
|
||||||
@ -640,7 +636,7 @@ class APICall(DataContainer):
|
|||||||
self, msg, code=500, subcode=0, include_stack=False, error_data=None
|
self, msg, code=500, subcode=0, include_stack=False, error_data=None
|
||||||
):
|
):
|
||||||
tb = format_exc() if include_stack else None
|
tb = format_exc() if include_stack else None
|
||||||
self._result = self._call_result_cls(
|
self._result = APICallResult(
|
||||||
data=self._result.data,
|
data=self._result.data,
|
||||||
code=code,
|
code=code,
|
||||||
subcode=subcode,
|
subcode=subcode,
|
||||||
|
@ -38,7 +38,6 @@ class Endpoint(object):
|
|||||||
:param response_data_model: response jsonschema model, will be validated if validate_schema=False
|
:param response_data_model: response jsonschema model, will be validated if validate_schema=False
|
||||||
:param validate_schema: whether request and response schema should be validated
|
:param validate_schema: whether request and response schema should be validated
|
||||||
"""
|
"""
|
||||||
super(Endpoint, self).__init__()
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.min_version = PartialVersion(min_version)
|
self.min_version = PartialVersion(min_version)
|
||||||
self.func = func
|
self.func = func
|
||||||
|
@ -2,13 +2,11 @@ import re
|
|||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import cast, Iterable, List, MutableMapping, Optional, Tuple
|
from typing import cast, Iterable, List, MutableMapping, Optional, Tuple, Callable
|
||||||
|
|
||||||
import jsonmodels.models
|
import jsonmodels.models
|
||||||
|
|
||||||
from apiserver import timing_context
|
from apiserver.apierrors import APIError, errors
|
||||||
from apiserver.apierrors import APIError
|
|
||||||
from apiserver.apierrors.errors.bad_request import RequestPathHasInvalidVersion
|
|
||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
from apiserver.utilities.partial_version import PartialVersion
|
from apiserver.utilities.partial_version import PartialVersion
|
||||||
from .apicall import APICall
|
from .apicall import APICall
|
||||||
@ -77,18 +75,36 @@ class ServiceRepo(object):
|
|||||||
""" Token for internal calls """
|
""" Token for internal calls """
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, root_module="services"):
|
def _load_from_path(
|
||||||
root_module = Path(__file__).parents[1] / root_module
|
cls,
|
||||||
|
root_module: Path,
|
||||||
|
module_prefix: Optional[str] = None,
|
||||||
|
predicate: Optional[Callable[[Path], bool]] = None,
|
||||||
|
):
|
||||||
|
log.info(f"Loading services from {str(root_module.absolute())}")
|
||||||
sub_module = None
|
sub_module = None
|
||||||
for sub_module in root_module.glob("*"):
|
for sub_module in root_module.glob("*"):
|
||||||
|
if predicate and not predicate(sub_module):
|
||||||
|
continue
|
||||||
|
|
||||||
if (
|
if (
|
||||||
sub_module.is_file()
|
sub_module.is_file()
|
||||||
and sub_module.suffix == ".py"
|
and sub_module.suffix == ".py"
|
||||||
and not sub_module.stem == "__init__"
|
and not sub_module.stem == "__init__"
|
||||||
):
|
):
|
||||||
import_module(f"apiserver.{root_module.stem}.{sub_module.stem}")
|
import_module(
|
||||||
if sub_module.is_dir():
|
".".join(
|
||||||
import_module(f"apiserver.{root_module.stem}.{sub_module.stem}")
|
filter(None, (module_prefix, root_module.stem, sub_module.stem))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if sub_module.is_dir() and not sub_module.stem == "__pycache__":
|
||||||
|
import_module(
|
||||||
|
".".join(
|
||||||
|
filter(None, (module_prefix, root_module.stem, sub_module.stem))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# leave no trace of the 'sub_module' local
|
# leave no trace of the 'sub_module' local
|
||||||
del sub_module
|
del sub_module
|
||||||
|
|
||||||
@ -101,8 +117,14 @@ class ServiceRepo(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(cls, endpoint):
|
def load(cls, root_module="services"):
|
||||||
assert isinstance(endpoint, Endpoint)
|
cls._load_from_path(
|
||||||
|
root_module=Path(__file__).parents[1] / root_module,
|
||||||
|
module_prefix="apiserver",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, endpoint: Endpoint):
|
||||||
if cls._endpoints.get(endpoint.name):
|
if cls._endpoints.get(endpoint.name):
|
||||||
if any(
|
if any(
|
||||||
ep.min_version == endpoint.min_version
|
ep.min_version == endpoint.min_version
|
||||||
@ -149,7 +171,6 @@ class ServiceRepo(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _resolve_endpoint_from_call(cls, call: APICall) -> Optional[Endpoint]:
|
def _resolve_endpoint_from_call(cls, call: APICall) -> Optional[Endpoint]:
|
||||||
assert isinstance(call, APICall)
|
|
||||||
endpoint = cls._get_endpoint(
|
endpoint = cls._get_endpoint(
|
||||||
call.endpoint_name, call.requested_endpoint_version
|
call.endpoint_name, call.requested_endpoint_version
|
||||||
)
|
)
|
||||||
@ -165,7 +186,6 @@ class ServiceRepo(object):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
assert isinstance(endpoint, Endpoint)
|
|
||||||
call.actual_endpoint_version = endpoint.min_version
|
call.actual_endpoint_version = endpoint.min_version
|
||||||
call.requires_authorization = endpoint.authorize
|
call.requires_authorization = endpoint.authorize
|
||||||
return endpoint
|
return endpoint
|
||||||
@ -185,7 +205,9 @@ class ServiceRepo(object):
|
|||||||
try:
|
try:
|
||||||
version = PartialVersion(version)
|
version = PartialVersion(version)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise RequestPathHasInvalidVersion(version=version, reason=e)
|
raise errors.bad_request.RequestPathHasInvalidVersion(
|
||||||
|
version=version, reason=e
|
||||||
|
)
|
||||||
if cls._check_max_version and version > cls._max_version:
|
if cls._check_max_version and version > cls._max_version:
|
||||||
raise InvalidVersionError(
|
raise InvalidVersionError(
|
||||||
f"Invalid API version (max. supported version is {cls._max_version})"
|
f"Invalid API version (max. supported version is {cls._max_version})"
|
||||||
@ -232,8 +254,6 @@ class ServiceRepo(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def handle_call(cls, call: APICall):
|
def handle_call(cls, call: APICall):
|
||||||
try:
|
try:
|
||||||
assert isinstance(call, APICall)
|
|
||||||
|
|
||||||
if call.failed:
|
if call.failed:
|
||||||
raise CallFailedError()
|
raise CallFailedError()
|
||||||
|
|
||||||
@ -242,8 +262,7 @@ class ServiceRepo(object):
|
|||||||
if call.failed:
|
if call.failed:
|
||||||
raise CallFailedError()
|
raise CallFailedError()
|
||||||
|
|
||||||
with timing_context.TimingContext("service_repo", "validate_call"):
|
validate_all(call, endpoint)
|
||||||
validate_all(call, endpoint)
|
|
||||||
|
|
||||||
if call.failed:
|
if call.failed:
|
||||||
raise CallFailedError()
|
raise CallFailedError()
|
||||||
|
Loading…
Reference in New Issue
Block a user