From 92a4e56c1ff422f467fbba6257f73083028b7284 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 18 Mar 2024 15:46:07 +0200 Subject: [PATCH] Add support for cookies extensions --- apiserver/server_init/request_handlers.py | 31 +++++++++++++++++------ 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/apiserver/server_init/request_handlers.py b/apiserver/server_init/request_handlers.py index a8c9182..6e15d10 100644 --- a/apiserver/server_init/request_handlers.py +++ b/apiserver/server_init/request_handlers.py @@ -21,6 +21,11 @@ log = config.logger(__file__) class RequestHandlers: _request_strip_prefix = config.get("apiserver.request.strip_prefix", None) _server_header = config.get("apiserver.response.headers.server", "clearml") + _custom_cookie_settings = { + c["name"]: c["settings"] + for c in config.get("apiserver.auth.custom_cookies", {}).values() + if c.get("enabled") and c.get("settings") + } def before_request(self): if request.method == "OPTIONS": @@ -29,7 +34,10 @@ class RequestHandlers: return if request.content_encoding: - return f"Content encoding is not supported ({request.content_encoding})", 415 + return ( + f"Content encoding is not supported ({request.content_encoding})", + 415, + ) try: call = self._create_api_call(request) @@ -70,7 +78,10 @@ class RequestHandlers: if call.result.cookies: for key, value in call.result.cookies.items(): - kwargs = config.get("apiserver.auth.cookies").copy() + kwargs = ( + self._custom_cookie_settings.get(key) + or config.get("apiserver.auth.cookies") + ).copy() if value is None: # Removing a cookie kwargs["max_age"] = 0 @@ -87,7 +98,9 @@ class RequestHandlers: if company: try: # use no default value to allow setting a null domain as well - kwargs["domain"] = config.get(f"apiserver.auth.cookies_domain_override.{company}") + kwargs["domain"] = config.get( + f"apiserver.auth.cookies_domain_override.{company}" + ) except KeyError: pass @@ -114,11 +127,15 @@ class RequestHandlers: return v for k, v in md.lists(): - v = [convert_value(x) for x in v] if (len(v) > 1 or k.endswith("[]")) else convert_value(v[0]) + v = ( + [convert_value(x) for x in v] + if (len(v) > 1 or k.endswith("[]")) + else convert_value(v[0]) + ) nested_set(body, k.rstrip("[]").split("."), v) def _update_call_data(self, call, req): - """ Use request payload/form to fill call data or batched data """ + """Use request payload/form to fill call data or batched data""" if req.content_type == "application/json-lines": items = [] for i, line in enumerate(req.data.splitlines()): @@ -149,9 +166,7 @@ class RequestHandlers: return call def _get_session_auth_cookie(self, req): - return req.cookies.get( - config.get("apiserver.auth.session_auth_cookie_name") - ) + return req.cookies.get(config.get("apiserver.auth.session_auth_cookie_name")) def _create_api_call(self, req): call = None