Refactor APICall and schema validation

This commit is contained in:
allegroai 2021-01-05 18:30:59 +02:00
parent 23736efbc3
commit bdf6c353bd
7 changed files with 68 additions and 58 deletions

View File

@ -0,0 +1,3 @@
from .schema_reader import EndpointSchema, EndpointVersionsGroup, SchemaReader, Schema
__all__ = [EndpointSchema, EndpointVersionsGroup, SchemaReader, Schema]

View File

@ -248,9 +248,9 @@ def remove_description(dct):
pass pass
def main(): def main(here: str):
args = parse_args() args = parse_args()
meta = load_hocon(os.path.dirname(__file__) + "/meta.conf") meta = load_hocon(here + "/meta.conf")
validator_for(meta).check_schema(meta) validator_for(meta).check_schema(meta)
driver = LazyDriver() driver = LazyDriver()
@ -300,4 +300,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main(here=os.path.dirname(__file__))

View File

@ -5,7 +5,7 @@ import json
import re import re
from operator import attrgetter from operator import attrgetter
from pathlib import Path from pathlib import Path
from typing import Mapping, Sequence from typing import Mapping, Sequence, Type
import attr import attr
from boltons.dictutils import subdict from boltons.dictutils import subdict
@ -14,7 +14,6 @@ from pyhocon import ConfigFactory
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.utilities.partial_version import PartialVersion from apiserver.utilities.partial_version import PartialVersion
HERE = Path(__file__)
log = config.logger(__file__) log = config.logger(__file__)
@ -120,7 +119,7 @@ class EndpointVersionsGroup:
self.endpoints = sorted( self.endpoints = sorted(
( (
EndpointSchema( SchemaReader.endpoint_schema_cls(
service_name=self.service_name, service_name=self.service_name,
action_name=self.action_name, action_name=self.action_name,
version=parse_version(version), version=parse_version(version),
@ -168,7 +167,7 @@ class Service:
self.defaults = {**api_defaults, **conf.pop("_default", {})} self.defaults = {**api_defaults, **conf.pop("_default", {})}
self.definitions = conf.pop("_definitions", None) self.definitions = conf.pop("_definitions", None)
self.endpoint_groups: Mapping[str, EndpointVersionsGroup] = { self.endpoint_groups: Mapping[str, EndpointVersionsGroup] = {
endpoint_name: EndpointVersionsGroup( endpoint_name: SchemaReader.endpoint_versions_group_cls(
service_name=self.name, service_name=self.name,
action_name=endpoint_name, action_name=endpoint_name,
conf=endpoint_conf, conf=endpoint_conf,
@ -179,10 +178,30 @@ class Service:
} }
class Schema:
services: Mapping[str, Service]
def __init__(self, services: dict, api_defaults: dict):
"""
Represents the entire API schema
:param services: services schema
:param api_defaults: default values of service configuration
"""
self.api_defaults = api_defaults
self.services = {
name: SchemaReader.service_cls(name, conf, api_defaults=self.api_defaults)
for name, conf in services.items()
}
@attr.s() @attr.s()
class SchemaReader: class SchemaReader:
root: Path = attr.ib(default=HERE.parent / "schema/services", converter=Path) service_cls: Type[Service] = Service
cache_path: Path = attr.ib(default=None) endpoint_versions_group_cls: Type[EndpointVersionsGroup] = EndpointVersionsGroup
endpoint_schema_cls: Type[EndpointSchema] = EndpointSchema
root: Path = Path(__file__).parent / "services"
cache_path: Path = None
def __attrs_post_init__(self): def __attrs_post_init__(self):
if not self.cache_path: if not self.cache_path:
@ -246,22 +265,3 @@ class SchemaReader:
log.exception(f"failed cache file to {self.cache_path}") log.exception(f"failed cache file to {self.cache_path}")
return Schema(services, api_defaults) return Schema(services, api_defaults)
class Schema:
services: Mapping[str, Service]
def __init__(self, services: dict, api_defaults: dict):
"""
Represents the entire API schema
:param services: services schema
:param api_defaults: default values of service configuration
"""
self.api_defaults = api_defaults
self.services = {
name: Service(name, conf, api_defaults=self.api_defaults)
for name, conf in services.items()
}
schema = SchemaReader().get_schema()

View File

@ -9,11 +9,8 @@ from six import string_types
from apiserver import database from apiserver import database
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.timing_context import TimingContext, TimingStats
from apiserver.utilities import json from apiserver.utilities import json
from apiserver.utilities.partial_version import PartialVersion from apiserver.utilities.partial_version import PartialVersion
from .auth import Identity
from .auth import Payload as AuthPayload
from .errors import CallParsingError from .errors import CallParsingError
from .schema_validator import SchemaValidator from .schema_validator import SchemaValidator
@ -311,6 +308,8 @@ class APICall(DataContainer):
HEADER_FORWARDED_FOR = "X-Forwarded-For" HEADER_FORWARDED_FOR = "X-Forwarded-For"
""" Standard headers """ """ Standard headers """
_call_result_cls = APICallResult
_transaction_headers = _get_headers("Trx") _transaction_headers = _get_headers("Trx")
""" Transaction ID """ """ Transaction ID """
@ -376,7 +375,7 @@ class APICall(DataContainer):
self._log_api = True self._log_api = True
if headers: if headers:
self._headers.update(headers) self._headers.update(headers)
self._result = APICallResult() self._result = self._call_result_cls()
self._auth = None self._auth = None
self._impersonation = None self._impersonation = None
if trx: if trx:
@ -471,8 +470,6 @@ class APICall(DataContainer):
@auth.setter @auth.setter
def auth(self, value): def auth(self, value):
if value:
assert isinstance(value, AuthPayload)
self._auth = value self._auth = value
@property @property
@ -497,12 +494,10 @@ class APICall(DataContainer):
@impersonation.setter @impersonation.setter
def impersonation(self, value): def impersonation(self, value):
if value:
assert isinstance(value, AuthPayload)
self._impersonation = value self._impersonation = value
@property @property
def identity(self) -> Identity: def identity(self):
if self.impersonation: if self.impersonation:
if not self.impersonation.identity: if not self.impersonation.identity:
raise Exception("Missing impersonate identity") raise Exception("Missing impersonate identity")
@ -543,7 +538,10 @@ class APICall(DataContainer):
@property @property
def worker(self): def worker(self):
return self.get_header(self._worker_headers, "<unknown>") return self.get_worker(default="<unknown>")
def get_worker(self, default=None):
return self.get_header(self._worker_headers, default)
@property @property
def authorization(self): def authorization(self):
@ -576,10 +574,16 @@ class APICall(DataContainer):
def mark_end(self): def mark_end(self):
self._end_ts = time.time() self._end_ts = time.time()
self._duration = int((self._end_ts - self._start_ts) * 1000) self._duration = int((self._end_ts - self._start_ts) * 1000)
self.stats = TimingStats.aggregate()
def get_response(self): def get_response(self, include_stack: bool = False) -> Tuple[Union[dict, str], str]:
def make_version_number(version: PartialVersion): """
Get the response for this call.
:param include_stack: If True, stack trace stored in this call's result should
be included in the response (default is False)
:return: Response data (encoded according to self.content_type) and the data's content type
"""
def make_version_number(version: PartialVersion) -> Union[None, float, str]:
""" """
Client versions <=2.0 expect expect endpoint versions in float format, otherwise throwing an exception Client versions <=2.0 expect expect endpoint versions in float format, otherwise throwing an exception
""" """
@ -610,13 +614,12 @@ class APICall(DataContainer):
"result_code": self.result.code, "result_code": self.result.code,
"result_subcode": self.result.subcode, "result_subcode": self.result.subcode,
"result_msg": self.result.msg, "result_msg": self.result.msg,
"error_stack": self.result.traceback, "error_stack": self.result.traceback if include_stack else None,
"error_data": self.result.error_data, "error_data": self.result.error_data,
}, },
"data": self.result.data, "data": self.result.data,
} }
if self.content_type.lower() == JSON_CONTENT_TYPE: if self.content_type.lower() == JSON_CONTENT_TYPE:
with TimingContext("json", "serialization"):
try: try:
res = json.dumps(res) res = json.dumps(res)
except Exception as ex: except Exception as ex:
@ -637,7 +640,7 @@ class APICall(DataContainer):
self, msg, code=500, subcode=0, include_stack=False, error_data=None self, msg, code=500, subcode=0, include_stack=False, error_data=None
): ):
tb = format_exc() if include_stack else None tb = format_exc() if include_stack else None
self._result = APICallResult( self._result = self._call_result_cls(
data=self._result.data, data=self._result.data,
code=code, code=code,
subcode=subcode, subcode=subcode,

View File

@ -4,7 +4,7 @@ from boltons.iterutils import remap
from jsonmodels import models from jsonmodels import models
from jsonmodels.errors import FieldNotSupported from jsonmodels.errors import FieldNotSupported
from apiserver.schema import schema from apiserver.services_schema import schema
from apiserver.utilities.partial_version import PartialVersion from apiserver.utilities.partial_version import PartialVersion
from .apicall import APICall from .apicall import APICall
from .schema_validator import SchemaValidator from .schema_validator import SchemaValidator

View File

@ -282,6 +282,7 @@ class ServiceRepo(object):
finally: finally:
content, content_type = call.get_response() content, content_type = call.get_response()
call.mark_end() call.mark_end()
console_msg = f"Returned {call.result.code} for {call.endpoint_name} in {call.duration}ms" console_msg = f"Returned {call.result.code} for {call.endpoint_name} in {call.duration}ms"
if call.result.code < 300: if call.result.code < 300:
log.info(console_msg) log.info(console_msg)

View File

@ -0,0 +1,3 @@
from apiserver.schema import SchemaReader
schema = SchemaReader().get_schema()