Unify server request handlers

This commit is contained in:
allegroai 2021-01-05 18:28:43 +02:00
parent ca890c7ae8
commit 3c8e27dc94
7 changed files with 255 additions and 196 deletions

View File

@ -1 +0,0 @@
__version__ = "2.12.0"

View File

@ -1,5 +1,4 @@
from enum import Enum
from textwrap import shorten
from typing import Union, Type, Iterable
import jsonmodels.errors
@ -8,8 +7,6 @@ from jsonmodels import fields
from jsonmodels.fields import _LazyType, NotSet
from jsonmodels.models import Base as ModelBase
from jsonmodels.validators import Enum as EnumValidator
from luqum.exceptions import ParseError
from luqum.parser import parser
from mongoengine.base import BaseDocument
from validators import email as email_validator, domain as domain_validator
@ -35,25 +32,6 @@ class DomainField(fields.StringField):
raise errors.bad_request.InvalidDomainName()
def validate_lucene_query(value):
if value == "":
return
try:
parser.parse(value)
except ParseError as e:
raise errors.bad_request.InvalidLuceneSyntax(
error=str(e), query=shorten(value, 50, placeholder="...")
)
class LuceneQueryField(fields.StringField):
def validate(self, value):
super(LuceneQueryField, self).validate(value)
if value is None:
return
validate_lucene_query(value)
def make_default(field_cls, default_value):
class _FieldWithDefault(field_cls):
def get_default_value(self):

View File

@ -1,161 +1,13 @@
from argparse import ArgumentParser
from flask import Flask, request, Response
from flask_compress import Compress
from flask_cors import CORS
from werkzeug.exceptions import BadRequest
from flask import Flask
from apiserver.apierrors.base import BaseError
from apiserver.app_sequence import AppSequence
from apiserver.config_repo import config
from apiserver.service_repo import ServiceRepo, APICall
from apiserver.service_repo.auth import AuthType
from apiserver.service_repo.errors import PathParsingError
from apiserver.timing_context import TimingContext
from apiserver.utilities import json
from apiserver.server_init.app_sequence import AppSequence
from apiserver.server_init.request_handlers import RequestHandlers
app = Flask(__name__, static_url_path="/static")
CORS(app, **config.get("apiserver.cors"))
Compress(app)
AppSequence(app).start()
log = config.logger(__file__)
@app.before_first_request
def before_app_first_request():
pass
@app.before_request
def before_request():
if request.method == "OPTIONS":
return "", 200
if "/static/" in request.path:
return
try:
call = create_api_call(request)
content, content_type = ServiceRepo.handle_call(call)
headers = {}
if call.result.filename:
headers[
"Content-Disposition"
] = f"attachment; filename={call.result.filename}"
if call.result.headers:
headers.update(call.result.headers)
response = Response(
content, mimetype=content_type, status=call.result.code, headers=headers
)
if call.result.cookies:
for key, value in call.result.cookies.items():
if value is None:
response.set_cookie(key, "", expires=0)
else:
response.set_cookie(
key, value, **config.get("apiserver.auth.cookies")
)
return response
except Exception as ex:
log.exception(f"Failed processing request {request.url}: {ex}")
return f"Failed processing request {request.url}", 500
def update_call_data(call, req):
""" Use request payload/form to fill call data or batched data """
if req.content_type == "application/json-lines":
items = []
for i, line in enumerate(req.data.splitlines()):
try:
event = json.loads(line)
if not isinstance(event, dict):
raise BadRequest(
f"json lines must contain objects, found: {type(event).__name__}"
)
items.append(event)
except ValueError as e:
msg = f"{e} in batch item #{i}"
req.on_json_loading_failed(msg)
call.batched_data = items
else:
json_body = req.get_json(force=True, silent=False) if req.data else None
# merge form and args
form = req.form.copy()
form.update(req.args)
form = form.to_dict()
# convert string numbers to floats
for key in form:
if form[key].replace(".", "", 1).isdigit():
if "." in form[key]:
form[key] = float(form[key])
else:
form[key] = int(form[key])
elif form[key].lower() == "true":
form[key] = True
elif form[key].lower() == "false":
form[key] = False
call.data = json_body or form or {}
def _call_or_empty_with_error(call, req, msg, code=500, subcode=0):
call = call or APICall(
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
)
call.set_error_result(msg=msg, code=code, subcode=subcode)
return call
def create_api_call(req):
call = None
try:
# Parse the request path
endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(req.path)
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
# in any case, request headers always take precedence.
auth_cookie = req.cookies.get(
config.get("apiserver.auth.session_auth_cookie_name")
)
headers = (
{}
if not auth_cookie
else {"Authorization": f"{AuthType.bearer_token} {auth_cookie}"}
)
headers.update(
list(req.headers.items())
) # add (possibly override with) the headers
# Construct call instance
call = APICall(
endpoint_name=endpoint_name,
remote_addr=req.remote_addr,
endpoint_version=endpoint_version,
headers=headers,
files=req.files,
)
# Update call data from request
with TimingContext("preprocess", "update_call_data"):
update_call_data(call, req)
except PathParsingError as ex:
call = _call_or_empty_with_error(call, req, ex.args[0], 400)
call.log_api = False
except BadRequest as ex:
call = _call_or_empty_with_error(call, req, ex.description, 400)
except BaseError as ex:
call = _call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode)
except Exception as ex:
log.exception("Error creating call")
call = _call_or_empty_with_error(
call, req, ex.args[0] if ex.args else type(ex).__name__, 500
)
return call
AppSequence(app).start(request_handlers=RequestHandlers())
# =================== MAIN =======================

