Refactor APICall and schema validation

This commit is contained in:
allegroai 2021-01-05 18:30:59 +02:00
parent 23736efbc3
commit bdf6c353bd
7 changed files with 68 additions and 58 deletions

View File

@ -0,0 +1,3 @@
from .schema_reader import EndpointSchema, EndpointVersionsGroup, SchemaReader, Schema
__all__ = [EndpointSchema, EndpointVersionsGroup, SchemaReader, Schema]

View File

@ -248,9 +248,9 @@ def remove_description(dct):
pass
def main():
def main(here: str):
args = parse_args()
meta = load_hocon(os.path.dirname(__file__) + "/meta.conf")
meta = load_hocon(here + "/meta.conf")
validator_for(meta).check_schema(meta)
driver = LazyDriver()
@ -300,4 +300,4 @@ def main():
if __name__ == "__main__":
main()
main(here=os.path.dirname(__file__))

View File

@ -5,7 +5,7 @@ import json
import re
from operator import attrgetter
from pathlib import Path
from typing import Mapping, Sequence
from typing import Mapping, Sequence, Type
import attr
from boltons.dictutils import subdict
@ -14,7 +14,6 @@ from pyhocon import ConfigFactory
from apiserver.config_repo import config
from apiserver.utilities.partial_version import PartialVersion
HERE = Path(__file__)
log = config.logger(__file__)
@ -120,7 +119,7 @@ class EndpointVersionsGroup:
self.endpoints = sorted(
(
EndpointSchema(
SchemaReader.endpoint_schema_cls(
service_name=self.service_name,
action_name=self.action_name,
version=parse_version(version),
@ -168,7 +167,7 @@ class Service:
self.defaults = {**api_defaults, **conf.pop("_default", {})}
self.definitions = conf.pop("_definitions", None)
self.endpoint_groups: Mapping[str, EndpointVersionsGroup] = {
endpoint_name: EndpointVersionsGroup(
endpoint_name: SchemaReader.endpoint_versions_group_cls(
service_name=self.name,
action_name=endpoint_name,
conf=endpoint_conf,
@ -179,10 +178,30 @@ class Service:
}
class Schema:
services: Mapping[str, Service]
def __init__(self, services: dict, api_defaults: dict):
"""
Represents the entire API schema
:param services: services schema
:param api_defaults: default values of service configuration
"""
self.api_defaults = api_defaults
self.services = {
name: SchemaReader.service_cls(name, conf, api_defaults=self.api_defaults)
for name, conf in services.items()
}
@attr.s()
class SchemaReader:
root: Path = attr.ib(default=HERE.parent / "schema/services", converter=Path)
cache_path: Path = attr.ib(default=None)
service_cls: Type[Service] = Service
endpoint_versions_group_cls: Type[EndpointVersionsGroup] = EndpointVersionsGroup
endpoint_schema_cls: Type[EndpointSchema] = EndpointSchema
root: Path = Path(__file__).parent / "services"
cache_path: Path = None
def __attrs_post_init__(self):
if not self.cache_path:
@ -246,22 +265,3 @@ class SchemaReader:
log.exception(f"failed cache file to {self.cache_path}")
return Schema(services, api_defaults)
class Schema:
services: Mapping[str, Service]
def __init__(self, services: dict, api_defaults: dict):
"""
Represents the entire API schema
:param services: services schema
:param api_defaults: default values of service configuration
"""
self.api_defaults = api_defaults
self.services = {
name: Service(name, conf, api_defaults=self.api_defaults)
for name, conf in services.items()
}
schema = SchemaReader().get_schema()

View File

@ -9,11 +9,8 @@ from six import string_types
from apiserver import database
from apiserver.config_repo import config
from apiserver.timing_context import TimingContext, TimingStats
from apiserver.utilities import json
from apiserver.utilities.partial_version import PartialVersion
from .auth import Identity
from .auth import Payload as AuthPayload
from .errors import CallParsingError
from .schema_validator import SchemaValidator
@ -311,6 +308,8 @@ class APICall(DataContainer):
HEADER_FORWARDED_FOR = "X-Forwarded-For"
""" Standard headers """
_call_result_cls = APICallResult
_transaction_headers = _get_headers("Trx")
""" Transaction ID """
@ -376,7 +375,7 @@ class APICall(DataContainer):
self._log_api = True
if headers:
self._headers.update(headers)
self._result = APICallResult()
self._result = self._call_result_cls()
self._auth = None
self._impersonation = None
if trx:
@ -471,8 +470,6 @@ class APICall(DataContainer):
@auth.setter
def auth(self, value):
if value:
assert isinstance(value, AuthPayload)
self._auth = value
@property
@ -497,12 +494,10 @@ class APICall(DataContainer):
@impersonation.setter
def impersonation(self, value):
if value:
assert isinstance(value, AuthPayload)
self._impersonation = value
@property
def identity(self) -> Identity:
def identity(self):
if self.impersonation:
if not self.impersonation.identity:
raise Exception("Missing impersonate identity")
@ -543,7 +538,10 @@ class APICall(DataContainer):
@property
def worker(self):
return self.get_header(self._worker_headers, "<unknown>")
return self.get_worker(default="<unknown>")
def get_worker(self, default=None):
return self.get_header(self._worker_headers, default)
@property
def authorization(self):
@ -576,10 +574,16 @@ class APICall(DataContainer):
def mark_end(self):
self._end_ts = time.time()
self._duration = int((self._end_ts - self._start_ts) * 1000)
self.stats = TimingStats.aggregate()
def get_response(self):
def make_version_number(version: PartialVersion):
def get_response(self, include_stack: bool = False) -> Tuple[Union[dict, str], str]:
"""
Get the response for this call.
:param include_stack: If True, stack trace stored in this call's result should
be included in the response (default is False)
:return: Response data (encoded according to self.content_type) and the data's content type
"""
def make_version_number(version: PartialVersion) -> Union[None, float, str]:
"""
Client versions <=2.0 expect expect endpoint versions in float format, otherwise throwing an exception
"""
@ -610,26 +614,25 @@ class APICall(DataContainer):
"result_code": self.result.code,
"result_subcode": self.result.subcode,
"result_msg": self.result.msg,
"error_stack": self.result.traceback,
"error_stack": self.result.traceback if include_stack else None,
"error_data": self.result.error_data,
},
"data": self.result.data,
}
if self.content_type.lower() == JSON_CONTENT_TYPE:
with TimingContext("json", "serialization"):
try:
res = json.dumps(res)
except Exception as ex:
# JSON serialization may fail, probably problem with data or error_data so pop it and try again
if not (self.result.data or self.result.error_data):
raise
self.result.data = None
self.result.error_data = None
msg = "Error serializing response data: " + str(ex)
self.set_error_result(
code=500, subcode=0, msg=msg, include_stack=False
)
return self.get_response()
try:
res = json.dumps(res)
except Exception as ex:
# JSON serialization may fail, probably problem with data or error_data so pop it and try again
if not (self.result.data or self.result.error_data):
raise
self.result.data = None
self.result.error_data = None
msg = "Error serializing response data: " + str(ex)
self.set_error_result(
code=500, subcode=0, msg=msg, include_stack=False
)
return self.get_response()
return res, self.content_type
@ -637,7 +640,7 @@ class APICall(DataContainer):
self, msg, code=500, subcode=0, include_stack=False, error_data=None
):
tb = format_exc() if include_stack else None
self._result = APICallResult(
self._result = self._call_result_cls(
data=self._result.data,
code=code,
subcode=subcode,

View File

@ -4,7 +4,7 @@ from boltons.iterutils import remap
from jsonmodels import models
from jsonmodels.errors import FieldNotSupported
from apiserver.schema import schema
from apiserver.services_schema import schema
from apiserver.utilities.partial_version import PartialVersion
from .apicall import APICall
from .schema_validator import SchemaValidator

View File

@ -282,6 +282,7 @@ class ServiceRepo(object):
finally:
content, content_type = call.get_response()
call.mark_end()
console_msg = f"Returned {call.result.code} for {call.endpoint_name} in {call.duration}ms"
if call.result.code < 300:
log.info(console_msg)

View File

@ -0,0 +1,3 @@
from apiserver.schema import SchemaReader
schema = SchemaReader().get_schema()