mirror of
https://github.com/clearml/clearml-server
synced 2025-06-09 18:25:38 +00:00
Refactor APICall and schema validation
This commit is contained in:
parent
23736efbc3
commit
bdf6c353bd
3
apiserver/schema/__init__.py
Normal file
3
apiserver/schema/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .schema_reader import EndpointSchema, EndpointVersionsGroup, SchemaReader, Schema
|
||||||
|
|
||||||
|
__all__ = [EndpointSchema, EndpointVersionsGroup, SchemaReader, Schema]
|
@ -248,9 +248,9 @@ def remove_description(dct):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main(here: str):
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
meta = load_hocon(os.path.dirname(__file__) + "/meta.conf")
|
meta = load_hocon(here + "/meta.conf")
|
||||||
validator_for(meta).check_schema(meta)
|
validator_for(meta).check_schema(meta)
|
||||||
|
|
||||||
driver = LazyDriver()
|
driver = LazyDriver()
|
||||||
@ -300,4 +300,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main(here=os.path.dirname(__file__))
|
||||||
|
@ -5,7 +5,7 @@ import json
|
|||||||
import re
|
import re
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Mapping, Sequence
|
from typing import Mapping, Sequence, Type
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from boltons.dictutils import subdict
|
from boltons.dictutils import subdict
|
||||||
@ -14,7 +14,6 @@ from pyhocon import ConfigFactory
|
|||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
from apiserver.utilities.partial_version import PartialVersion
|
from apiserver.utilities.partial_version import PartialVersion
|
||||||
|
|
||||||
HERE = Path(__file__)
|
|
||||||
|
|
||||||
log = config.logger(__file__)
|
log = config.logger(__file__)
|
||||||
|
|
||||||
@ -120,7 +119,7 @@ class EndpointVersionsGroup:
|
|||||||
|
|
||||||
self.endpoints = sorted(
|
self.endpoints = sorted(
|
||||||
(
|
(
|
||||||
EndpointSchema(
|
SchemaReader.endpoint_schema_cls(
|
||||||
service_name=self.service_name,
|
service_name=self.service_name,
|
||||||
action_name=self.action_name,
|
action_name=self.action_name,
|
||||||
version=parse_version(version),
|
version=parse_version(version),
|
||||||
@ -168,7 +167,7 @@ class Service:
|
|||||||
self.defaults = {**api_defaults, **conf.pop("_default", {})}
|
self.defaults = {**api_defaults, **conf.pop("_default", {})}
|
||||||
self.definitions = conf.pop("_definitions", None)
|
self.definitions = conf.pop("_definitions", None)
|
||||||
self.endpoint_groups: Mapping[str, EndpointVersionsGroup] = {
|
self.endpoint_groups: Mapping[str, EndpointVersionsGroup] = {
|
||||||
endpoint_name: EndpointVersionsGroup(
|
endpoint_name: SchemaReader.endpoint_versions_group_cls(
|
||||||
service_name=self.name,
|
service_name=self.name,
|
||||||
action_name=endpoint_name,
|
action_name=endpoint_name,
|
||||||
conf=endpoint_conf,
|
conf=endpoint_conf,
|
||||||
@ -179,10 +178,30 @@ class Service:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Schema:
|
||||||
|
services: Mapping[str, Service]
|
||||||
|
|
||||||
|
def __init__(self, services: dict, api_defaults: dict):
|
||||||
|
"""
|
||||||
|
Represents the entire API schema
|
||||||
|
:param services: services schema
|
||||||
|
:param api_defaults: default values of service configuration
|
||||||
|
"""
|
||||||
|
self.api_defaults = api_defaults
|
||||||
|
self.services = {
|
||||||
|
name: SchemaReader.service_cls(name, conf, api_defaults=self.api_defaults)
|
||||||
|
for name, conf in services.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@attr.s()
|
@attr.s()
|
||||||
class SchemaReader:
|
class SchemaReader:
|
||||||
root: Path = attr.ib(default=HERE.parent / "schema/services", converter=Path)
|
service_cls: Type[Service] = Service
|
||||||
cache_path: Path = attr.ib(default=None)
|
endpoint_versions_group_cls: Type[EndpointVersionsGroup] = EndpointVersionsGroup
|
||||||
|
endpoint_schema_cls: Type[EndpointSchema] = EndpointSchema
|
||||||
|
|
||||||
|
root: Path = Path(__file__).parent / "services"
|
||||||
|
cache_path: Path = None
|
||||||
|
|
||||||
def __attrs_post_init__(self):
|
def __attrs_post_init__(self):
|
||||||
if not self.cache_path:
|
if not self.cache_path:
|
||||||
@ -246,22 +265,3 @@ class SchemaReader:
|
|||||||
log.exception(f"failed cache file to {self.cache_path}")
|
log.exception(f"failed cache file to {self.cache_path}")
|
||||||
|
|
||||||
return Schema(services, api_defaults)
|
return Schema(services, api_defaults)
|
||||||
|
|
||||||
|
|
||||||
class Schema:
|
|
||||||
services: Mapping[str, Service]
|
|
||||||
|
|
||||||
def __init__(self, services: dict, api_defaults: dict):
|
|
||||||
"""
|
|
||||||
Represents the entire API schema
|
|
||||||
:param services: services schema
|
|
||||||
:param api_defaults: default values of service configuration
|
|
||||||
"""
|
|
||||||
self.api_defaults = api_defaults
|
|
||||||
self.services = {
|
|
||||||
name: Service(name, conf, api_defaults=self.api_defaults)
|
|
||||||
for name, conf in services.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
schema = SchemaReader().get_schema()
|
|
@ -9,11 +9,8 @@ from six import string_types
|
|||||||
|
|
||||||
from apiserver import database
|
from apiserver import database
|
||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
from apiserver.timing_context import TimingContext, TimingStats
|
|
||||||
from apiserver.utilities import json
|
from apiserver.utilities import json
|
||||||
from apiserver.utilities.partial_version import PartialVersion
|
from apiserver.utilities.partial_version import PartialVersion
|
||||||
from .auth import Identity
|
|
||||||
from .auth import Payload as AuthPayload
|
|
||||||
from .errors import CallParsingError
|
from .errors import CallParsingError
|
||||||
from .schema_validator import SchemaValidator
|
from .schema_validator import SchemaValidator
|
||||||
|
|
||||||
@ -311,6 +308,8 @@ class APICall(DataContainer):
|
|||||||
HEADER_FORWARDED_FOR = "X-Forwarded-For"
|
HEADER_FORWARDED_FOR = "X-Forwarded-For"
|
||||||
""" Standard headers """
|
""" Standard headers """
|
||||||
|
|
||||||
|
_call_result_cls = APICallResult
|
||||||
|
|
||||||
_transaction_headers = _get_headers("Trx")
|
_transaction_headers = _get_headers("Trx")
|
||||||
""" Transaction ID """
|
""" Transaction ID """
|
||||||
|
|
||||||
@ -376,7 +375,7 @@ class APICall(DataContainer):
|
|||||||
self._log_api = True
|
self._log_api = True
|
||||||
if headers:
|
if headers:
|
||||||
self._headers.update(headers)
|
self._headers.update(headers)
|
||||||
self._result = APICallResult()
|
self._result = self._call_result_cls()
|
||||||
self._auth = None
|
self._auth = None
|
||||||
self._impersonation = None
|
self._impersonation = None
|
||||||
if trx:
|
if trx:
|
||||||
@ -471,8 +470,6 @@ class APICall(DataContainer):
|
|||||||
|
|
||||||
@auth.setter
|
@auth.setter
|
||||||
def auth(self, value):
|
def auth(self, value):
|
||||||
if value:
|
|
||||||
assert isinstance(value, AuthPayload)
|
|
||||||
self._auth = value
|
self._auth = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -497,12 +494,10 @@ class APICall(DataContainer):
|
|||||||
|
|
||||||
@impersonation.setter
|
@impersonation.setter
|
||||||
def impersonation(self, value):
|
def impersonation(self, value):
|
||||||
if value:
|
|
||||||
assert isinstance(value, AuthPayload)
|
|
||||||
self._impersonation = value
|
self._impersonation = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def identity(self) -> Identity:
|
def identity(self):
|
||||||
if self.impersonation:
|
if self.impersonation:
|
||||||
if not self.impersonation.identity:
|
if not self.impersonation.identity:
|
||||||
raise Exception("Missing impersonate identity")
|
raise Exception("Missing impersonate identity")
|
||||||
@ -543,7 +538,10 @@ class APICall(DataContainer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def worker(self):
|
def worker(self):
|
||||||
return self.get_header(self._worker_headers, "<unknown>")
|
return self.get_worker(default="<unknown>")
|
||||||
|
|
||||||
|
def get_worker(self, default=None):
|
||||||
|
return self.get_header(self._worker_headers, default)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def authorization(self):
|
def authorization(self):
|
||||||
@ -576,10 +574,16 @@ class APICall(DataContainer):
|
|||||||
def mark_end(self):
|
def mark_end(self):
|
||||||
self._end_ts = time.time()
|
self._end_ts = time.time()
|
||||||
self._duration = int((self._end_ts - self._start_ts) * 1000)
|
self._duration = int((self._end_ts - self._start_ts) * 1000)
|
||||||
self.stats = TimingStats.aggregate()
|
|
||||||
|
|
||||||
def get_response(self):
|
def get_response(self, include_stack: bool = False) -> Tuple[Union[dict, str], str]:
|
||||||
def make_version_number(version: PartialVersion):
|
"""
|
||||||
|
Get the response for this call.
|
||||||
|
:param include_stack: If True, stack trace stored in this call's result should
|
||||||
|
be included in the response (default is False)
|
||||||
|
:return: Response data (encoded according to self.content_type) and the data's content type
|
||||||
|
"""
|
||||||
|
|
||||||
|
def make_version_number(version: PartialVersion) -> Union[None, float, str]:
|
||||||
"""
|
"""
|
||||||
Client versions <=2.0 expect expect endpoint versions in float format, otherwise throwing an exception
|
Client versions <=2.0 expect expect endpoint versions in float format, otherwise throwing an exception
|
||||||
"""
|
"""
|
||||||
@ -610,13 +614,12 @@ class APICall(DataContainer):
|
|||||||
"result_code": self.result.code,
|
"result_code": self.result.code,
|
||||||
"result_subcode": self.result.subcode,
|
"result_subcode": self.result.subcode,
|
||||||
"result_msg": self.result.msg,
|
"result_msg": self.result.msg,
|
||||||
"error_stack": self.result.traceback,
|
"error_stack": self.result.traceback if include_stack else None,
|
||||||
"error_data": self.result.error_data,
|
"error_data": self.result.error_data,
|
||||||
},
|
},
|
||||||
"data": self.result.data,
|
"data": self.result.data,
|
||||||
}
|
}
|
||||||
if self.content_type.lower() == JSON_CONTENT_TYPE:
|
if self.content_type.lower() == JSON_CONTENT_TYPE:
|
||||||
with TimingContext("json", "serialization"):
|
|
||||||
try:
|
try:
|
||||||
res = json.dumps(res)
|
res = json.dumps(res)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
@ -637,7 +640,7 @@ class APICall(DataContainer):
|
|||||||
self, msg, code=500, subcode=0, include_stack=False, error_data=None
|
self, msg, code=500, subcode=0, include_stack=False, error_data=None
|
||||||
):
|
):
|
||||||
tb = format_exc() if include_stack else None
|
tb = format_exc() if include_stack else None
|
||||||
self._result = APICallResult(
|
self._result = self._call_result_cls(
|
||||||
data=self._result.data,
|
data=self._result.data,
|
||||||
code=code,
|
code=code,
|
||||||
subcode=subcode,
|
subcode=subcode,
|
||||||
|
@ -4,7 +4,7 @@ from boltons.iterutils import remap
|
|||||||
from jsonmodels import models
|
from jsonmodels import models
|
||||||
from jsonmodels.errors import FieldNotSupported
|
from jsonmodels.errors import FieldNotSupported
|
||||||
|
|
||||||
from apiserver.schema import schema
|
from apiserver.services_schema import schema
|
||||||
from apiserver.utilities.partial_version import PartialVersion
|
from apiserver.utilities.partial_version import PartialVersion
|
||||||
from .apicall import APICall
|
from .apicall import APICall
|
||||||
from .schema_validator import SchemaValidator
|
from .schema_validator import SchemaValidator
|
||||||
|
@ -282,6 +282,7 @@ class ServiceRepo(object):
|
|||||||
finally:
|
finally:
|
||||||
content, content_type = call.get_response()
|
content, content_type = call.get_response()
|
||||||
call.mark_end()
|
call.mark_end()
|
||||||
|
|
||||||
console_msg = f"Returned {call.result.code} for {call.endpoint_name} in {call.duration}ms"
|
console_msg = f"Returned {call.result.code} for {call.endpoint_name} in {call.duration}ms"
|
||||||
if call.result.code < 300:
|
if call.result.code < 300:
|
||||||
log.info(console_msg)
|
log.info(console_msg)
|
||||||
|
3
apiserver/services_schema.py
Normal file
3
apiserver/services_schema.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from apiserver.schema import SchemaReader
|
||||||
|
|
||||||
|
schema = SchemaReader().get_schema()
|
Loading…
Reference in New Issue
Block a user