diff --git a/apiserver/database/model/base.py b/apiserver/database/model/base.py index bd7d625..967c9cb 100644 --- a/apiserver/database/model/base.py +++ b/apiserver/database/model/base.py @@ -28,7 +28,7 @@ from apiserver.database import Database from apiserver.database.errors import MakeGetAllQueryError from apiserver.database.projection import project_dict, ProjectionHelper from apiserver.database.props import PropsMixin -from apiserver.database.query import RegexQ, RegexWrapper +from apiserver.database.query import RegexQ, RegexWrapper, RegexQCombination from apiserver.database.utils import ( get_company_or_none_constraint, get_fields_choices, @@ -131,6 +131,8 @@ class GetMixin(PropsMixin): "nop": (default_mongo_op, False), "all": ("all", True), "and": ("all", True), + "any": (default_mongo_op, True), + "or": (default_mongo_op, True), } def __init__(self, legacy=False): @@ -138,6 +140,13 @@ class GetMixin(PropsMixin): self._sticky = False self._support_legacy = legacy + def _get_op(self, v: str, translate: bool = False) -> Optional[str]: + op = v[len(self.op_prefix):] if v and v.startswith(self.op_prefix) else None + if translate: + tup = self._ops.get(op, None) + return tup[0] if tup else None + return op + def _key(self, v) -> Optional[Union[str, bool]]: if v is None: self._current_op = None @@ -151,16 +160,27 @@ class GetMixin(PropsMixin): elif self._support_legacy and v.startswith(self._legacy_exclude_prefix): self._current_op = None return False - elif v.startswith(self.op_prefix): - self._current_op, self._sticky = self._ops.get( - v[len(self.op_prefix):], (self.default_mongo_op, self._sticky) - ) - return None + else: + op = self._get_op(v) + if op is not None: + self._current_op, self._sticky = self._ops.get( + op, (self.default_mongo_op, self._sticky) + ) + return None return self.default_mongo_op - def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]: + def get_actions(self, data: Sequence[str]) -> Tuple[Dict[str, List[Union[str, None]]], Optional[str]]: actions = {} + if not data: + return actions, None + + global_op = self._get_op(data[0], translate=True) + if global_op in ("in", "all"): + data = data[1:] + else: + global_op = None + for val in data: key = self._key(val) if key is None: @@ -169,7 +189,8 @@ class GetMixin(PropsMixin): key = self._legacy_exclude_mongo_op val = val[len(self._legacy_exclude_prefix) :] actions.setdefault(key, []).append(val) - return actions + + return actions, global_op or self.default_mongo_op get_all_query_options = QueryParameterOptions() @@ -285,7 +306,7 @@ class GetMixin(PropsMixin): Prepare a query object based on the provided query dictionary and various fields. NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies. - IMPLEMENTATION NOTE: Make sure that inside this function or the function it depends on RegexQ is always + IMPLEMENTATION NOTE: Make sure that inside this function or the functions it depends on RegexQ is always used instead of Q. Otherwise we can and up with some combination that is not processed according to RegexQ rules :param parameters_options: Specifies options for parsing the parameters (see ParametersOptions) @@ -425,11 +446,14 @@ class GetMixin(PropsMixin): """ if not isinstance(data, (list, tuple)): data = [data] - # raise MakeGetAllQueryError("expected list", field) - actions = cls.ListFieldBucketHelper(legacy=True).get_actions(data) + actions, global_op = cls.ListFieldBucketHelper(legacy=True).get_actions(data) + + default_op = cls.ListFieldBucketHelper.default_mongo_op + + # Handle `allow_empty` hack: controlled using `None` as a specific value in the default "in" action allow_empty = False - default_op_actions = actions.get(cls.ListFieldBucketHelper.default_mongo_op) + default_op_actions = actions.get(default_op) if default_op_actions and None in default_op_actions: allow_empty = True default_op_actions.remove(None) @@ -438,10 +462,17 @@ class GetMixin(PropsMixin): mongoengine_field = field.replace(".", "__") - q = RegexQ() - for action in filter(None, actions): - q &= RegexQ( - **{f"{mongoengine_field}__{action}": list(set(actions[action]))} + queries = [ + RegexQ(**{f"{mongoengine_field}__{action}": list(set(actions[action]))}) + for action in filter(None, actions) + ] + + if not queries: + q = RegexQ() + else: + q = RegexQCombination( + operation=RegexQ.AND if global_op is not default_op else RegexQ.OR, + children=queries ) if not allow_empty: diff --git a/apiserver/tests/automated/test_organization.py b/apiserver/tests/automated/test_organization.py index 67923c2..7c19de1 100644 --- a/apiserver/tests/automated/test_organization.py +++ b/apiserver/tests/automated/test_organization.py @@ -2,9 +2,6 @@ from apiserver.tests.automated import TestService class TestOrganization(TestService): - def setUp(self, version="2.12"): - super().setUp(version=version) - def test_get_user_companies(self): company = self.api.organization.get_user_companies().companies[0] self.assertEqual(len(company.owners), company.allocated) diff --git a/apiserver/tests/automated/test_paging_and_scrolling.py b/apiserver/tests/automated/test_paging_and_scrolling.py index e129277..af6c008 100644 --- a/apiserver/tests/automated/test_paging_and_scrolling.py +++ b/apiserver/tests/automated/test_paging_and_scrolling.py @@ -2,7 +2,7 @@ import math from apiserver.tests.automated import TestService -class TestEntityOrdering(TestService): +class TestPagingAndScrolling(TestService): name_prefix = f"Test paging " def setUp(self, **kwargs): @@ -13,7 +13,16 @@ class TestEntityOrdering(TestService): tasks = [ self._temp_task( name=f"{self.name_prefix}{i}", - hyperparams={"test": {"param": {"section": "test", "name": "param", "type": "str", "value": str(i)}}}, + hyperparams={ + "test": { + "param": { + "section": "test", + "name": "param", + "type": "str", + "value": str(i), + } + } + }, ) for i in range(18) ] @@ -24,10 +33,7 @@ class TestEntityOrdering(TestService): for page in range(0, math.ceil(len(self.task_ids) / page_size)): start = page * page_size expected_size = min(page_size, len(self.task_ids) - start) - tasks = self._get_tasks( - page=page, - page_size=page_size, - ).tasks + tasks = self._get_tasks(page=page, page_size=page_size,).tasks self.assertEqual(len(tasks), expected_size) for i, t in enumerate(tasks): self.assertEqual(t.name, f"{self.name_prefix}{start + i}") @@ -38,10 +44,7 @@ class TestEntityOrdering(TestService): for page in range(0, math.ceil(len(self.task_ids) / page_size)): start = page * page_size expected_size = min(page_size, len(self.task_ids) - start) - res = self._get_tasks( - size=page_size, - scroll_id=scroll_id, - ) + res = self._get_tasks(size=page_size, scroll_id=scroll_id,) self.assertTrue(res.scroll_id) scroll_id = res.scroll_id tasks = res.tasks @@ -50,24 +53,19 @@ class TestEntityOrdering(TestService): self.assertEqual(t.name, f"{self.name_prefix}{start + i}") # no more data in this scroll - tasks = self._get_tasks( - size=page_size, - scroll_id=scroll_id, - ).tasks + tasks = self._get_tasks(size=page_size, scroll_id=scroll_id,).tasks self.assertFalse(tasks) # refresh brings all tasks = self._get_tasks( - size=page_size, - scroll_id=scroll_id, - refresh_scroll=True, + size=page_size, scroll_id=scroll_id, refresh_scroll=True, ).tasks self.assertEqual([t.id for t in tasks], self.task_ids) def _get_tasks(self, **page_params): return self.api.tasks.get_all_ex( name="^Test paging ", - order_by=["hyperparams.param"], + order_by=["hyperparams.test.param.value"], **page_params, )