Unify server request handlers

This commit is contained in:
allegroai
2021-01-05 18:28:43 +02:00
parent ca890c7ae8
commit 3c8e27dc94
7 changed files with 255 additions and 196 deletions

View File

@@ -8,7 +8,7 @@ from .endpoint import EndpointFunc, Endpoint
from .service_repo import ServiceRepo
__all__ = ["APICall", "endpoint"]
__all__ = ["ServiceRepo", "APICall", "endpoint"]
LegacyEndpointFunc = Callable[[APICall], None]

View File

@@ -1,12 +1,14 @@
import time
import types
from traceback import format_exc
from typing import Type, Optional
from typing import Type, Optional, Union, Tuple
import attr
from jsonmodels import models
from six import string_types
from apiserver import database
from apiserver.config_repo import config
from apiserver.timing_context import TimingContext, TimingStats
from apiserver.utilities import json
from apiserver.utilities.partial_version import PartialVersion
@@ -18,6 +20,19 @@ from .schema_validator import SchemaValidator
JSON_CONTENT_TYPE = "application/json"
@attr.s
class Redirect:
url = attr.ib(type=str)
code = attr.ib(
type=int,
default=302,
validator=attr.validators.in_((301, 302, 303, 305, 307, 308)),
)
def empty(self) -> bool:
return not (self.url and self.code)
class DataContainer(object):
""" Data container that supports raw data (dict or a list of batched dicts) and a data model """
@@ -47,7 +62,7 @@ class DataContainer(object):
self._update_data_model()
@property
def data(self):
def data(self) -> dict:
return self._data or {}
@data.setter
@@ -164,7 +179,16 @@ class DataContainer(object):
class APICallResult(DataContainer):
def __init__(self, data=None, code=200, subcode=0, msg="OK", traceback=""):
def __init__(
self,
data=None,
code=200,
subcode=0,
msg="OK",
traceback="",
error_data=None,
cookies=None,
):
super(APICallResult, self).__init__(data)
self._code = code
self._subcode = subcode
@@ -172,18 +196,18 @@ class APICallResult(DataContainer):
self._traceback = traceback
self._extra = None
self._filename = None
self._headers = {}
self._cookies = {}
self._error_data = error_data or {}
self._cookies = cookies or {}
self._redirect = None
def get_log_entry(self):
res = dict(
return dict(
msg=self.msg,
code=self.code,
subcode=self.subcode,
traceback=self._traceback,
extra=self._extra,
)
return res
def copy_from(self, result):
self._code = result.code
@@ -242,13 +266,34 @@ class APICallResult(DataContainer):
self._filename = value
@property
def headers(self):
return self._headers
def error_data(self):
return self._error_data
@error_data.setter
def error_data(self, value):
self._error_data = value
@property
def cookies(self):
return self._cookies
def set_auth_cookie(self, value):
self.cookies[config.get("apiserver.auth.session_auth_cookie_name")] = value
@property
def redirect(self):
return self._redirect
@redirect.setter
def redirect(self, value: Union[Redirect, str, Tuple[str, int], list]):
if isinstance(value, str):
self._redirect = Redirect(url=value)
elif isinstance(value, (tuple, list)):
url, code, *_ = value
self._redirect = Redirect(url=url, code=code)
else:
self._redirect = value
class MissingIdentity(Exception):
pass
@@ -305,7 +350,8 @@ class APICall(DataContainer):
headers=None,
files=None,
trx=None,
**kwargs,
host=None,
auth_cookie=None,
):
super(APICall, self).__init__(data=data, batched_data=batched_data)
@@ -330,6 +376,8 @@ class APICall(DataContainer):
if trx:
self.set_header(self._transaction_headers, trx)
self._requires_authorization = True
self._host = host
self._auth_cookie = auth_cookie
@property
def id(self):
@@ -388,10 +436,10 @@ class APICall(DataContainer):
def real_ip(self):
""" Obtain visitor's IP address """
return (
self.get_header(self.HEADER_FORWARDED_FOR)
or self.get_header(self.HEADER_REAL_IP)
or self._remote_addr
or "untrackable"
self.get_header(self.HEADER_FORWARDED_FOR)
or self.get_header(self.HEADER_REAL_IP)
or self._remote_addr
or "untrackable"
)
@property
@@ -511,13 +559,21 @@ class APICall(DataContainer):
else:
self.clear_header(self._async_headers)
@property
def host(self):
return self._host
@property
def auth_cookie(self):
return self._auth_cookie
def mark_end(self):
self._end_ts = time.time()
self._duration = int((self._end_ts - self._start_ts) * 1000)
self.stats = TimingStats.aggregate()
def get_response(self):
def make_version_number(version):
def make_version_number(version: PartialVersion):
"""
Client versions <=2.0 expect expect endpoint versions in float format, otherwise throwing an exception
"""
@@ -549,6 +605,7 @@ class APICall(DataContainer):
"result_subcode": self.result.subcode,
"result_msg": self.result.msg,
"error_stack": self.result.traceback,
"error_data": self.result.error_data,
},
"data": self.result.data,
}
@@ -557,10 +614,11 @@ class APICall(DataContainer):
try:
res = json.dumps(res)
except Exception as ex:
# JSON serialization may fail, probably problem with data so pop it and try again
if not self.result.data:
# JSON serialization may fail, probably problem with data or error_data so pop it and try again
if not (self.result.data or self.result.error_data):
raise
self.result.data = None
self.result.error_data = None
msg = "Error serializing response data: " + str(ex)
self.set_error_result(
code=500, subcode=0, msg=msg, include_stack=False
@@ -569,8 +627,16 @@ class APICall(DataContainer):
return res, self.content_type
def set_error_result(self, msg, code=500, subcode=0, include_stack=False):
def set_error_result(
self, msg, code=500, subcode=0, include_stack=False, error_data=None
):
tb = format_exc() if include_stack else None
self._result = APICallResult(
data=self._result.data, code=code, subcode=subcode, msg=msg, traceback=tb
data=self._result.data,
code=code,
subcode=subcode,
msg=msg,
traceback=tb,
error_data=error_data,
cookies=self._result.cookies,
)