From 17cd48dada6333a1e1c5655f50a0ba5ffe1b45a4 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 13 Feb 2022 19:35:35 +0200 Subject: [PATCH] Add support for override cookie domains Support for community invitation alarms Remove duplicate property Add query optimizations --- apiserver/config/default/apiserver.conf | 5 ++++ apiserver/database/model/base.py | 26 ++++++++++++------- apiserver/schema/services/projects.conf | 4 --- apiserver/server_init/request_handlers.py | 18 ++++++++----- apiserver/service_repo/service_repo.py | 19 ++++++++++---- .../tests/automated/test_entity_ordering.py | 2 +- 6 files changed, 49 insertions(+), 25 deletions(-) diff --git a/apiserver/config/default/apiserver.conf b/apiserver/config/default/apiserver.conf index 4eedb52..51ab237 100644 --- a/apiserver/config/default/apiserver.conf +++ b/apiserver/config/default/apiserver.conf @@ -79,6 +79,11 @@ max_age: 99999999999 } + # provide a cookie domain override per company +# cookies_domain_override { +# : +# } + # # A list of fixed users # # Note: password may be bcrypt-hashed (generate using `python -c 'import bcrypt; print(bcrypt.hashpw("password", bcrypt.gensalt()))'`) # fixed_users { diff --git a/apiserver/database/model/base.py b/apiserver/database/model/base.py index a8770e0..23cbe4d 100644 --- a/apiserver/database/model/base.py +++ b/apiserver/database/model/base.py @@ -901,6 +901,9 @@ class GetMixin(PropsMixin): search_text = parameters.get(cls._search_text_key) order_by = cls.validate_order_by(parameters=parameters, search_text=search_text) start, size = cls.validate_paging(parameters=parameters) + if size is not None and size <= 0: + return [] + include, exclude = cls.split_projection( cls.get_projection(parameters, override_projection) ) @@ -937,18 +940,23 @@ class GetMixin(PropsMixin): # add paging ret = [] - for qs in query_sets: - qs_size = qs.count() - if qs_size < start: - start -= qs_size - continue + last_set = len(query_sets) - 1 + for i, qs in enumerate(query_sets): + last_size = len(ret) ret.extend( - obj.to_proper_dict(only=include) for obj in qs.skip(start).limit(size) + obj.to_proper_dict(only=include) + for obj in (qs.skip(start) if start else qs).limit(size) ) - if len(ret) >= size: + added = len(ret) - last_size + + if added > 0: + start = 0 + size = max(0, size - added) + elif i != last_set: + start -= min(start, qs.count()) + + if size <= 0: break - start = 0 - size -= len(ret) return ret diff --git a/apiserver/schema/services/projects.conf b/apiserver/schema/services/projects.conf index fd5b38f..d3f5f89 100644 --- a/apiserver/schema/services/projects.conf +++ b/apiserver/schema/services/projects.conf @@ -562,10 +562,6 @@ update { description: "Project name. Unique within the company." type: string } - description { - description: "Project description. " - type: string - } description { description: "Project description" type: string diff --git a/apiserver/server_init/request_handlers.py b/apiserver/server_init/request_handlers.py index bc7203d..4883cbf 100644 --- a/apiserver/server_init/request_handlers.py +++ b/apiserver/server_init/request_handlers.py @@ -29,7 +29,7 @@ class RequestHandlers: try: call = self._create_api_call(request) load_data_callback = partial(self._load_call_data, req=request) - content, content_type = ServiceRepo.handle_call( + content, content_type, company = ServiceRepo.handle_call( call, load_data_callback=load_data_callback ) @@ -51,14 +51,20 @@ class RequestHandlers: if call.result.cookies: for key, value in call.result.cookies.items(): - kwargs = config.get("apiserver.auth.cookies") + kwargs = config.get("apiserver.auth.cookies").copy() + if company: + try: + # use no default value to allow setting a null domain as well + kwargs["domain"] = config.get(f"apiserver.auth.cookies_domain_override.{company}") + except KeyError: + pass + if value is None: - kwargs = kwargs.copy() kwargs["max_age"] = 0 kwargs["expires"] = 0 - response.set_cookie(key, "", **kwargs) - else: - response.set_cookie(key, value, **kwargs) + value = "" + + response.set_cookie(key, value, **kwargs) return response except Exception as ex: diff --git a/apiserver/service_repo/service_repo.py b/apiserver/service_repo/service_repo.py index bb338de..491c94f 100644 --- a/apiserver/service_repo/service_repo.py +++ b/apiserver/service_repo/service_repo.py @@ -10,6 +10,7 @@ from apiserver.apierrors import APIError, errors from apiserver.config_repo import config from apiserver.utilities.partial_version import PartialVersion from .apicall import APICall +from .auth import Identity from .endpoint import Endpoint from .errors import MalformedPathError, InvalidVersionError, CallFailedError from .util import parse_return_stack_on_code @@ -233,19 +234,27 @@ class ServiceRepo(object): return subcode in subcode_list @classmethod - def _get_company( + def _get_identity( cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False - ) -> Optional[str]: + ) -> Optional[Identity]: authorize = endpoint and endpoint.authorize if ignore_error or not authorize: try: - return call.identity.company + return call.identity except Exception: return None - return call.identity.company + return call.identity + + @classmethod + def _get_company( + cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False + ) -> Optional[str]: + identity = cls._get_identity(call, endpoint=endpoint, ignore_error=ignore_error) + return None if identity is None else identity.company @classmethod def handle_call(cls, call: APICall, load_data_callback: Callable = None): + company = None try: if call.failed: raise CallFailedError() @@ -316,4 +325,4 @@ class ServiceRepo(object): else: log.error(console_msg) - return content, content_type + return content, content_type, company diff --git a/apiserver/tests/automated/test_entity_ordering.py b/apiserver/tests/automated/test_entity_ordering.py index 2966b51..b7c2e88 100644 --- a/apiserver/tests/automated/test_entity_ordering.py +++ b/apiserver/tests/automated/test_entity_ordering.py @@ -38,7 +38,7 @@ class TestEntityOrdering(TestService): self._assertGetTasksWithOrdering(order_by=order_field, page=0, page_size=20) field_vals = [] - page_size = 2 + page_size = 4 num_pages = 5 for page in range(num_pages): paged_tasks = self._get_page_tasks(