clearml-server/apiserver/server_init/request_handlers.py

215 lines
8.4 KiB
Python
Raw Normal View History

import unicodedata
from functools import partial
2021-01-05 16:28:43 +00:00
from flask import request, Response, redirect
from werkzeug.datastructures import ImmutableMultiDict
2021-01-05 16:28:43 +00:00
from werkzeug.exceptions import BadRequest
from werkzeug.urls import url_quote
2021-01-05 16:28:43 +00:00
from apiserver.apierrors import APIError
2021-01-05 16:28:43 +00:00
from apiserver.apierrors.base import BaseError
from apiserver.config_repo import config
from apiserver.service_repo import ServiceRepo, APICall
from apiserver.service_repo.auth import AuthType, Token
2021-01-05 16:28:43 +00:00
from apiserver.service_repo.errors import PathParsingError
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_set
2021-01-05 16:28:43 +00:00
log = config.logger(__file__)
class RequestHandlers:
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
_server_header = config.get("apiserver.response.headers.server", "clearml")
2021-01-05 16:28:43 +00:00
def before_app_first_request(self):
pass
def before_request(self):
if request.method == "OPTIONS":
return "", 200
if "/static/" in request.path:
return
if request.content_encoding:
return f"Content encoding is not supported ({request.content_encoding})", 415
2021-01-05 16:28:43 +00:00
try:
call = self._create_api_call(request)
load_data_callback = partial(self._load_call_data, req=request)
content, content_type, company = ServiceRepo.handle_call(
call, load_data_callback=load_data_callback
)
2021-01-05 16:28:43 +00:00
if call.result.redirect:
response = redirect(call.result.redirect.url, call.result.redirect.code)
else:
headers = None
if call.result.filename:
try:
call.result.filename.encode("ascii")
except UnicodeEncodeError:
simple = unicodedata.normalize("NFKD", call.result.filename)
simple = simple.encode("ascii", "ignore").decode("ascii")
# safe = RFC 5987 attr-char
quoted = url_quote(call.result.filename, safe="")
filenames = f"filename={simple}; filename*=UTF-8''{quoted}"
else:
filenames = f"filename={call.result.filename}"
headers = {"Content-Disposition": "attachment; " + filenames}
2021-01-05 16:28:43 +00:00
response = Response(
2021-01-05 16:50:42 +00:00
content,
mimetype=content_type,
status=call.result.code,
headers=headers,
2021-01-05 16:28:43 +00:00
)
if call.result.cookies:
for key, value in call.result.cookies.items():
kwargs = config.get("apiserver.auth.cookies").copy()
if value is None:
# Removing a cookie
kwargs["max_age"] = 0
kwargs["expires"] = 0
value = ""
elif not company:
# Setting a cookie, let's try to figure out the company
# noinspection PyBroadException
try:
company = Token.decode_identity(value).company
except Exception:
pass
if company:
try:
# use no default value to allow setting a null domain as well
kwargs["domain"] = config.get(f"apiserver.auth.cookies_domain_override.{company}")
except KeyError:
pass
response.set_cookie(key, value, **kwargs)
2021-01-05 16:28:43 +00:00
return response
except Exception as ex:
log.exception(f"Failed processing request {request.url}: {ex}")
return f"Failed processing request {request.url}", 500
def after_request(self, response):
response.headers["server"] = self._server_header
return response
@staticmethod
def _apply_multi_dict(body: dict, md: ImmutableMultiDict):
def convert_value(v: str):
if v.replace(".", "", 1).isdigit():
return float(v) if "." in v else int(v)
if v in ("true", "True", "TRUE"):
return True
if v in ("false", "False", "FALSE"):
return False
return v
for k, v in md.lists():
v = [convert_value(x) for x in v] if (len(v) > 1 or k.endswith("[]")) else convert_value(v[0])
nested_set(body, k.rstrip("[]").split("."), v)
2021-01-05 16:28:43 +00:00
def _update_call_data(self, call, req):
""" Use request payload/form to fill call data or batched data """
if req.content_type == "application/json-lines":
items = []
for i, line in enumerate(req.data.splitlines()):
try:
event = json.loads(line)
if not isinstance(event, dict):
raise BadRequest(
f"json lines must contain objects, found: {type(event).__name__}"
)
items.append(event)
except ValueError as e:
msg = f"{e} in batch item #{i}"
req.on_json_loading_failed(msg)
call.batched_data = items
else:
body = (req.get_json(force=True, silent=False) if req.data else None) or {}
if req.args:
self._apply_multi_dict(body, req.args)
if req.form:
self._apply_multi_dict(body, req.form)
call.data = body
2021-01-05 16:28:43 +00:00
def _call_or_empty_with_error(self, call, req, msg, code=500, subcode=0):
2021-01-05 16:50:42 +00:00
call = call or APICall(
2021-01-05 16:28:43 +00:00
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
)
call.set_error_result(msg=msg, code=code, subcode=subcode)
return call
def _create_api_call(self, req):
call = None
try:
# Parse the request path
path = req.path
2021-01-05 16:50:42 +00:00
if self._request_strip_prefix and path.startswith(
self._request_strip_prefix
):
path = path[len(self._request_strip_prefix) :]
endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(path)
2021-01-05 16:28:43 +00:00
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
# in any case, request headers always take precedence.
auth_cookie = req.cookies.get(
config.get("apiserver.auth.session_auth_cookie_name")
)
headers = (
{}
if not auth_cookie
else {"Authorization": f"{AuthType.bearer_token} {auth_cookie}"}
)
headers.update(
list(req.headers.items())
) # add (possibly override with) the headers
# Construct call instance
2021-01-05 16:50:42 +00:00
call = APICall(
2021-01-05 16:28:43 +00:00
endpoint_name=endpoint_name,
remote_addr=req.remote_addr,
endpoint_version=endpoint_version,
headers=headers,
files=req.files,
host=req.host,
auth_cookie=auth_cookie,
)
except PathParsingError as ex:
call = self._call_or_empty_with_error(call, req, ex.args[0], 400)
call.log_api = False
except BadRequest as ex:
call = self._call_or_empty_with_error(call, req, ex.description, 400)
except BaseError as ex:
2021-01-05 16:50:42 +00:00
call = self._call_or_empty_with_error(
call, req, ex.msg, ex.code, ex.subcode
)
2021-01-05 16:28:43 +00:00
except Exception as ex:
log.exception("Error creating call")
2021-01-05 16:50:42 +00:00
call = self._call_or_empty_with_error(
call, req, ex.args[0] if ex.args else type(ex).__name__, 500
)
2021-01-05 16:28:43 +00:00
return call
def _load_call_data(self, call: APICall, req):
"""Update call data from request"""
try:
self._update_call_data(call, req)
except BadRequest as ex:
call.set_error_result(msg=ex.description, code=400)
except BaseError as ex:
call.set_error_result(msg=ex.msg, code=ex.code, subcode=ex.subcode)
except APIError as ex:
call.set_error_result(
msg=ex.msg, code=ex.code, subcode=ex.subcode, error_data=ex.error_data
)
except Exception as ex:
call.set_error_result(msg=ex.args[0] if ex.args else type(ex).__name__)