Add support for lists and nested fields in URL args and form

This commit is contained in:
allegroai 2022-02-13 19:52:05 +02:00
parent c4001b4037
commit 92fd98d5ad
6 changed files with 45 additions and 21 deletions

View File

@ -1220,7 +1220,7 @@ get_scalar_metric_data {
}
}
scalar_metrics_iter_raw {
"999.0" {
"2.16" {
description: "Get raw data for a specific metric variants in the task"
request {
type: object

View File

@ -539,7 +539,7 @@ get_all_ex {
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
"999.0": ${get_all_ex."2.15"} {
"2.16": ${get_all_ex."2.15"} {
request.properties.stats_with_children {
description: "If include_stats flag is set then this flag contols whether the child projects tasks are taken into statistics or not"
type: boolean

View File

@ -1,6 +1,7 @@
from functools import partial
from flask import request, Response, redirect
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.exceptions import BadRequest
from apiserver.apierrors import APIError
@ -10,6 +11,7 @@ from apiserver.service_repo import ServiceRepo, APICall
from apiserver.service_repo.auth import AuthType, Token
from apiserver.service_repo.errors import PathParsingError
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_set
log = config.logger(__file__)
@ -79,6 +81,21 @@ class RequestHandlers:
log.exception(f"Failed processing request {request.url}: {ex}")
return f"Failed processing request {request.url}", 500
@staticmethod
def _apply_multi_dict(body: dict, md: ImmutableMultiDict):
def convert_value(v: str):
if v.replace(".", "", 1).isdigit():
return float(v) if "." in v else int(v)
if v in ("true", "True", "TRUE"):
return True
if v in ("false", "False", "FALSE"):
return False
return v
for k, v in md.lists():
v = [convert_value(x) for x in v] if (len(v) > 1 or k.endswith("[]")) else convert_value(v[0])
nested_set(body, k.rstrip("[]").split("."), v)
def _update_call_data(self, call, req):
""" Use request payload/form to fill call data or batched data """
if req.content_type == "application/json-lines":
@ -96,23 +113,12 @@ class RequestHandlers:
req.on_json_loading_failed(msg)
call.batched_data = items
else:
json_body = req.get_json(force=True, silent=False) if req.data else None
# merge form and args
form = req.form.copy()
form.update(req.args)
form = form.to_dict()
# convert string numbers to floats
for key in form:
if form[key].replace(".", "", 1).isdigit():
if "." in form[key]:
form[key] = float(form[key])
else:
form[key] = int(form[key])
elif form[key].lower() == "true":
form[key] = True
elif form[key].lower() == "false":
form[key] = False
call.data = json_body or form or {}
body = (req.get_json(force=True, silent=False) if req.data else None) or {}
if req.args:
self._apply_multi_dict(body, req.args)
if req.form:
self._apply_multi_dict(body, req.form)
call.data = body
def _call_or_empty_with_error(self, call, req, msg, code=500, subcode=0):
call = call or APICall(

View File

@ -310,6 +310,12 @@ class APICall(DataContainer):
_transaction_headers = _get_headers("Trx")
""" Transaction ID """
_redacted_headers = {
HEADER_AUTHORIZATION: " ",
"Cookie": "=",
}
""" Headers whose value should be redacted. Maps header name to partition char """
@property
def HEADER_TRANSACTION(self):
return self._transaction_headers[0]
@ -673,3 +679,15 @@ class APICall(DataContainer):
error_data=error_data,
cookies=self._result.cookies,
)
def get_redacted_headers(self):
headers = self.headers.copy()
if not self.requires_authorization or self.auth:
# We won't log the authorization header if call shouldn't be authorized, or if it was successfully
# authorized. This means we'll only log authorization header for calls that failed to authorize (hopefully
# this will allow us to debug authorization errors).
for header, sep in self._redacted_headers.items():
if header in headers:
prefix, _, redact = headers[header].partition(sep)
headers[header] = prefix + sep + f"<{len(redact)} bytes redacted>"
return headers

View File

@ -38,7 +38,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.15")
_max_version = PartialVersion("2.16")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (

View File

@ -839,7 +839,7 @@ class ScalarMetricsIterRawScroll(Scroll):
)
@endpoint("events.scalar_metrics_iter_raw", min_version="999.0")
@endpoint("events.scalar_metrics_iter_raw", min_version="2.16")
def scalar_metrics_iter_raw(
call: APICall, company_id: str, request: ScalarMetricsIterRawRequest
):