View File

@ -2,6 +2,8 @@ import atexit
from hashlib import md5
from flask import Flask
from flask_compress import Compress
from flask_cors import CORS
from semantic_version import Version
from apiserver.database import db
@ -19,6 +21,7 @@ from apiserver.mongo.initialize import (
check_mongo_empty,
get_last_server_version,
)
from apiserver.server_init.request_handlers import RequestHandlers
from apiserver.service_repo import ServiceRepo
from apiserver.sync import distributed_lock
from apiserver.updates import check_updates_thread
@ -31,15 +34,23 @@ class AppSequence:
def __init__(self, app: Flask):
self.app = app
def start(self):
def start(self, request_handlers: RequestHandlers):
log.info("################ API Server initializing #####################")
self._configure()
self._init_dbs()
self._load_services()
self._start_worker()
atexit.register(self._on_worker_stop)
self._attach_request_handlers(request_handlers)
def _attach_request_handlers(self, request_handlers: RequestHandlers):
self.app.before_first_request(request_handlers.before_app_first_request)
self.app.before_request(request_handlers.before_request)
def _configure(self):
CORS(self.app, **config.get("apiserver.cors"))
Compress(self.app)
self.app.config["SECRET_KEY"] = config.get(
"secure.http.session_secret.apiserver"
)

View File

