clearml-server/apiserver/service_repo/service_repo.py
2021-01-05 16:44:31 +02:00

297 lines
11 KiB
Python

import re
from importlib import import_module
from itertools import chain
from pathlib import Path
from typing import cast, Iterable, List, MutableMapping, Optional, Tuple
import jsonmodels.models
from apiserver import timing_context
from apiserver.apierrors import APIError
from apiserver.apierrors.errors.bad_request import RequestPathHasInvalidVersion
from apiserver.api_version import __version__ as _api_version_
from apiserver.config_repo import config
from apiserver.utilities.partial_version import PartialVersion
from .apicall import APICall
from .endpoint import Endpoint
from .errors import MalformedPathError, InvalidVersionError, CallFailedError
from .util import parse_return_stack_on_code
from .validators import validate_all
log = config.logger(__file__)
class ServiceRepo(object):
_endpoints: MutableMapping[str, List[Endpoint]] = {}
"""
Registered endpoints, in the format of {endpoint_name: Endpoint)}
the list of endpoints is sorted by min_version
"""
_version_required = config.get("apiserver.version.required")
""" If version is required, parsing will fail for endpoint paths that do not contain a valid version """
_check_max_version = config.get("apiserver.version.check_max_version")
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion(".".join(_api_version_.split(".")[:2]))
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (
re.compile(
r"^/?v(?P<endpoint_version>\d+\.?\d+)/(?P<endpoint_name>[a-zA-Z_]\w+\.[a-zA-Z_]\w+)/?$"
)
if config.get("apiserver.version.required")
else re.compile(
r"^/?(v(?P<endpoint_version>\d+\.?\d+)/)?(?P<endpoint_name>[a-zA-Z_]\w+\.[a-zA-Z_]\w+)/?$"
)
)
"""
Endpoint structure expressions. We have two expressions, one with optional version part.
Constraints for the first (strict) expression:
1. May start with a leading '/'
2. Followed by a version number (int or float) preceded by a leading 'v'
3. Followed by a '/'
4. Followed by a service name, which must start with an english letter (lower or upper case) or underscore,
and followed by any number of alphanumeric or underscore characters
5. Followed by a '.'
6. Followed by an action name, which must start with an english letter (lower or upper case) or underscore,
and followed by any number of alphanumeric or underscore characters
7. May end with a leading '/'
The second (optional version) expression does not require steps 2 and 3.
"""
_return_stack = config.get("apiserver.return_stack")
""" return stack trace on error """
_return_stack_on_code = parse_return_stack_on_code(
config.get("apiserver.return_stack_on_code", {})
)
""" if 'return_stack' is true and error contains a return code, return stack trace only for these error codes """
_credentials = config["secure.credentials.apiserver"]
""" Api Server credentials used for intra-service communication """
_token = None
""" Token for internal calls """
@classmethod
def load(cls, root_module="services"):
root_module = Path(root_module)
sub_module = None
for sub_module in root_module.glob("*"):
if (
sub_module.is_file()
and sub_module.suffix == ".py"
and not sub_module.stem == "__init__"
):
import_module(f"{root_module.stem}.{sub_module.stem}")
if sub_module.is_dir():
import_module(f"{root_module.stem}.{sub_module.stem}")
# leave no trace of the 'sub_module' local
del sub_module
cls._max_version = max(
cls._max_version,
max(
ep.min_version
for ep in cast(Iterable[Endpoint], chain(*cls._endpoints.values()))
),
)
@classmethod
def register(cls, endpoint):
assert isinstance(endpoint, Endpoint)
if cls._endpoints.get(endpoint.name):
if any(
ep.min_version == endpoint.min_version
for ep in cls._endpoints[endpoint.name]
):
raise Exception(
f"Trying to register an existing endpoint. name={endpoint.name}, version={endpoint.min_version}"
)
else:
cls._endpoints[endpoint.name].append(endpoint)
else:
cls._endpoints[endpoint.name] = [endpoint]
cls._endpoints[endpoint.name].sort(key=lambda ep: ep.min_version, reverse=True)
@classmethod
def endpoint_names(cls):
return sorted(cls._endpoints.keys())
@classmethod
def endpoints_summary(cls):
return {
"endpoints": {
name: list(map(Endpoint.to_dict, eps))
for name, eps in cls._endpoints.items()
},
"models": {},
}
@classmethod
def max_endpoint_version(cls) -> PartialVersion:
return cls._max_version
@classmethod
def _get_endpoint(cls, name, version) -> Optional[Endpoint]:
versions = cls._endpoints.get(name)
if not versions:
return None
try:
return next(ep for ep in versions if ep.min_version <= version)
except StopIteration:
# no appropriate version found
return None
@classmethod
def _resolve_endpoint_from_call(cls, call: APICall) -> Optional[Endpoint]:
assert isinstance(call, APICall)
endpoint = cls._get_endpoint(
call.endpoint_name, call.requested_endpoint_version
)
if endpoint is None:
call.log_api = False
call.set_error_result(
msg=(
f"Unable to find endpoint for name {call.endpoint_name} "
f"and version {call.requested_endpoint_version}"
),
code=404,
subcode=0,
)
return
assert isinstance(endpoint, Endpoint)
call.actual_endpoint_version = endpoint.min_version
call.requires_authorization = endpoint.authorize
return endpoint
@classmethod
def parse_endpoint_path(cls, path: str) -> Tuple[PartialVersion, str]:
""" Parse endpoint version, service and action from request path. """
m = cls._endpoint_exp.match(path)
if not m:
raise MalformedPathError("Invalid request path %s" % path)
endpoint_name = m.group("endpoint_name")
version = m.group("endpoint_version")
if version is None:
# If endpoint is available, use the max version
version = cls._max_version
else:
try:
version = PartialVersion(version)
except ValueError as e:
raise RequestPathHasInvalidVersion(version=version, reason=e)
if cls._check_max_version and version > cls._max_version:
raise InvalidVersionError(
f"Invalid API version (max. supported version is {cls._max_version})"
)
return version, endpoint_name
@classmethod
def _should_return_stack(cls, code: int, subcode: int) -> bool:
if not cls._return_stack or code not in cls._return_stack_on_code:
return False
if subcode is None:
# Code in dict, but no subcode. We'll allow it.
return True
subcode_list = cls._return_stack_on_code.get(code)
if subcode_list is None:
# if the code is there but we don't have any subcode list, always return stack
return True
return subcode in subcode_list
@classmethod
def _validate_call(cls, call: APICall) -> Optional[Endpoint]:
endpoint = cls._resolve_endpoint_from_call(call)
if call.failed:
return
validate_all(call, endpoint)
return endpoint
@classmethod
def validate_call(cls, call: APICall):
cls._validate_call(call)
@classmethod
def _get_company(
cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False
) -> Optional[str]:
authorize = endpoint and endpoint.authorize
if ignore_error or not authorize:
try:
return call.identity.company
except Exception:
return None
return call.identity.company
@classmethod
def handle_call(cls, call: APICall):
try:
assert isinstance(call, APICall)
if call.failed:
raise CallFailedError()
endpoint = cls._resolve_endpoint_from_call(call)
if call.failed:
raise CallFailedError()
with timing_context.TimingContext("service_repo", "validate_call"):
validate_all(call, endpoint)
if call.failed:
raise CallFailedError()
# In case call does not require authorization, parsing the identity.company might raise an exception
company = cls._get_company(call, endpoint)
ret = endpoint.func(call, company, call.data_model)
# allow endpoints to return dict or model (instead of setting them explicitly on the call)
if ret is not None:
if isinstance(ret, jsonmodels.models.Base):
call.result.data_model = ret
elif isinstance(ret, dict):
call.result.data = ret
except APIError as ex:
# report stack trace only for gene
include_stack = cls._return_stack and cls._should_return_stack(
ex.code, ex.subcode
)
call.set_error_result(
code=ex.code,
subcode=ex.subcode,
msg=str(ex),
include_stack=include_stack,
)
except CallFailedError:
# Do nothing, let 'finally' wrap up
pass
except Exception as ex:
log.exception(ex)
call.set_error_result(
code=500, subcode=0, msg=str(ex), include_stack=cls._return_stack
)
finally:
content, content_type = call.get_response()
call.mark_end()
console_msg = f"Returned {call.result.code} for {call.endpoint_name} in {call.duration}ms"
if call.result.code < 300:
log.info(console_msg)
else:
console_msg = f"{console_msg}, msg={call.result.msg}"
if call.result.code < 500:
log.warn(console_msg)
else:
log.error(console_msg)
return content, content_type