mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 19:06:55 +00:00
331 lines
12 KiB
Python
331 lines
12 KiB
Python
import re
|
|
from importlib import import_module
|
|
from itertools import chain
|
|
from pathlib import Path
|
|
from typing import cast, Iterable, List, MutableMapping, Optional, Tuple, Callable
|
|
|
|
import jsonmodels.models
|
|
|
|
from apiserver.apierrors import APIError, errors
|
|
from apiserver.config_repo import config
|
|
from apiserver.database.errors import translate_errors_context
|
|
from apiserver.utilities.partial_version import PartialVersion
|
|
from .apicall import APICall
|
|
from .auth import Identity
|
|
from .endpoint import Endpoint
|
|
from .errors import MalformedPathError, InvalidVersionError, CallFailedError
|
|
from .util import parse_return_stack_on_code
|
|
from .validators import (
|
|
validate_data,
|
|
validate_auth,
|
|
validate_role,
|
|
validate_impersonation,
|
|
)
|
|
|
|
log = config.logger(__file__)
|
|
|
|
|
|
class ServiceRepo(object):
|
|
_endpoints: MutableMapping[str, List[Endpoint]] = {}
|
|
"""
|
|
Registered endpoints, in the format of {endpoint_name: Endpoint)}
|
|
the list of endpoints is sorted by min_version
|
|
"""
|
|
|
|
_version_required = config.get("apiserver.version.required")
|
|
""" If version is required, parsing will fail for endpoint paths that do not contain a valid version """
|
|
|
|
_check_max_version = config.get("apiserver.version.check_max_version")
|
|
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
|
|
maximum """
|
|
|
|
_max_version = PartialVersion("2.27")
|
|
""" Maximum version number (the highest min_version value across all endpoints) """
|
|
|
|
_endpoint_exp = (
|
|
re.compile(
|
|
r"^/?v(?P<endpoint_version>\d+\.?\d+)/(?P<endpoint_name>[a-zA-Z_]\w+\.[a-zA-Z_]\w+)/?$"
|
|
)
|
|
if config.get("apiserver.version.required")
|
|
else re.compile(
|
|
r"^/?(v(?P<endpoint_version>\d+\.?\d+)/)?(?P<endpoint_name>[a-zA-Z_]\w+\.[a-zA-Z_]\w+)/?$"
|
|
)
|
|
)
|
|
"""
|
|
Endpoint structure expressions. We have two expressions, one with optional version part.
|
|
Constraints for the first (strict) expression:
|
|
1. May start with a leading '/'
|
|
2. Followed by a version number (int or float) preceded by a leading 'v'
|
|
3. Followed by a '/'
|
|
4. Followed by a service name, which must start with an english letter (lower or upper case) or underscore,
|
|
and followed by any number of alphanumeric or underscore characters
|
|
5. Followed by a '.'
|
|
6. Followed by an action name, which must start with an english letter (lower or upper case) or underscore,
|
|
and followed by any number of alphanumeric or underscore characters
|
|
7. May end with a leading '/'
|
|
|
|
The second (optional version) expression does not require steps 2 and 3.
|
|
"""
|
|
|
|
_return_stack = config.get("apiserver.return_stack")
|
|
""" return stack trace on error """
|
|
|
|
_return_stack_on_code = parse_return_stack_on_code(
|
|
config.get("apiserver.return_stack_on_code", {})
|
|
)
|
|
""" if 'return_stack' is true and error contains a return code, return stack trace only for these error codes """
|
|
|
|
_credentials = config["secure.credentials.apiserver"]
|
|
""" Api Server credentials used for intra-service communication """
|
|
|
|
_token = None
|
|
""" Token for internal calls """
|
|
|
|
@classmethod
|
|
def _load_from_path(
|
|
cls,
|
|
root_module: Path,
|
|
module_prefix: Optional[str] = None,
|
|
predicate: Optional[Callable[[Path], bool]] = None,
|
|
):
|
|
log.info(f"Loading services from {str(root_module.absolute())}")
|
|
sub_module = None
|
|
for sub_module in root_module.glob("*"):
|
|
if predicate and not predicate(sub_module):
|
|
continue
|
|
|
|
if (
|
|
sub_module.is_file()
|
|
and sub_module.suffix == ".py"
|
|
and not sub_module.stem == "__init__"
|
|
):
|
|
import_module(
|
|
".".join(
|
|
filter(None, (module_prefix, root_module.stem, sub_module.stem))
|
|
)
|
|
)
|
|
|
|
if sub_module.is_dir() and not sub_module.stem == "__pycache__":
|
|
import_module(
|
|
".".join(
|
|
filter(None, (module_prefix, root_module.stem, sub_module.stem))
|
|
)
|
|
)
|
|
|
|
# leave no trace of the 'sub_module' local
|
|
del sub_module
|
|
|
|
cls._max_version = max(
|
|
cls._max_version,
|
|
max(
|
|
ep.min_version
|
|
for ep in cast(Iterable[Endpoint], chain(*cls._endpoints.values()))
|
|
),
|
|
)
|
|
|
|
@classmethod
|
|
def load(cls, root_module="services"):
|
|
cls._load_from_path(
|
|
root_module=Path(__file__).parents[1] / root_module,
|
|
module_prefix="apiserver",
|
|
)
|
|
|
|
@classmethod
|
|
def register(cls, endpoint: Endpoint):
|
|
if cls._endpoints.get(endpoint.name):
|
|
if any(
|
|
ep.min_version == endpoint.min_version
|
|
for ep in cls._endpoints[endpoint.name]
|
|
):
|
|
raise Exception(
|
|
f"Trying to register an existing endpoint. name={endpoint.name}, version={endpoint.min_version}"
|
|
)
|
|
else:
|
|
cls._endpoints[endpoint.name].append(endpoint)
|
|
else:
|
|
cls._endpoints[endpoint.name] = [endpoint]
|
|
|
|
cls._endpoints[endpoint.name].sort(key=lambda ep: ep.min_version, reverse=True)
|
|
|
|
@classmethod
|
|
def endpoint_names(cls):
|
|
return sorted(cls._endpoints.keys())
|
|
|
|
@classmethod
|
|
def endpoints_summary(cls):
|
|
return {
|
|
"endpoints": {
|
|
name: list(map(Endpoint.to_dict, eps))
|
|
for name, eps in cls._endpoints.items()
|
|
},
|
|
"models": {},
|
|
}
|
|
|
|
@classmethod
|
|
def max_endpoint_version(cls) -> PartialVersion:
|
|
return cls._max_version
|
|
|
|
@classmethod
|
|
def _get_endpoint(cls, name, version) -> Optional[Endpoint]:
|
|
versions = cls._endpoints.get(name)
|
|
if not versions:
|
|
return None
|
|
try:
|
|
return next(ep for ep in versions if ep.min_version <= version)
|
|
except StopIteration:
|
|
# no appropriate version found
|
|
return None
|
|
|
|
@classmethod
|
|
def _resolve_endpoint_from_call(cls, call: APICall) -> Optional[Endpoint]:
|
|
endpoint = cls._get_endpoint(
|
|
call.endpoint_name, call.requested_endpoint_version
|
|
)
|
|
if endpoint is None:
|
|
call.log_api = False
|
|
call.set_error_result(
|
|
msg=(
|
|
f"Unable to find endpoint for name {call.endpoint_name} "
|
|
f"and version {call.requested_endpoint_version}"
|
|
),
|
|
code=404,
|
|
subcode=0,
|
|
)
|
|
return
|
|
|
|
call.actual_endpoint_version = endpoint.min_version
|
|
call.requires_authorization = endpoint.authorize
|
|
return endpoint
|
|
|
|
@classmethod
|
|
def parse_endpoint_path(cls, path: str) -> Tuple[PartialVersion, str]:
|
|
""" Parse endpoint version, service and action from request path. """
|
|
m = cls._endpoint_exp.match(path)
|
|
if not m:
|
|
raise MalformedPathError("Invalid request path %s" % path)
|
|
endpoint_name = m.group("endpoint_name")
|
|
version = m.group("endpoint_version")
|
|
if version is None:
|
|
# If endpoint is available, use the max version
|
|
version = cls._max_version
|
|
else:
|
|
try:
|
|
version = PartialVersion(version)
|
|
except ValueError as e:
|
|
raise errors.bad_request.RequestPathHasInvalidVersion(
|
|
version=version, reason=e
|
|
)
|
|
if cls._check_max_version and version > cls._max_version:
|
|
raise InvalidVersionError(
|
|
f"Invalid API version (max. supported version is {cls._max_version})"
|
|
)
|
|
return version, endpoint_name
|
|
|
|
@classmethod
|
|
def _should_return_stack(cls, code: int, subcode: int) -> bool:
|
|
if not cls._return_stack or code not in cls._return_stack_on_code:
|
|
return False
|
|
if subcode is None:
|
|
# Code in dict, but no subcode. We'll allow it.
|
|
return True
|
|
subcode_list = cls._return_stack_on_code.get(code)
|
|
if subcode_list is None:
|
|
# if the code is there but we don't have any subcode list, always return stack
|
|
return True
|
|
return subcode in subcode_list
|
|
|
|
@classmethod
|
|
def _get_identity(
|
|
cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False
|
|
) -> Optional[Identity]:
|
|
authorize = endpoint and endpoint.authorize
|
|
if ignore_error or not authorize:
|
|
try:
|
|
return call.identity
|
|
except Exception:
|
|
return None
|
|
return call.identity
|
|
|
|
@classmethod
|
|
def _get_company(
|
|
cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False
|
|
) -> Optional[str]:
|
|
identity = cls._get_identity(call, endpoint=endpoint, ignore_error=ignore_error)
|
|
return None if identity is None else identity.company
|
|
|
|
@classmethod
|
|
def handle_call(cls, call: APICall, load_data_callback: Callable = None):
|
|
company = None
|
|
try:
|
|
if call.failed:
|
|
raise CallFailedError()
|
|
|
|
endpoint = cls._resolve_endpoint_from_call(call)
|
|
|
|
if call.failed:
|
|
raise CallFailedError()
|
|
|
|
validate_auth(endpoint, call)
|
|
validate_role(endpoint, call)
|
|
if validate_impersonation(endpoint, call):
|
|
# if impersonating, validate role again
|
|
validate_role(endpoint, call)
|
|
|
|
if load_data_callback:
|
|
load_data_callback(call)
|
|
if call.failed:
|
|
raise CallFailedError()
|
|
|
|
validate_data(call, endpoint)
|
|
|
|
if call.failed:
|
|
raise CallFailedError()
|
|
|
|
# In case call does not require authorization, parsing the identity.company might raise an exception
|
|
company = cls._get_company(call, endpoint)
|
|
|
|
with translate_errors_context():
|
|
ret = endpoint.func(call, company, call.data_model)
|
|
|
|
# allow endpoints to return dict or model (instead of setting them explicitly on the call)
|
|
if ret is not None:
|
|
if isinstance(ret, jsonmodels.models.Base):
|
|
call.result.data_model = ret
|
|
elif isinstance(ret, dict):
|
|
call.result.data = ret
|
|
|
|
except APIError as ex:
|
|
# report stack trace only for gene
|
|
include_stack = cls._return_stack and cls._should_return_stack(
|
|
ex.code, ex.subcode
|
|
)
|
|
call.set_error_result(
|
|
code=ex.code,
|
|
subcode=ex.subcode,
|
|
msg=str(ex),
|
|
include_stack=include_stack,
|
|
)
|
|
except CallFailedError:
|
|
# Do nothing, let 'finally' wrap up
|
|
pass
|
|
except Exception as ex:
|
|
log.exception(ex)
|
|
call.set_error_result(
|
|
code=500, subcode=0, msg=str(ex), include_stack=cls._return_stack
|
|
)
|
|
finally:
|
|
content, content_type = call.get_response()
|
|
call.mark_end()
|
|
|
|
console_msg = f"Returned {call.result.code} for {call.endpoint_name} in {call.duration}ms"
|
|
if call.result.code < 300:
|
|
log.info(console_msg)
|
|
else:
|
|
console_msg = f"{console_msg}, msg={call.result.msg}"
|
|
if call.result.code < 500:
|
|
log.warn(console_msg)
|
|
else:
|
|
log.error(console_msg)
|
|
|
|
return content, content_type, company
|