@ -0,0 +1,153 @@
from flask import request, Response, redirect
from werkzeug.exceptions import BadRequest
from apiserver.apierrors.base import BaseError
from apiserver.config_repo import config
from apiserver.service_repo import ServiceRepo, APICall
from apiserver.service_repo.auth import AuthType
from apiserver.service_repo.errors import PathParsingError
from apiserver.utilities import json
log = config.logger(__file__)
class RequestHandlers:
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
_service_repo_cls = ServiceRepo
_api_call_cls = APICall
def before_app_first_request(self):
pass
def before_request(self):
if request.method == "OPTIONS":
return "", 200
if "/static/" in request.path:
return
try:
call = self._create_api_call(request)
content, content_type = self._service_repo_cls.handle_call(call)
if call.result.redirect:
response = redirect(call.result.redirect.url, call.result.redirect.code)
else:
headers = None
if call.result.filename:
headers = {
"Content-Disposition": f"attachment; filename={call.result.filename}"
}
response = Response(
content, mimetype=content_type, status=call.result.code, headers=headers
)
if call.result.cookies:
for key, value in call.result.cookies.items():
kwargs = config.get("apiserver.auth.cookies")
if value is None:
kwargs = kwargs.copy()
kwargs['max_age'] = 0
kwargs['expires'] = 0
response.set_cookie(key, "", **kwargs)
else:
response.set_cookie(
key, value, **kwargs
)
return response
except Exception as ex:
log.exception(f"Failed processing request {request.url}: {ex}")
return f"Failed processing request {request.url}", 500
def _update_call_data(self, call, req):
""" Use request payload/form to fill call data or batched data """
if req.content_type == "application/json-lines":
items = []
for i, line in enumerate(req.data.splitlines()):
try:
event = json.loads(line)
if not isinstance(event, dict):
raise BadRequest(
f"json lines must contain objects, found: {type(event).__name__}"
)
items.append(event)
except ValueError as e:
msg = f"{e} in batch item #{i}"
req.on_json_loading_failed(msg)
call.batched_data = items
else:
json_body = req.get_json(force=True, silent=False) if req.data else None
# merge form and args
form = req.form.copy()
form.update(req.args)
form = form.to_dict()
# convert string numbers to floats
for key in form:
if form[key].replace(".", "", 1).isdigit():
if "." in form[key]:
form[key] = float(form[key])
else:
form[key] = int(form[key])
elif form[key].lower() == "true":
form[key] = True
elif form[key].lower() == "false":
form[key] = False
call.data = json_body or form or {}
def _call_or_empty_with_error(self, call, req, msg, code=500, subcode=0):
call = call or self._api_call_cls(
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
)
call.set_error_result(msg=msg, code=code, subcode=subcode)
return call
def _create_api_call(self, req):
call = None
try:
# Parse the request path
path = req.path
if self._request_strip_prefix and path.startswith(self._request_strip_prefix):
path = path[len(self._request_strip_prefix):]
endpoint_version, endpoint_name = self._service_repo_cls.parse_endpoint_path(path)
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
# in any case, request headers always take precedence.
auth_cookie = req.cookies.get(
config.get("apiserver.auth.session_auth_cookie_name")
)
headers = (
{}
if not auth_cookie
else {"Authorization": f"{AuthType.bearer_token} {auth_cookie}"}
)
headers.update(
list(req.headers.items())
) # add (possibly override with) the headers
# Construct call instance
call = self._api_call_cls(
endpoint_name=endpoint_name,
remote_addr=req.remote_addr,
endpoint_version=endpoint_version,
headers=headers,
files=req.files,
host=req.host,
auth_cookie=auth_cookie,
)
# Update call data from request
self._update_call_data(call, req)
except PathParsingError as ex:
call = self._call_or_empty_with_error(call, req, ex.args[0], 400)
call.log_api = False
except BadRequest as ex:
call = self._call_or_empty_with_error(call, req, ex.description, 400)
except BaseError as ex:
call = self._call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode)
except Exception as ex:
log.exception("Error creating call")
call = self._call_or_empty_with_error(call, req, ex.args[0] if ex.args else type(ex).__name__, 500)
return call

View File

@ -8,7 +8,7 @@ from .endpoint import EndpointFunc, Endpoint
from .service_repo import ServiceRepo
__all__ = ["APICall", "endpoint"]
__all__ = ["ServiceRepo", "APICall", "endpoint"]
LegacyEndpointFunc = Callable[[APICall], None]

View File

