mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 02:46:53 +00:00
Refactor APICall and schema validation
This commit is contained in:
parent
23736efbc3
commit
bdf6c353bd
3
apiserver/schema/__init__.py
Normal file
3
apiserver/schema/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .schema_reader import EndpointSchema, EndpointVersionsGroup, SchemaReader, Schema
|
||||
|
||||
__all__ = [EndpointSchema, EndpointVersionsGroup, SchemaReader, Schema]
|
@ -248,9 +248,9 @@ def remove_description(dct):
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
def main(here: str):
|
||||
args = parse_args()
|
||||
meta = load_hocon(os.path.dirname(__file__) + "/meta.conf")
|
||||
meta = load_hocon(here + "/meta.conf")
|
||||
validator_for(meta).check_schema(meta)
|
||||
|
||||
driver = LazyDriver()
|
||||
@ -300,4 +300,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main(here=os.path.dirname(__file__))
|
||||
|
@ -5,7 +5,7 @@ import json
|
||||
import re
|
||||
from operator import attrgetter
|
||||
from pathlib import Path
|
||||
from typing import Mapping, Sequence
|
||||
from typing import Mapping, Sequence, Type
|
||||
|
||||
import attr
|
||||
from boltons.dictutils import subdict
|
||||
@ -14,7 +14,6 @@ from pyhocon import ConfigFactory
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.utilities.partial_version import PartialVersion
|
||||
|
||||
HERE = Path(__file__)
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@ -120,7 +119,7 @@ class EndpointVersionsGroup:
|
||||
|
||||
self.endpoints = sorted(
|
||||
(
|
||||
EndpointSchema(
|
||||
SchemaReader.endpoint_schema_cls(
|
||||
service_name=self.service_name,
|
||||
action_name=self.action_name,
|
||||
version=parse_version(version),
|
||||
@ -168,7 +167,7 @@ class Service:
|
||||
self.defaults = {**api_defaults, **conf.pop("_default", {})}
|
||||
self.definitions = conf.pop("_definitions", None)
|
||||
self.endpoint_groups: Mapping[str, EndpointVersionsGroup] = {
|
||||
endpoint_name: EndpointVersionsGroup(
|
||||
endpoint_name: SchemaReader.endpoint_versions_group_cls(
|
||||
service_name=self.name,
|
||||
action_name=endpoint_name,
|
||||
conf=endpoint_conf,
|
||||
@ -179,10 +178,30 @@ class Service:
|
||||
}
|
||||
|
||||
|
||||
class Schema:
|
||||
services: Mapping[str, Service]
|
||||
|
||||
def __init__(self, services: dict, api_defaults: dict):
|
||||
"""
|
||||
Represents the entire API schema
|
||||
:param services: services schema
|
||||
:param api_defaults: default values of service configuration
|
||||
"""
|
||||
self.api_defaults = api_defaults
|
||||
self.services = {
|
||||
name: SchemaReader.service_cls(name, conf, api_defaults=self.api_defaults)
|
||||
for name, conf in services.items()
|
||||
}
|
||||
|
||||
|
||||
@attr.s()
|
||||
class SchemaReader:
|
||||
root: Path = attr.ib(default=HERE.parent / "schema/services", converter=Path)
|
||||
cache_path: Path = attr.ib(default=None)
|
||||
service_cls: Type[Service] = Service
|
||||
endpoint_versions_group_cls: Type[EndpointVersionsGroup] = EndpointVersionsGroup
|
||||
endpoint_schema_cls: Type[EndpointSchema] = EndpointSchema
|
||||
|
||||
root: Path = Path(__file__).parent / "services"
|
||||
cache_path: Path = None
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
if not self.cache_path:
|
||||
@ -246,22 +265,3 @@ class SchemaReader:
|
||||
log.exception(f"failed cache file to {self.cache_path}")
|
||||
|
||||
return Schema(services, api_defaults)
|
||||
|
||||
|
||||
class Schema:
|
||||
services: Mapping[str, Service]
|
||||
|
||||
def __init__(self, services: dict, api_defaults: dict):
|
||||
"""
|
||||
Represents the entire API schema
|
||||
:param services: services schema
|
||||
:param api_defaults: default values of service configuration
|
||||
"""
|
||||
self.api_defaults = api_defaults
|
||||
self.services = {
|
||||
name: Service(name, conf, api_defaults=self.api_defaults)
|
||||
for name, conf in services.items()
|
||||
}
|
||||
|
||||
|
||||
schema = SchemaReader().get_schema()
|
@ -9,11 +9,8 @@ from six import string_types
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.timing_context import TimingContext, TimingStats
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities.partial_version import PartialVersion
|
||||
from .auth import Identity
|
||||
from .auth import Payload as AuthPayload
|
||||
from .errors import CallParsingError
|
||||
from .schema_validator import SchemaValidator
|
||||
|
||||
@ -311,6 +308,8 @@ class APICall(DataContainer):
|
||||
HEADER_FORWARDED_FOR = "X-Forwarded-For"
|
||||
""" Standard headers """
|
||||
|
||||
_call_result_cls = APICallResult
|
||||
|
||||
_transaction_headers = _get_headers("Trx")
|
||||
""" Transaction ID """
|
||||
|
||||
@ -376,7 +375,7 @@ class APICall(DataContainer):
|
||||
self._log_api = True
|
||||
if headers:
|
||||
self._headers.update(headers)
|
||||
self._result = APICallResult()
|
||||
self._result = self._call_result_cls()
|
||||
self._auth = None
|
||||
self._impersonation = None
|
||||
if trx:
|
||||
@ -471,8 +470,6 @@ class APICall(DataContainer):
|
||||
|
||||
@auth.setter
|
||||
def auth(self, value):
|
||||
if value:
|
||||
assert isinstance(value, AuthPayload)
|
||||
self._auth = value
|
||||
|
||||
@property
|
||||
@ -497,12 +494,10 @@ class APICall(DataContainer):
|
||||
|
||||
@impersonation.setter
|
||||
def impersonation(self, value):
|
||||
if value:
|
||||
assert isinstance(value, AuthPayload)
|
||||
self._impersonation = value
|
||||
|
||||
@property
|
||||
def identity(self) -> Identity:
|
||||
def identity(self):
|
||||
if self.impersonation:
|
||||
if not self.impersonation.identity:
|
||||
raise Exception("Missing impersonate identity")
|
||||
@ -543,7 +538,10 @@ class APICall(DataContainer):
|
||||
|
||||
@property
|
||||
def worker(self):
|
||||
return self.get_header(self._worker_headers, "<unknown>")
|
||||
return self.get_worker(default="<unknown>")
|
||||
|
||||
def get_worker(self, default=None):
|
||||
return self.get_header(self._worker_headers, default)
|
||||
|
||||
@property
|
||||
def authorization(self):
|
||||
@ -576,10 +574,16 @@ class APICall(DataContainer):
|
||||
def mark_end(self):
|
||||
self._end_ts = time.time()
|
||||
self._duration = int((self._end_ts - self._start_ts) * 1000)
|
||||
self.stats = TimingStats.aggregate()
|
||||
|
||||
def get_response(self):
|
||||
def make_version_number(version: PartialVersion):
|
||||
def get_response(self, include_stack: bool = False) -> Tuple[Union[dict, str], str]:
|
||||
"""
|
||||
Get the response for this call.
|
||||
:param include_stack: If True, stack trace stored in this call's result should
|
||||
be included in the response (default is False)
|
||||
:return: Response data (encoded according to self.content_type) and the data's content type
|
||||
"""
|
||||
|
||||
def make_version_number(version: PartialVersion) -> Union[None, float, str]:
|
||||
"""
|
||||
Client versions <=2.0 expect expect endpoint versions in float format, otherwise throwing an exception
|
||||
"""
|
||||
@ -610,26 +614,25 @@ class APICall(DataContainer):
|
||||
"result_code": self.result.code,
|
||||
"result_subcode": self.result.subcode,
|
||||
"result_msg": self.result.msg,
|
||||
"error_stack": self.result.traceback,
|
||||
"error_stack": self.result.traceback if include_stack else None,
|
||||
"error_data": self.result.error_data,
|
||||
},
|
||||
"data": self.result.data,
|
||||
}
|
||||
if self.content_type.lower() == JSON_CONTENT_TYPE:
|
||||
with TimingContext("json", "serialization"):
|
||||
try:
|
||||
res = json.dumps(res)
|
||||
except Exception as ex:
|
||||
# JSON serialization may fail, probably problem with data or error_data so pop it and try again
|
||||
if not (self.result.data or self.result.error_data):
|
||||
raise
|
||||
self.result.data = None
|
||||
self.result.error_data = None
|
||||
msg = "Error serializing response data: " + str(ex)
|
||||
self.set_error_result(
|
||||
code=500, subcode=0, msg=msg, include_stack=False
|
||||
)
|
||||
return self.get_response()
|
||||
try:
|
||||
res = json.dumps(res)
|
||||
except Exception as ex:
|
||||
# JSON serialization may fail, probably problem with data or error_data so pop it and try again
|
||||
if not (self.result.data or self.result.error_data):
|
||||
raise
|
||||
self.result.data = None
|
||||
self.result.error_data = None
|
||||
msg = "Error serializing response data: " + str(ex)
|
||||
self.set_error_result(
|
||||
code=500, subcode=0, msg=msg, include_stack=False
|
||||
)
|
||||
return self.get_response()
|
||||
|
||||
return res, self.content_type
|
||||
|
||||
@ -637,7 +640,7 @@ class APICall(DataContainer):
|
||||
self, msg, code=500, subcode=0, include_stack=False, error_data=None
|
||||
):
|
||||
tb = format_exc() if include_stack else None
|
||||
self._result = APICallResult(
|
||||
self._result = self._call_result_cls(
|
||||
data=self._result.data,
|
||||
code=code,
|
||||
subcode=subcode,
|
||||
|
@ -4,7 +4,7 @@ from boltons.iterutils import remap
|
||||
from jsonmodels import models
|
||||
from jsonmodels.errors import FieldNotSupported
|
||||
|
||||
from apiserver.schema import schema
|
||||
from apiserver.services_schema import schema
|
||||
from apiserver.utilities.partial_version import PartialVersion
|
||||
from .apicall import APICall
|
||||
from .schema_validator import SchemaValidator
|
||||
|
@ -282,6 +282,7 @@ class ServiceRepo(object):
|
||||
finally:
|
||||
content, content_type = call.get_response()
|
||||
call.mark_end()
|
||||
|
||||
console_msg = f"Returned {call.result.code} for {call.endpoint_name} in {call.duration}ms"
|
||||
if call.result.code < 300:
|
||||
log.info(console_msg)
|
||||
|
3
apiserver/services_schema.py
Normal file
3
apiserver/services_schema.py
Normal file
@ -0,0 +1,3 @@
|
||||
from apiserver.schema import SchemaReader
|
||||
|
||||
schema = SchemaReader().get_schema()
|
Loading…
Reference in New Issue
Block a user