mirror of
https://github.com/clearml/clearml-server
synced 2025-01-30 18:36:52 +00:00
Unify server request handlers
This commit is contained in:
parent
ca890c7ae8
commit
3c8e27dc94
@ -1 +0,0 @@
|
||||
__version__ = "2.12.0"
|
@ -1,5 +1,4 @@
|
||||
from enum import Enum
|
||||
from textwrap import shorten
|
||||
from typing import Union, Type, Iterable
|
||||
|
||||
import jsonmodels.errors
|
||||
@ -8,8 +7,6 @@ from jsonmodels import fields
|
||||
from jsonmodels.fields import _LazyType, NotSet
|
||||
from jsonmodels.models import Base as ModelBase
|
||||
from jsonmodels.validators import Enum as EnumValidator
|
||||
from luqum.exceptions import ParseError
|
||||
from luqum.parser import parser
|
||||
from mongoengine.base import BaseDocument
|
||||
from validators import email as email_validator, domain as domain_validator
|
||||
|
||||
@ -35,25 +32,6 @@ class DomainField(fields.StringField):
|
||||
raise errors.bad_request.InvalidDomainName()
|
||||
|
||||
|
||||
def validate_lucene_query(value):
|
||||
if value == "":
|
||||
return
|
||||
try:
|
||||
parser.parse(value)
|
||||
except ParseError as e:
|
||||
raise errors.bad_request.InvalidLuceneSyntax(
|
||||
error=str(e), query=shorten(value, 50, placeholder="...")
|
||||
)
|
||||
|
||||
|
||||
class LuceneQueryField(fields.StringField):
|
||||
def validate(self, value):
|
||||
super(LuceneQueryField, self).validate(value)
|
||||
if value is None:
|
||||
return
|
||||
validate_lucene_query(value)
|
||||
|
||||
|
||||
def make_default(field_cls, default_value):
|
||||
class _FieldWithDefault(field_cls):
|
||||
def get_default_value(self):
|
||||
|
@ -1,161 +1,13 @@
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from flask import Flask, request, Response
|
||||
from flask_compress import Compress
|
||||
from flask_cors import CORS
|
||||
from werkzeug.exceptions import BadRequest
|
||||
from flask import Flask
|
||||
|
||||
from apiserver.apierrors.base import BaseError
|
||||
from apiserver.app_sequence import AppSequence
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.service_repo import ServiceRepo, APICall
|
||||
from apiserver.service_repo.auth import AuthType
|
||||
from apiserver.service_repo.errors import PathParsingError
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities import json
|
||||
from apiserver.server_init.app_sequence import AppSequence
|
||||
from apiserver.server_init.request_handlers import RequestHandlers
|
||||
|
||||
app = Flask(__name__, static_url_path="/static")
|
||||
CORS(app, **config.get("apiserver.cors"))
|
||||
Compress(app)
|
||||
AppSequence(app).start()
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
@app.before_first_request
|
||||
def before_app_first_request():
|
||||
pass
|
||||
|
||||
|
||||
@app.before_request
|
||||
def before_request():
|
||||
if request.method == "OPTIONS":
|
||||
return "", 200
|
||||
if "/static/" in request.path:
|
||||
return
|
||||
|
||||
try:
|
||||
call = create_api_call(request)
|
||||
content, content_type = ServiceRepo.handle_call(call)
|
||||
headers = {}
|
||||
if call.result.filename:
|
||||
headers[
|
||||
"Content-Disposition"
|
||||
] = f"attachment; filename={call.result.filename}"
|
||||
|
||||
if call.result.headers:
|
||||
headers.update(call.result.headers)
|
||||
|
||||
response = Response(
|
||||
content, mimetype=content_type, status=call.result.code, headers=headers
|
||||
)
|
||||
|
||||
if call.result.cookies:
|
||||
for key, value in call.result.cookies.items():
|
||||
if value is None:
|
||||
response.set_cookie(key, "", expires=0)
|
||||
else:
|
||||
response.set_cookie(
|
||||
key, value, **config.get("apiserver.auth.cookies")
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed processing request {request.url}: {ex}")
|
||||
return f"Failed processing request {request.url}", 500
|
||||
|
||||
|
||||
def update_call_data(call, req):
|
||||
""" Use request payload/form to fill call data or batched data """
|
||||
if req.content_type == "application/json-lines":
|
||||
items = []
|
||||
for i, line in enumerate(req.data.splitlines()):
|
||||
try:
|
||||
event = json.loads(line)
|
||||
if not isinstance(event, dict):
|
||||
raise BadRequest(
|
||||
f"json lines must contain objects, found: {type(event).__name__}"
|
||||
)
|
||||
items.append(event)
|
||||
except ValueError as e:
|
||||
msg = f"{e} in batch item #{i}"
|
||||
req.on_json_loading_failed(msg)
|
||||
call.batched_data = items
|
||||
else:
|
||||
json_body = req.get_json(force=True, silent=False) if req.data else None
|
||||
# merge form and args
|
||||
form = req.form.copy()
|
||||
form.update(req.args)
|
||||
form = form.to_dict()
|
||||
# convert string numbers to floats
|
||||
for key in form:
|
||||
if form[key].replace(".", "", 1).isdigit():
|
||||
if "." in form[key]:
|
||||
form[key] = float(form[key])
|
||||
else:
|
||||
form[key] = int(form[key])
|
||||
elif form[key].lower() == "true":
|
||||
form[key] = True
|
||||
elif form[key].lower() == "false":
|
||||
form[key] = False
|
||||
call.data = json_body or form or {}
|
||||
|
||||
|
||||
def _call_or_empty_with_error(call, req, msg, code=500, subcode=0):
|
||||
call = call or APICall(
|
||||
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
|
||||
)
|
||||
call.set_error_result(msg=msg, code=code, subcode=subcode)
|
||||
return call
|
||||
|
||||
|
||||
def create_api_call(req):
|
||||
call = None
|
||||
try:
|
||||
# Parse the request path
|
||||
endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(req.path)
|
||||
|
||||
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
|
||||
# in any case, request headers always take precedence.
|
||||
auth_cookie = req.cookies.get(
|
||||
config.get("apiserver.auth.session_auth_cookie_name")
|
||||
)
|
||||
headers = (
|
||||
{}
|
||||
if not auth_cookie
|
||||
else {"Authorization": f"{AuthType.bearer_token} {auth_cookie}"}
|
||||
)
|
||||
headers.update(
|
||||
list(req.headers.items())
|
||||
) # add (possibly override with) the headers
|
||||
|
||||
# Construct call instance
|
||||
call = APICall(
|
||||
endpoint_name=endpoint_name,
|
||||
remote_addr=req.remote_addr,
|
||||
endpoint_version=endpoint_version,
|
||||
headers=headers,
|
||||
files=req.files,
|
||||
)
|
||||
|
||||
# Update call data from request
|
||||
with TimingContext("preprocess", "update_call_data"):
|
||||
update_call_data(call, req)
|
||||
|
||||
except PathParsingError as ex:
|
||||
call = _call_or_empty_with_error(call, req, ex.args[0], 400)
|
||||
call.log_api = False
|
||||
except BadRequest as ex:
|
||||
call = _call_or_empty_with_error(call, req, ex.description, 400)
|
||||
except BaseError as ex:
|
||||
call = _call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode)
|
||||
except Exception as ex:
|
||||
log.exception("Error creating call")
|
||||
call = _call_or_empty_with_error(
|
||||
call, req, ex.args[0] if ex.args else type(ex).__name__, 500
|
||||
)
|
||||
|
||||
return call
|
||||
AppSequence(app).start(request_handlers=RequestHandlers())
|
||||
|
||||
|
||||
# =================== MAIN =======================
|
||||
|
@ -2,6 +2,8 @@ import atexit
|
||||
from hashlib import md5
|
||||
|
||||
from flask import Flask
|
||||
from flask_compress import Compress
|
||||
from flask_cors import CORS
|
||||
from semantic_version import Version
|
||||
|
||||
from apiserver.database import db
|
||||
@ -19,6 +21,7 @@ from apiserver.mongo.initialize import (
|
||||
check_mongo_empty,
|
||||
get_last_server_version,
|
||||
)
|
||||
from apiserver.server_init.request_handlers import RequestHandlers
|
||||
from apiserver.service_repo import ServiceRepo
|
||||
from apiserver.sync import distributed_lock
|
||||
from apiserver.updates import check_updates_thread
|
||||
@ -31,15 +34,23 @@ class AppSequence:
|
||||
def __init__(self, app: Flask):
|
||||
self.app = app
|
||||
|
||||
def start(self):
|
||||
def start(self, request_handlers: RequestHandlers):
|
||||
log.info("################ API Server initializing #####################")
|
||||
self._configure()
|
||||
self._init_dbs()
|
||||
self._load_services()
|
||||
self._start_worker()
|
||||
atexit.register(self._on_worker_stop)
|
||||
self._attach_request_handlers(request_handlers)
|
||||
|
||||
def _attach_request_handlers(self, request_handlers: RequestHandlers):
|
||||
self.app.before_first_request(request_handlers.before_app_first_request)
|
||||
self.app.before_request(request_handlers.before_request)
|
||||
|
||||
def _configure(self):
|
||||
CORS(self.app, **config.get("apiserver.cors"))
|
||||
Compress(self.app)
|
||||
|
||||
self.app.config["SECRET_KEY"] = config.get(
|
||||
"secure.http.session_secret.apiserver"
|
||||
)
|
153
apiserver/server_init/request_handlers.py
Normal file
153
apiserver/server_init/request_handlers.py
Normal file
@ -0,0 +1,153 @@
|
||||
from flask import request, Response, redirect
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from apiserver.apierrors.base import BaseError
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.service_repo import ServiceRepo, APICall
|
||||
from apiserver.service_repo.auth import AuthType
|
||||
from apiserver.service_repo.errors import PathParsingError
|
||||
from apiserver.utilities import json
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class RequestHandlers:
|
||||
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
|
||||
_service_repo_cls = ServiceRepo
|
||||
_api_call_cls = APICall
|
||||
|
||||
def before_app_first_request(self):
|
||||
pass
|
||||
|
||||
def before_request(self):
|
||||
if request.method == "OPTIONS":
|
||||
return "", 200
|
||||
if "/static/" in request.path:
|
||||
return
|
||||
|
||||
try:
|
||||
call = self._create_api_call(request)
|
||||
content, content_type = self._service_repo_cls.handle_call(call)
|
||||
|
||||
if call.result.redirect:
|
||||
response = redirect(call.result.redirect.url, call.result.redirect.code)
|
||||
else:
|
||||
headers = None
|
||||
if call.result.filename:
|
||||
headers = {
|
||||
"Content-Disposition": f"attachment; filename={call.result.filename}"
|
||||
}
|
||||
|
||||
response = Response(
|
||||
content, mimetype=content_type, status=call.result.code, headers=headers
|
||||
)
|
||||
|
||||
if call.result.cookies:
|
||||
for key, value in call.result.cookies.items():
|
||||
kwargs = config.get("apiserver.auth.cookies")
|
||||
if value is None:
|
||||
kwargs = kwargs.copy()
|
||||
kwargs['max_age'] = 0
|
||||
kwargs['expires'] = 0
|
||||
response.set_cookie(key, "", **kwargs)
|
||||
else:
|
||||
response.set_cookie(
|
||||
key, value, **kwargs
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed processing request {request.url}: {ex}")
|
||||
return f"Failed processing request {request.url}", 500
|
||||
|
||||
def _update_call_data(self, call, req):
|
||||
""" Use request payload/form to fill call data or batched data """
|
||||
if req.content_type == "application/json-lines":
|
||||
items = []
|
||||
for i, line in enumerate(req.data.splitlines()):
|
||||
try:
|
||||
event = json.loads(line)
|
||||
if not isinstance(event, dict):
|
||||
raise BadRequest(
|
||||
f"json lines must contain objects, found: {type(event).__name__}"
|
||||
)
|
||||
items.append(event)
|
||||
except ValueError as e:
|
||||
msg = f"{e} in batch item #{i}"
|
||||
req.on_json_loading_failed(msg)
|
||||
call.batched_data = items
|
||||
else:
|
||||
json_body = req.get_json(force=True, silent=False) if req.data else None
|
||||
# merge form and args
|
||||
form = req.form.copy()
|
||||
form.update(req.args)
|
||||
form = form.to_dict()
|
||||
# convert string numbers to floats
|
||||
for key in form:
|
||||
if form[key].replace(".", "", 1).isdigit():
|
||||
if "." in form[key]:
|
||||
form[key] = float(form[key])
|
||||
else:
|
||||
form[key] = int(form[key])
|
||||
elif form[key].lower() == "true":
|
||||
form[key] = True
|
||||
elif form[key].lower() == "false":
|
||||
form[key] = False
|
||||
call.data = json_body or form or {}
|
||||
|
||||
def _call_or_empty_with_error(self, call, req, msg, code=500, subcode=0):
|
||||
call = call or self._api_call_cls(
|
||||
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
|
||||
)
|
||||
call.set_error_result(msg=msg, code=code, subcode=subcode)
|
||||
return call
|
||||
|
||||
def _create_api_call(self, req):
|
||||
call = None
|
||||
try:
|
||||
# Parse the request path
|
||||
path = req.path
|
||||
if self._request_strip_prefix and path.startswith(self._request_strip_prefix):
|
||||
path = path[len(self._request_strip_prefix):]
|
||||
endpoint_version, endpoint_name = self._service_repo_cls.parse_endpoint_path(path)
|
||||
|
||||
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
|
||||
# in any case, request headers always take precedence.
|
||||
auth_cookie = req.cookies.get(
|
||||
config.get("apiserver.auth.session_auth_cookie_name")
|
||||
)
|
||||
headers = (
|
||||
{}
|
||||
if not auth_cookie
|
||||
else {"Authorization": f"{AuthType.bearer_token} {auth_cookie}"}
|
||||
)
|
||||
headers.update(
|
||||
list(req.headers.items())
|
||||
) # add (possibly override with) the headers
|
||||
|
||||
# Construct call instance
|
||||
call = self._api_call_cls(
|
||||
endpoint_name=endpoint_name,
|
||||
remote_addr=req.remote_addr,
|
||||
endpoint_version=endpoint_version,
|
||||
headers=headers,
|
||||
files=req.files,
|
||||
host=req.host,
|
||||
auth_cookie=auth_cookie,
|
||||
)
|
||||
|
||||
# Update call data from request
|
||||
self._update_call_data(call, req)
|
||||
|
||||
except PathParsingError as ex:
|
||||
call = self._call_or_empty_with_error(call, req, ex.args[0], 400)
|
||||
call.log_api = False
|
||||
except BadRequest as ex:
|
||||
call = self._call_or_empty_with_error(call, req, ex.description, 400)
|
||||
except BaseError as ex:
|
||||
call = self._call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode)
|
||||
except Exception as ex:
|
||||
log.exception("Error creating call")
|
||||
call = self._call_or_empty_with_error(call, req, ex.args[0] if ex.args else type(ex).__name__, 500)
|
||||
|
||||
return call
|
@ -8,7 +8,7 @@ from .endpoint import EndpointFunc, Endpoint
|
||||
from .service_repo import ServiceRepo
|
||||
|
||||
|
||||
__all__ = ["APICall", "endpoint"]
|
||||
__all__ = ["ServiceRepo", "APICall", "endpoint"]
|
||||
|
||||
|
||||
LegacyEndpointFunc = Callable[[APICall], None]
|
||||
|
@ -1,12 +1,14 @@
|
||||
import time
|
||||
import types
|
||||
from traceback import format_exc
|
||||
from typing import Type, Optional
|
||||
from typing import Type, Optional, Union, Tuple
|
||||
|
||||
import attr
|
||||
from jsonmodels import models
|
||||
from six import string_types
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.timing_context import TimingContext, TimingStats
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities.partial_version import PartialVersion
|
||||
@ -18,6 +20,19 @@ from .schema_validator import SchemaValidator
|
||||
JSON_CONTENT_TYPE = "application/json"
|
||||
|
||||
|
||||
@attr.s
|
||||
class Redirect:
|
||||
url = attr.ib(type=str)
|
||||
code = attr.ib(
|
||||
type=int,
|
||||
default=302,
|
||||
validator=attr.validators.in_((301, 302, 303, 305, 307, 308)),
|
||||
)
|
||||
|
||||
def empty(self) -> bool:
|
||||
return not (self.url and self.code)
|
||||
|
||||
|
||||
class DataContainer(object):
|
||||
""" Data container that supports raw data (dict or a list of batched dicts) and a data model """
|
||||
|
||||
@ -47,7 +62,7 @@ class DataContainer(object):
|
||||
self._update_data_model()
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> dict:
|
||||
return self._data or {}
|
||||
|
||||
@data.setter
|
||||
@ -164,7 +179,16 @@ class DataContainer(object):
|
||||
|
||||
|
||||
class APICallResult(DataContainer):
|
||||
def __init__(self, data=None, code=200, subcode=0, msg="OK", traceback=""):
|
||||
def __init__(
|
||||
self,
|
||||
data=None,
|
||||
code=200,
|
||||
subcode=0,
|
||||
msg="OK",
|
||||
traceback="",
|
||||
error_data=None,
|
||||
cookies=None,
|
||||
):
|
||||
super(APICallResult, self).__init__(data)
|
||||
self._code = code
|
||||
self._subcode = subcode
|
||||
@ -172,18 +196,18 @@ class APICallResult(DataContainer):
|
||||
self._traceback = traceback
|
||||
self._extra = None
|
||||
self._filename = None
|
||||
self._headers = {}
|
||||
self._cookies = {}
|
||||
self._error_data = error_data or {}
|
||||
self._cookies = cookies or {}
|
||||
self._redirect = None
|
||||
|
||||
def get_log_entry(self):
|
||||
res = dict(
|
||||
return dict(
|
||||
msg=self.msg,
|
||||
code=self.code,
|
||||
subcode=self.subcode,
|
||||
traceback=self._traceback,
|
||||
extra=self._extra,
|
||||
)
|
||||
return res
|
||||
|
||||
def copy_from(self, result):
|
||||
self._code = result.code
|
||||
@ -242,13 +266,34 @@ class APICallResult(DataContainer):
|
||||
self._filename = value
|
||||
|
||||
@property
|
||||
def headers(self):
|
||||
return self._headers
|
||||
def error_data(self):
|
||||
return self._error_data
|
||||
|
||||
@error_data.setter
|
||||
def error_data(self, value):
|
||||
self._error_data = value
|
||||
|
||||
@property
|
||||
def cookies(self):
|
||||
return self._cookies
|
||||
|
||||
def set_auth_cookie(self, value):
|
||||
self.cookies[config.get("apiserver.auth.session_auth_cookie_name")] = value
|
||||
|
||||
@property
|
||||
def redirect(self):
|
||||
return self._redirect
|
||||
|
||||
@redirect.setter
|
||||
def redirect(self, value: Union[Redirect, str, Tuple[str, int], list]):
|
||||
if isinstance(value, str):
|
||||
self._redirect = Redirect(url=value)
|
||||
elif isinstance(value, (tuple, list)):
|
||||
url, code, *_ = value
|
||||
self._redirect = Redirect(url=url, code=code)
|
||||
else:
|
||||
self._redirect = value
|
||||
|
||||
|
||||
class MissingIdentity(Exception):
|
||||
pass
|
||||
@ -305,7 +350,8 @@ class APICall(DataContainer):
|
||||
headers=None,
|
||||
files=None,
|
||||
trx=None,
|
||||
**kwargs,
|
||||
host=None,
|
||||
auth_cookie=None,
|
||||
):
|
||||
super(APICall, self).__init__(data=data, batched_data=batched_data)
|
||||
|
||||
@ -330,6 +376,8 @@ class APICall(DataContainer):
|
||||
if trx:
|
||||
self.set_header(self._transaction_headers, trx)
|
||||
self._requires_authorization = True
|
||||
self._host = host
|
||||
self._auth_cookie = auth_cookie
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
@ -388,10 +436,10 @@ class APICall(DataContainer):
|
||||
def real_ip(self):
|
||||
""" Obtain visitor's IP address """
|
||||
return (
|
||||
self.get_header(self.HEADER_FORWARDED_FOR)
|
||||
or self.get_header(self.HEADER_REAL_IP)
|
||||
or self._remote_addr
|
||||
or "untrackable"
|
||||
self.get_header(self.HEADER_FORWARDED_FOR)
|
||||
or self.get_header(self.HEADER_REAL_IP)
|
||||
or self._remote_addr
|
||||
or "untrackable"
|
||||
)
|
||||
|
||||
@property
|
||||
@ -511,13 +559,21 @@ class APICall(DataContainer):
|
||||
else:
|
||||
self.clear_header(self._async_headers)
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
return self._host
|
||||
|
||||
@property
|
||||
def auth_cookie(self):
|
||||
return self._auth_cookie
|
||||
|
||||
def mark_end(self):
|
||||
self._end_ts = time.time()
|
||||
self._duration = int((self._end_ts - self._start_ts) * 1000)
|
||||
self.stats = TimingStats.aggregate()
|
||||
|
||||
def get_response(self):
|
||||
def make_version_number(version):
|
||||
def make_version_number(version: PartialVersion):
|
||||
"""
|
||||
Client versions <=2.0 expect expect endpoint versions in float format, otherwise throwing an exception
|
||||
"""
|
||||
@ -549,6 +605,7 @@ class APICall(DataContainer):
|
||||
"result_subcode": self.result.subcode,
|
||||
"result_msg": self.result.msg,
|
||||
"error_stack": self.result.traceback,
|
||||
"error_data": self.result.error_data,
|
||||
},
|
||||
"data": self.result.data,
|
||||
}
|
||||
@ -557,10 +614,11 @@ class APICall(DataContainer):
|
||||
try:
|
||||
res = json.dumps(res)
|
||||
except Exception as ex:
|
||||
# JSON serialization may fail, probably problem with data so pop it and try again
|
||||
if not self.result.data:
|
||||
# JSON serialization may fail, probably problem with data or error_data so pop it and try again
|
||||
if not (self.result.data or self.result.error_data):
|
||||
raise
|
||||
self.result.data = None
|
||||
self.result.error_data = None
|
||||
msg = "Error serializing response data: " + str(ex)
|
||||
self.set_error_result(
|
||||
code=500, subcode=0, msg=msg, include_stack=False
|
||||
@ -569,8 +627,16 @@ class APICall(DataContainer):
|
||||
|
||||
return res, self.content_type
|
||||
|
||||
def set_error_result(self, msg, code=500, subcode=0, include_stack=False):
|
||||
def set_error_result(
|
||||
self, msg, code=500, subcode=0, include_stack=False, error_data=None
|
||||
):
|
||||
tb = format_exc() if include_stack else None
|
||||
self._result = APICallResult(
|
||||
data=self._result.data, code=code, subcode=subcode, msg=msg, traceback=tb
|
||||
data=self._result.data,
|
||||
code=code,
|
||||
subcode=subcode,
|
||||
msg=msg,
|
||||
traceback=tb,
|
||||
error_data=error_data,
|
||||
cookies=self._result.cookies,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user