Add support for cookies extensions

This commit is contained in:
allegroai 2024-03-18 15:46:07 +02:00
parent 33528870ae
commit 92a4e56c1f

View File

@ -21,6 +21,11 @@ log = config.logger(__file__)
class RequestHandlers: class RequestHandlers:
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None) _request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
_server_header = config.get("apiserver.response.headers.server", "clearml") _server_header = config.get("apiserver.response.headers.server", "clearml")
_custom_cookie_settings = {
c["name"]: c["settings"]
for c in config.get("apiserver.auth.custom_cookies", {}).values()
if c.get("enabled") and c.get("settings")
}
def before_request(self): def before_request(self):
if request.method == "OPTIONS": if request.method == "OPTIONS":
@ -29,7 +34,10 @@ class RequestHandlers:
return return
if request.content_encoding: if request.content_encoding:
return f"Content encoding is not supported ({request.content_encoding})", 415 return (
f"Content encoding is not supported ({request.content_encoding})",
415,
)
try: try:
call = self._create_api_call(request) call = self._create_api_call(request)
@ -70,7 +78,10 @@ class RequestHandlers:
if call.result.cookies: if call.result.cookies:
for key, value in call.result.cookies.items(): for key, value in call.result.cookies.items():
kwargs = config.get("apiserver.auth.cookies").copy() kwargs = (
self._custom_cookie_settings.get(key)
or config.get("apiserver.auth.cookies")
).copy()
if value is None: if value is None:
# Removing a cookie # Removing a cookie
kwargs["max_age"] = 0 kwargs["max_age"] = 0
@ -87,7 +98,9 @@ class RequestHandlers:
if company: if company:
try: try:
# use no default value to allow setting a null domain as well # use no default value to allow setting a null domain as well
kwargs["domain"] = config.get(f"apiserver.auth.cookies_domain_override.{company}") kwargs["domain"] = config.get(
f"apiserver.auth.cookies_domain_override.{company}"
)
except KeyError: except KeyError:
pass pass
@ -114,7 +127,11 @@ class RequestHandlers:
return v return v
for k, v in md.lists(): for k, v in md.lists():
v = [convert_value(x) for x in v] if (len(v) > 1 or k.endswith("[]")) else convert_value(v[0]) v = (
[convert_value(x) for x in v]
if (len(v) > 1 or k.endswith("[]"))
else convert_value(v[0])
)
nested_set(body, k.rstrip("[]").split("."), v) nested_set(body, k.rstrip("[]").split("."), v)
def _update_call_data(self, call, req): def _update_call_data(self, call, req):
@ -149,9 +166,7 @@ class RequestHandlers:
return call return call
def _get_session_auth_cookie(self, req): def _get_session_auth_cookie(self, req):
return req.cookies.get( return req.cookies.get(config.get("apiserver.auth.session_auth_cookie_name"))
config.get("apiserver.auth.session_auth_cookie_name")
)
def _create_api_call(self, req): def _create_api_call(self, req):
call = None call = None