Add support for override cookie domains

Support for community invitation alarms
Remove duplicate property
Add query optimizations
This commit is contained in:
allegroai 2022-02-13 19:35:35 +02:00
parent ea3b6e955f
commit 17cd48dada
6 changed files with 49 additions and 25 deletions

View File

@ -79,6 +79,11 @@
max_age: 99999999999 max_age: 99999999999
} }
# provide a cookie domain override per company
# cookies_domain_override {
# <company-id>: <domain>
# }
# # A list of fixed users # # A list of fixed users
# # Note: password may be bcrypt-hashed (generate using `python -c 'import bcrypt; print(bcrypt.hashpw("password", bcrypt.gensalt()))'`) # # Note: password may be bcrypt-hashed (generate using `python -c 'import bcrypt; print(bcrypt.hashpw("password", bcrypt.gensalt()))'`)
# fixed_users { # fixed_users {

View File

@ -901,6 +901,9 @@ class GetMixin(PropsMixin):
search_text = parameters.get(cls._search_text_key) search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text) order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
start, size = cls.validate_paging(parameters=parameters) start, size = cls.validate_paging(parameters=parameters)
if size is not None and size <= 0:
return []
include, exclude = cls.split_projection( include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection) cls.get_projection(parameters, override_projection)
) )
@ -937,18 +940,23 @@ class GetMixin(PropsMixin):
# add paging # add paging
ret = [] ret = []
for qs in query_sets: last_set = len(query_sets) - 1
qs_size = qs.count() for i, qs in enumerate(query_sets):
if qs_size < start: last_size = len(ret)
start -= qs_size
continue
ret.extend( ret.extend(
obj.to_proper_dict(only=include) for obj in qs.skip(start).limit(size) obj.to_proper_dict(only=include)
for obj in (qs.skip(start) if start else qs).limit(size)
) )
if len(ret) >= size: added = len(ret) - last_size
break
if added > 0:
start = 0 start = 0
size -= len(ret) size = max(0, size - added)
elif i != last_set:
start -= min(start, qs.count())
if size <= 0:
break
return ret return ret

View File

@ -562,10 +562,6 @@ update {
description: "Project name. Unique within the company." description: "Project name. Unique within the company."
type: string type: string
} }
description {
description: "Project description. "
type: string
}
description { description {
description: "Project description" description: "Project description"
type: string type: string

View File

@ -29,7 +29,7 @@ class RequestHandlers:
try: try:
call = self._create_api_call(request) call = self._create_api_call(request)
load_data_callback = partial(self._load_call_data, req=request) load_data_callback = partial(self._load_call_data, req=request)
content, content_type = ServiceRepo.handle_call( content, content_type, company = ServiceRepo.handle_call(
call, load_data_callback=load_data_callback call, load_data_callback=load_data_callback
) )
@ -51,13 +51,19 @@ class RequestHandlers:
if call.result.cookies: if call.result.cookies:
for key, value in call.result.cookies.items(): for key, value in call.result.cookies.items():
kwargs = config.get("apiserver.auth.cookies") kwargs = config.get("apiserver.auth.cookies").copy()
if company:
try:
# use no default value to allow setting a null domain as well
kwargs["domain"] = config.get(f"apiserver.auth.cookies_domain_override.{company}")
except KeyError:
pass
if value is None: if value is None:
kwargs = kwargs.copy()
kwargs["max_age"] = 0 kwargs["max_age"] = 0
kwargs["expires"] = 0 kwargs["expires"] = 0
response.set_cookie(key, "", **kwargs) value = ""
else:
response.set_cookie(key, value, **kwargs) response.set_cookie(key, value, **kwargs)
return response return response

View File

@ -10,6 +10,7 @@ from apiserver.apierrors import APIError, errors
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.utilities.partial_version import PartialVersion from apiserver.utilities.partial_version import PartialVersion
from .apicall import APICall from .apicall import APICall
from .auth import Identity
from .endpoint import Endpoint from .endpoint import Endpoint
from .errors import MalformedPathError, InvalidVersionError, CallFailedError from .errors import MalformedPathError, InvalidVersionError, CallFailedError
from .util import parse_return_stack_on_code from .util import parse_return_stack_on_code
@ -233,19 +234,27 @@ class ServiceRepo(object):
return subcode in subcode_list return subcode in subcode_list
@classmethod @classmethod
def _get_company( def _get_identity(
cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False
) -> Optional[str]: ) -> Optional[Identity]:
authorize = endpoint and endpoint.authorize authorize = endpoint and endpoint.authorize
if ignore_error or not authorize: if ignore_error or not authorize:
try: try:
return call.identity.company return call.identity
except Exception: except Exception:
return None return None
return call.identity.company return call.identity
@classmethod
def _get_company(
cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False
) -> Optional[str]:
identity = cls._get_identity(call, endpoint=endpoint, ignore_error=ignore_error)
return None if identity is None else identity.company
@classmethod @classmethod
def handle_call(cls, call: APICall, load_data_callback: Callable = None): def handle_call(cls, call: APICall, load_data_callback: Callable = None):
company = None
try: try:
if call.failed: if call.failed:
raise CallFailedError() raise CallFailedError()
@ -316,4 +325,4 @@ class ServiceRepo(object):
else: else:
log.error(console_msg) log.error(console_msg)
return content, content_type return content, content_type, company

View File

@ -38,7 +38,7 @@ class TestEntityOrdering(TestService):
self._assertGetTasksWithOrdering(order_by=order_field, page=0, page_size=20) self._assertGetTasksWithOrdering(order_by=order_field, page=0, page_size=20)
field_vals = [] field_vals = []
page_size = 2 page_size = 4
num_pages = 5 num_pages = 5
for page in range(num_pages): for page in range(num_pages):
paged_tasks = self._get_page_tasks( paged_tasks = self._get_page_tasks(