Add support for lists and nested fields in URL args and form

This commit is contained in:
allegroai 2022-02-13 19:52:05 +02:00
parent c4001b4037
commit 92fd98d5ad
6 changed files with 45 additions and 21 deletions

View File

@ -1220,7 +1220,7 @@ get_scalar_metric_data {
} }
} }
scalar_metrics_iter_raw { scalar_metrics_iter_raw {
"999.0" { "2.16" {
description: "Get raw data for a specific metric variants in the task" description: "Get raw data for a specific metric variants in the task"
request { request {
type: object type: object

View File

@ -539,7 +539,7 @@ get_all_ex {
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data" description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
} }
} }
"999.0": ${get_all_ex."2.15"} { "2.16": ${get_all_ex."2.15"} {
request.properties.stats_with_children { request.properties.stats_with_children {
description: "If include_stats flag is set then this flag contols whether the child projects tasks are taken into statistics or not" description: "If include_stats flag is set then this flag contols whether the child projects tasks are taken into statistics or not"
type: boolean type: boolean

View File

@ -1,6 +1,7 @@
from functools import partial from functools import partial
from flask import request, Response, redirect from flask import request, Response, redirect
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
from apiserver.apierrors import APIError from apiserver.apierrors import APIError
@ -10,6 +11,7 @@ from apiserver.service_repo import ServiceRepo, APICall
from apiserver.service_repo.auth import AuthType, Token from apiserver.service_repo.auth import AuthType, Token
from apiserver.service_repo.errors import PathParsingError from apiserver.service_repo.errors import PathParsingError
from apiserver.utilities import json from apiserver.utilities import json
from apiserver.utilities.dicts import nested_set
log = config.logger(__file__) log = config.logger(__file__)
@ -79,6 +81,21 @@ class RequestHandlers:
log.exception(f"Failed processing request {request.url}: {ex}") log.exception(f"Failed processing request {request.url}: {ex}")
return f"Failed processing request {request.url}", 500 return f"Failed processing request {request.url}", 500
@staticmethod
def _apply_multi_dict(body: dict, md: ImmutableMultiDict):
def convert_value(v: str):
if v.replace(".", "", 1).isdigit():
return float(v) if "." in v else int(v)
if v in ("true", "True", "TRUE"):
return True
if v in ("false", "False", "FALSE"):
return False
return v
for k, v in md.lists():
v = [convert_value(x) for x in v] if (len(v) > 1 or k.endswith("[]")) else convert_value(v[0])
nested_set(body, k.rstrip("[]").split("."), v)
def _update_call_data(self, call, req): def _update_call_data(self, call, req):
""" Use request payload/form to fill call data or batched data """ """ Use request payload/form to fill call data or batched data """
if req.content_type == "application/json-lines": if req.content_type == "application/json-lines":
@ -96,23 +113,12 @@ class RequestHandlers:
req.on_json_loading_failed(msg) req.on_json_loading_failed(msg)
call.batched_data = items call.batched_data = items
else: else:
json_body = req.get_json(force=True, silent=False) if req.data else None body = (req.get_json(force=True, silent=False) if req.data else None) or {}
# merge form and args if req.args:
form = req.form.copy() self._apply_multi_dict(body, req.args)
form.update(req.args) if req.form:
form = form.to_dict() self._apply_multi_dict(body, req.form)
# convert string numbers to floats call.data = body
for key in form:
if form[key].replace(".", "", 1).isdigit():
if "." in form[key]:
form[key] = float(form[key])
else:
form[key] = int(form[key])
elif form[key].lower() == "true":
form[key] = True
elif form[key].lower() == "false":
form[key] = False
call.data = json_body or form or {}
def _call_or_empty_with_error(self, call, req, msg, code=500, subcode=0): def _call_or_empty_with_error(self, call, req, msg, code=500, subcode=0):
call = call or APICall( call = call or APICall(

View File

@ -310,6 +310,12 @@ class APICall(DataContainer):
_transaction_headers = _get_headers("Trx") _transaction_headers = _get_headers("Trx")
""" Transaction ID """ """ Transaction ID """
_redacted_headers = {
HEADER_AUTHORIZATION: " ",
"Cookie": "=",
}
""" Headers whose value should be redacted. Maps header name to partition char """
@property @property
def HEADER_TRANSACTION(self): def HEADER_TRANSACTION(self):
return self._transaction_headers[0] return self._transaction_headers[0]
@ -673,3 +679,15 @@ class APICall(DataContainer):
error_data=error_data, error_data=error_data,
cookies=self._result.cookies, cookies=self._result.cookies,
) )
def get_redacted_headers(self):
headers = self.headers.copy()
if not self.requires_authorization or self.auth:
# We won't log the authorization header if call shouldn't be authorized, or if it was successfully
# authorized. This means we'll only log authorization header for calls that failed to authorize (hopefully
# this will allow us to debug authorization errors).
for header, sep in self._redacted_headers.items():
if header in headers:
prefix, _, redact = headers[header].partition(sep)
headers[header] = prefix + sep + f"<{len(redact)} bytes redacted>"
return headers

View File

@ -38,7 +38,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current """If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """ maximum """
_max_version = PartialVersion("2.15") _max_version = PartialVersion("2.16")
""" Maximum version number (the highest min_version value across all endpoints) """ """ Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = ( _endpoint_exp = (

View File

@ -839,7 +839,7 @@ class ScalarMetricsIterRawScroll(Scroll):
) )
@endpoint("events.scalar_metrics_iter_raw", min_version="999.0") @endpoint("events.scalar_metrics_iter_raw", min_version="2.16")
def scalar_metrics_iter_raw( def scalar_metrics_iter_raw(
call: APICall, company_id: str, request: ScalarMetricsIterRawRequest call: APICall, company_id: str, request: ScalarMetricsIterRawRequest
): ):