clearml-server/apiserver/schema.py
2021-01-05 16:22:34 +02:00

268 lines
8.6 KiB
Python

"""
Objects representing schema entities
"""
import json
import re
from operator import attrgetter
from pathlib import Path
from typing import Mapping, Sequence
import attr
from boltons.dictutils import subdict
from pyhocon import ConfigFactory
from config import config
from service_repo.base import PartialVersion
HERE = Path(__file__)
log = config.logger(__file__)
ALL_ROLES = "*"
class EndpointSchema:
REQUEST_KEY = "request"
RESPONSE_KEY = "response"
BATCH_REQUEST_KEY = "batch_request"
DEFINITIONS_KEY = "definitions"
def __init__(
self,
service_name: str,
action_name: str,
version: PartialVersion,
schema: dict,
definitions: dict = None,
):
"""
Class for interacting with the schema of a single endpoint
:param service_name: name of containing service
:param action_name: name of action
:param version: endpoint version
:param schema: endpoint schema
:param definitions: service definitions
"""
self.service_name = service_name
self.action_name = action_name
self.full_name = f"{service_name}.{action_name}"
self.version = version
self.definitions = definitions
self.request_schema = None
self.batch_request_schema = None
if self.REQUEST_KEY in schema:
self.request_schema = {
**schema[self.REQUEST_KEY],
self.DEFINITIONS_KEY: self.definitions,
}
elif self.BATCH_REQUEST_KEY in schema:
self.batch_request_schema = {
**schema[self.BATCH_REQUEST_KEY],
self.DEFINITIONS_KEY: self.definitions,
}
else:
raise RuntimeError(
f"endpoint {self.full_name} version {self.version} "
"has no request or batch_request schema",
schema,
)
self.response_schema = {
**schema[self.RESPONSE_KEY],
"definitions": self.definitions,
}
class EndpointVersionsGroup:
endpoints: Sequence[EndpointSchema]
allow_roles: Sequence[str]
internal: bool
authorize: bool
def __repr__(self):
return (
f"{type(self).__name__}<{self.full_name}, "
f"versions={tuple(e.version for e in self.endpoints)}>"
)
def __init__(
self,
service_name: str,
action_name: str,
conf: dict,
definitions: dict = None,
defaults: dict = None,
):
"""
Represents multiple implementations of a single endpoint, discriminated by API version
:param service_name: name of containing service
:param action_name: name of action
:param conf: mapping between minimum version to endpoint schema
:param definitions: service definitions
:param defaults: service defaults
"""
self.service_name = service_name
self.action_name = action_name
self.full_name = f"{service_name}.{action_name}"
self.definitions = definitions or {}
self.defaults = defaults or {}
self.internal = self._pop_attr_with_default(conf, "internal")
self.allow_roles = self._pop_attr_with_default(conf, "allow_roles")
self.authorize = self._pop_attr_with_default(conf, "authorize")
def parse_version(version):
if not re.match(r"^\d+\.\d+$", version):
raise ValueError(
f"Encountered unrecognized key {version!r} in {self.service_name}.{self.action_name}"
)
return PartialVersion(version)
self.endpoints = sorted(
(
EndpointSchema(
service_name=self.service_name,
action_name=self.action_name,
version=parse_version(version),
schema=endpoint_conf,
definitions=self.definitions,
)
for version, endpoint_conf in conf.items()
),
key=attrgetter("version"),
)
def allows(self, role):
return ALL_ROLES in self.allow_roles or role in self.allow_roles
def _pop_attr_with_default(self, conf, attr):
return conf.pop(attr, self.defaults[attr])
def get_for_version(self, min_version: PartialVersion):
"""
Return endpoint schema for version
"""
if not self.endpoints:
raise ValueError(f"endpoint group {self} has no versions")
for endpoint in self.endpoints:
if min_version <= endpoint.version:
return endpoint
raise ValueError(
f"min_version {min_version} is higher than highest version in group {self}"
)
class Service:
endpoint_groups: Mapping[str, EndpointVersionsGroup]
def __init__(self, name: str, conf: dict, api_defaults: dict):
"""
Represents schema of one service
:param name: name of service
:param conf: service configuration, containing endpoint groups and other details
:param api_defaults: API-wide endpoint attributes default values
"""
self.name = name
conf = subdict(conf, drop=("_description", "_references"))
self.defaults = {**api_defaults, **conf.pop("_default", {})}
self.definitions = conf.pop("_definitions", None)
self.endpoint_groups: Mapping[str, EndpointVersionsGroup] = {
endpoint_name: EndpointVersionsGroup(
service_name=self.name,
action_name=endpoint_name,
conf=endpoint_conf,
defaults=self.defaults,
definitions=self.definitions,
)
for endpoint_name, endpoint_conf in conf.items()
}
@attr.s()
class SchemaReader:
root: Path = attr.ib(default=HERE.parent / "schema/services", converter=Path)
cache_path: Path = attr.ib(default=None)
def __attrs_post_init__(self):
if not self.cache_path:
self.cache_path = self.root / "_cache.json"
@staticmethod
def mod_time(path):
"""
return file modification time
"""
return path.stat().st_mtime
@staticmethod
def read_file(path):
return ConfigFactory.parse_file(path).as_plain_ordered_dict()
def get_schema(self):
"""
Parse the API schema to schema object.
Load from config files and write to cache file if possible.
"""
services = [
service
for service in self.root.glob("*.conf")
if not service.name.startswith("_")
]
current_services_names = {path.stem for path in services}
try:
if self.mod_time(self.cache_path) >= max(map(self.mod_time, services)):
log.info("loading schema from cache")
result = json.loads(self.cache_path.read_text())
cached_services_names = set(result.pop("services_names", []))
if cached_services_names == current_services_names:
return Schema(**result)
else:
log.info(
f"found services files changed: "
f"added: {list(current_services_names - cached_services_names)}, "
f"removed: {list(cached_services_names - current_services_names)}"
)
except (IOError, KeyError, TypeError, ValueError, AttributeError) as ex:
log.warning(f"failed loading cache: {ex}")
log.info("regenerating schema cache")
services = {path.stem: self.read_file(path) for path in services}
api_defaults = self.read_file(self.root / "_api_defaults.conf")
try:
self.cache_path.write_text(
json.dumps(
dict(
services_names=list(current_services_names),
services=services,
api_defaults=api_defaults,
)
)
)
except IOError:
log.exception(f"failed cache file to {self.cache_path}")
return Schema(services, api_defaults)
class Schema:
services: Mapping[str, Service]
def __init__(self, services: dict, api_defaults: dict):
"""
Represents the entire API schema
:param services: services schema
:param api_defaults: default values of service configuration
"""
self.api_defaults = api_defaults
self.services = {
name: Service(name, conf, api_defaults=self.api_defaults)
for name, conf in services.items()
}
schema = SchemaReader().get_schema()