clearml-server/apiserver/schema/schema_reader.py

263 lines
8.5 KiB
Python
Raw Normal View History

2019-06-10 21:24:35 +00:00
"""
Objects representing schema entities
"""
import json
import re
from operator import attrgetter
from pathlib import Path
2021-01-05 16:47:32 +00:00
from typing import Mapping, Sequence
2019-06-10 21:24:35 +00:00
import attr
from boltons.dictutils import subdict
from pyhocon import ConfigFactory
2021-01-05 14:44:31 +00:00
from apiserver.config_repo import config
2021-01-05 14:28:49 +00:00
from apiserver.utilities.partial_version import PartialVersion
2019-06-10 21:24:35 +00:00
log = config.logger(__file__)
ALL_ROLES = "*"
class EndpointSchema:
REQUEST_KEY = "request"
RESPONSE_KEY = "response"
BATCH_REQUEST_KEY = "batch_request"
DEFINITIONS_KEY = "definitions"
def __init__(
self,
service_name: str,
action_name: str,
version: PartialVersion,
schema: dict,
definitions: dict = None,
):
"""
Class for interacting with the schema of a single endpoint
:param service_name: name of containing service
:param action_name: name of action
:param version: endpoint version
:param schema: endpoint schema
:param definitions: service definitions
"""
self.service_name = service_name
self.action_name = action_name
self.full_name = f"{service_name}.{action_name}"
self.version = version
self.definitions = definitions
self.request_schema = None
self.batch_request_schema = None
if self.REQUEST_KEY in schema:
self.request_schema = {
**schema[self.REQUEST_KEY],
self.DEFINITIONS_KEY: self.definitions,
}
elif self.BATCH_REQUEST_KEY in schema:
self.batch_request_schema = {
**schema[self.BATCH_REQUEST_KEY],
self.DEFINITIONS_KEY: self.definitions,
}
else:
raise RuntimeError(
f"endpoint {self.full_name} version {self.version} "
"has no request or batch_request schema",
schema,
)
self.response_schema = {
**schema[self.RESPONSE_KEY],
"definitions": self.definitions,
}
class EndpointVersionsGroup:
endpoints: Sequence[EndpointSchema]
allow_roles: Sequence[str]
internal: bool
authorize: bool
def __repr__(self):
return (
f"{type(self).__name__}<{self.full_name}, "
f"versions={tuple(e.version for e in self.endpoints)}>"
)
def __init__(
self,
service_name: str,
action_name: str,
conf: dict,
definitions: dict = None,
defaults: dict = None,
):
"""
Represents multiple implementations of a single endpoint, discriminated by API version
:param service_name: name of containing service
:param action_name: name of action
:param conf: mapping between minimum version to endpoint schema
:param definitions: service definitions
:param defaults: service defaults
"""
self.service_name = service_name
self.action_name = action_name
self.full_name = f"{service_name}.{action_name}"
self.definitions = definitions or {}
self.defaults = defaults or {}
self.internal = self._pop_attr_with_default(conf, "internal")
self.allow_roles = self._pop_attr_with_default(conf, "allow_roles")
self.authorize = self._pop_attr_with_default(conf, "authorize")
def parse_version(version):
if not re.match(r"^\d+\.\d+$", version):
raise ValueError(
f"Encountered unrecognized key {version!r} in {self.service_name}.{self.action_name}"
)
return PartialVersion(version)
self.endpoints = sorted(
(
2021-01-05 16:47:32 +00:00
EndpointSchema(
2019-06-10 21:24:35 +00:00
service_name=self.service_name,
action_name=self.action_name,
version=parse_version(version),
schema=endpoint_conf,
definitions=self.definitions,
)
for version, endpoint_conf in conf.items()
),
key=attrgetter("version"),
)
def allows(self, role):
return ALL_ROLES in self.allow_roles or role in self.allow_roles
def _pop_attr_with_default(self, conf, attr):
return conf.pop(attr, self.defaults[attr])
def get_for_version(self, min_version: PartialVersion):
"""
Return endpoint schema for version
"""
if not self.endpoints:
raise ValueError(f"endpoint group {self} has no versions")
for endpoint in self.endpoints:
if min_version <= endpoint.version:
return endpoint
raise ValueError(
f"min_version {min_version} is higher than highest version in group {self}"
)
class Service:
endpoint_groups: Mapping[str, EndpointVersionsGroup]
def __init__(self, name: str, conf: dict, api_defaults: dict):
"""
Represents schema of one service
:param name: name of service
:param conf: service configuration, containing endpoint groups and other details
:param api_defaults: API-wide endpoint attributes default values
"""
self.name = name
conf = subdict(conf, drop=("_description", "_references"))
self.defaults = {**api_defaults, **conf.pop("_default", {})}
self.definitions = conf.pop("_definitions", None)
self.endpoint_groups: Mapping[str, EndpointVersionsGroup] = {
2021-01-05 16:47:32 +00:00
endpoint_name: EndpointVersionsGroup(
2019-06-10 21:24:35 +00:00
service_name=self.name,
action_name=endpoint_name,
conf=endpoint_conf,
defaults=self.defaults,
definitions=self.definitions,
)
for endpoint_name, endpoint_conf in conf.items()
}
2021-01-05 16:30:59 +00:00
class Schema:
services: Mapping[str, Service]
def __init__(self, services: dict, api_defaults: dict):
"""
Represents the entire API schema
:param services: services schema
:param api_defaults: default values of service configuration
"""
self.api_defaults = api_defaults
self.services = {
2021-01-05 16:47:32 +00:00
name: Service(name, conf, api_defaults=self.api_defaults)
2021-01-05 16:30:59 +00:00
for name, conf in services.items()
}
2019-06-10 21:24:35 +00:00
@attr.s()
class SchemaReader:
2021-01-05 16:50:42 +00:00
root = Path(__file__).parent / "services"
2021-01-05 16:30:59 +00:00
cache_path: Path = None
2019-06-10 21:24:35 +00:00
def __attrs_post_init__(self):
if not self.cache_path:
2021-01-05 16:50:42 +00:00
self.cache_path = self.root / "_cache.json"
2019-06-10 21:24:35 +00:00
@staticmethod
def mod_time(path):
"""
return file modification time
"""
return path.stat().st_mtime
@staticmethod
def read_file(path):
return ConfigFactory.parse_file(path).as_plain_ordered_dict()
def get_schema(self):
"""
Parse the API schema to schema object.
Load from config files and write to cache file if possible.
"""
services = [
service
2021-01-05 16:50:42 +00:00
for service in self.root.glob("*.conf")
2019-06-10 21:24:35 +00:00
if not service.name.startswith("_")
]
current_services_names = {path.stem for path in services}
try:
if self.mod_time(self.cache_path) >= max(map(self.mod_time, services)):
log.info("loading schema from cache")
result = json.loads(self.cache_path.read_text())
cached_services_names = set(result.pop("services_names", []))
if cached_services_names == current_services_names:
return Schema(**result)
else:
log.info(
f"found services files changed: "
f"added: {list(current_services_names - cached_services_names)}, "
f"removed: {list(cached_services_names - current_services_names)}"
)
except (IOError, KeyError, TypeError, ValueError, AttributeError) as ex:
log.warning(f"failed loading cache: {ex}")
log.info("regenerating schema cache")
services = {path.stem: self.read_file(path) for path in services}
2021-01-05 16:50:42 +00:00
api_defaults = self.read_file(self.root / "_api_defaults.conf")
2019-06-10 21:24:35 +00:00
try:
self.cache_path.write_text(
json.dumps(
dict(
services_names=list(current_services_names),
services=services,
api_defaults=api_defaults,
)
)
)
except IOError:
log.exception(f"failed cache file to {self.cache_path}")
return Schema(services, api_defaults)