@ -1,12 +1,14 @@
import time
import types
from traceback import format_exc
from typing import Type, Optional
from typing import Type, Optional, Union, Tuple
import attr
from jsonmodels import models
from six import string_types
from apiserver import database
from apiserver.config_repo import config
from apiserver.timing_context import TimingContext, TimingStats
from apiserver.utilities import json
from apiserver.utilities.partial_version import PartialVersion
@ -18,6 +20,19 @@ from .schema_validator import SchemaValidator
JSON_CONTENT_TYPE = "application/json"
@attr.s
class Redirect:
url = attr.ib(type=str)
code = attr.ib(
type=int,
default=302,
validator=attr.validators.in_((301, 302, 303, 305, 307, 308)),
)
def empty(self) -> bool:
return not (self.url and self.code)
class DataContainer(object):
""" Data container that supports raw data (dict or a list of batched dicts) and a data model """
@ -47,7 +62,7 @@ class DataContainer(object):
self._update_data_model()
@property
def data(self):
def data(self) -> dict:
return self._data or {}
@data.setter
@ -164,7 +179,16 @@ class DataContainer(object):
class APICallResult(DataContainer):
def __init__(self, data=None, code=200, subcode=0, msg="OK", traceback=""):
def __init__(
self,
data=None,
code=200,
subcode=0,
msg="OK",
traceback="",
error_data=None,
cookies=None,
):
super(APICallResult, self).__init__(data)
self._code = code
self._subcode = subcode
@ -172,18 +196,18 @@ class APICallResult(DataContainer):
self._traceback = traceback
self._extra = None
self._filename = None
self._headers = {}
self._cookies = {}
self._error_data = error_data or {}
self._cookies = cookies or {}
self._redirect = None
def get_log_entry(self):
res = dict(
return dict(
msg=self.msg,
code=self.code,
subcode=self.subcode,
traceback=self._traceback,
extra=self._extra,
)
return res
def copy_from(self, result):
self._code = result.code
@ -242,13 +266,34 @@ class APICallResult(DataContainer):
self._filename = value
@property
def headers(self):
return self._headers
def error_data(self):
return self._error_data
@error_data.setter
def error_data(self, value):
self._error_data = value
@property
def cookies(self):
return self._cookies
def set_auth_cookie(self, value):
self.cookies[config.get("apiserver.auth.session_auth_cookie_name")] = value
@property
def redirect(self):
return self._redirect
@redirect.setter
def redirect(self, value: Union[Redirect, str, Tuple[str, int], list]):
if isinstance(value, str):
self._redirect = Redirect(url=value)
elif isinstance(value, (tuple, list)):
url, code, *_ = value
self._redirect = Redirect(url=url, code=code)
else:
self._redirect = value
class MissingIdentity(Exception):
pass
@ -305,7 +350,8 @@ class APICall(DataContainer):
headers=None,
files=None,
trx=None,
**kwargs,
host=None,
auth_cookie=None,
):
super(APICall, self).__init__(data=data, batched_data=batched_data)
@ -330,6 +376,8 @@ class APICall(DataContainer):
if trx:
self.set_header(self._transaction_headers, trx)
self._requires_authorization = True
self._host = host
self._auth_cookie = auth_cookie
@property
def id(self):
@ -388,10 +436,10 @@ class APICall(DataContainer):
def real_ip(self):
""" Obtain visitor's IP address """
return (
self.get_header(self.HEADER_FORWARDED_FOR)
or self.get_header(self.HEADER_REAL_IP)
or self._remote_addr
or "untrackable"
self.get_header(self.HEADER_FORWARDED_FOR)
or self.get_header(self.HEADER_REAL_IP)
or self._remote_addr
or "untrackable"
)
@property
@ -511,13 +559,21 @@ class APICall(DataContainer):
else:
self.clear_header(self._async_headers)
@property
def host(self):
return self._host
@property
def auth_cookie(self):
return self._auth_cookie
def mark_end(self):
self._end_ts = time.time()
self._duration = int((self._end_ts - self._start_ts) * 1000)
self.stats = TimingStats.aggregate()
def get_response(self):
def make_version_number(version):
def make_version_number(version: PartialVersion):
"""
Client versions <=2.0 expect expect endpoint versions in float format, otherwise throwing an exception
"""
@ -549,6 +605,7 @@ class APICall(DataContainer):
"result_subcode": self.result.subcode,
"result_msg": self.result.msg,
"error_stack": self.result.traceback,
"error_data": self.result.error_data,
},
"data": self.result.data,
}
@ -557,10 +614,11 @@ class APICall(DataContainer):
try:
res = json.dumps(res)
except Exception as ex:
# JSON serialization may fail, probably problem with data so pop it and try again
if not self.result.data:
# JSON serialization may fail, probably problem with data or error_data so pop it and try again
if not (self.result.data or self.result.error_data):
raise
self.result.data = None
self.result.error_data = None
msg = "Error serializing response data: " + str(ex)
self.set_error_result(
code=500, subcode=0, msg=msg, include_stack=False
@ -569,8 +627,16 @@ class APICall(DataContainer):
return res, self.content_type
def set_error_result(self, msg, code=500, subcode=0, include_stack=False):
def set_error_result(
self, msg, code=500, subcode=0, include_stack=False, error_data=None
):
tb = format_exc() if include_stack else None
self._result = APICallResult(
data=self._result.data, code=code, subcode=subcode, msg=msg, traceback=tb
data=self._result.data,
code=code,
subcode=subcode,
msg=msg,
traceback=tb,
error_data=error_data,
cookies=self._result.cookies,
)