diff --git a/apiserver/apimodels/projects.py b/apiserver/apimodels/projects.py index a352c2a..933feb4 100644 --- a/apiserver/apimodels/projects.py +++ b/apiserver/apimodels/projects.py @@ -99,3 +99,4 @@ class ProjectsGetRequest(models.Base): allow_public = fields.BoolField(default=True) children_type = ActualEnumField(ProjectChildrenType) children_tags = fields.ListField(str) + children_tags_filter = DictField() diff --git a/apiserver/bll/project/project_bll.py b/apiserver/bll/project/project_bll.py index e55bb5e..0562df4 100644 --- a/apiserver/bll/project/project_bll.py +++ b/apiserver/bll/project/project_bll.py @@ -551,7 +551,10 @@ class ProjectBLL: @classmethod def get_dataset_stats( - cls, company: str, project_ids: Sequence[str], users: Sequence[str] = None, + cls, + company: str, + project_ids: Sequence[str], + users: Sequence[str] = None, ) -> Dict[str, dict]: if not project_ids: return {} @@ -585,7 +588,9 @@ class ProjectBLL: @staticmethod def _get_projects_children( - project_ids: Sequence[str], search_hidden: bool, allowed_ids: Sequence[str], + project_ids: Sequence[str], + search_hidden: bool, + allowed_ids: Sequence[str], ) -> Tuple[ProjectsChildren, Set[str]]: child_projects = _get_sub_projects( project_ids, @@ -629,7 +634,9 @@ class ProjectBLL: project_ids_with_children = set(project_ids) if include_children: child_projects, children_ids = cls._get_projects_children( - project_ids, search_hidden=True, allowed_ids=selected_project_ids, + project_ids, + search_hidden=True, + allowed_ids=selected_project_ids, ) project_ids_with_children |= children_ids @@ -903,6 +910,7 @@ class ProjectBLL: allow_public: bool = True, children_type: ProjectChildrenType = None, children_tags: Sequence[str] = None, + children_tags_filter: dict = None, ) -> Tuple[Sequence[str], Sequence[str]]: """ Get the projects ids matching children_condition (if passed) or where the passed user created any tasks @@ -923,11 +931,15 @@ class ProjectBLL: query &= Q(user__in=users) project_query = None - child_query = ( - query & GetMixin.get_list_field_query("tags", children_tags) - if children_tags - else query - ) + if children_tags_filter: + child_query = query & GetMixin.get_list_filter_query( + "tags", children_tags_filter + ) + elif children_tags: + child_query = query & GetMixin.get_list_field_query("tags", children_tags) + else: + child_query = query + if children_type == ProjectChildrenType.dataset: child_queries = { Project: child_query @@ -1087,39 +1099,54 @@ class ProjectBLL: or_conditions = [] for field, field_filter in filter_.items(): - if not ( - field_filter - and isinstance(field_filter, list) - and all(isinstance(t, str) for t in field_filter) - ): + if not (field_filter and isinstance(field_filter, (list, dict))): raise errors.bad_request.ValidationError( - f"List of strings expected for the field: {field}" + f"Non empty list or dictionary expected for the field: {field}" ) - helper = GetMixin.NewListFieldBucketHelper( - field, data=field_filter, legacy=True - ) - field_conditions = {} - for action, values in helper.actions.items(): - value = list(set(values)) - for key in reversed(action.split("__")): - value = {f"${key}": value} - field_conditions.update(value) - if ( - helper.explicit_operator - and helper.global_operator == Q.OR - and len(field_conditions) > 1 - ): - or_conditions.append( - [{field: {op: cond}} for op, cond in field_conditions.items()] + + if isinstance(field_filter, list): + if not all(isinstance(t, str) for t in field_filter): + raise errors.bad_request.ValidationError( + f"Only string values are allowed in the list filter: {field}" + ) + helper = GetMixin.NewListFieldBucketHelper( + field, data=field_filter, legacy=True ) + op = ( + Q.OR + if helper.explicit_operator and helper.global_operator == Q.OR + else Q.AND + ) + db_query = {op: helper.actions} else: - conditions[field] = field_conditions + helper = GetMixin.ListQueryFilter.from_data(field, field_filter) + db_query = helper.db_query + + for op, actions in db_query.items(): + field_conditions = {} + for action, values in actions.items(): + value = list(set(values)) + for key in reversed(action.split("__")): + value = {f"${key}": value} + field_conditions.update(value) + + if op == Q.OR and len(field_conditions) > 1: + or_conditions.append( + { + "$or": [ + {field: {db_modifier: cond}} + for db_modifier, cond in field_conditions.items() + ] + } + ) + else: + conditions[field] = field_conditions if or_conditions: if len(or_conditions) == 1: - conditions["$or"] = next(iter(or_conditions)) + conditions = next(iter(or_conditions)) else: - conditions["$and"] = [{"$or": c} for c in or_conditions] + conditions["$and"] = [c for c in or_conditions] return conditions diff --git a/apiserver/database/model/base.py b/apiserver/database/model/base.py index de3da1b..e3c5625 100644 --- a/apiserver/database/model/base.py +++ b/apiserver/database/model/base.py @@ -197,9 +197,7 @@ class GetMixin(PropsMixin): if self.global_operator is None: self.global_operator = self.default_operator - def _get_next_term( - self, data: Sequence[str] - ) -> Generator[Term, None, None]: + def _get_next_term(self, data: Sequence[str]) -> Generator[Term, None, None]: unary_operator = None for value in data: if value is None: @@ -233,12 +231,18 @@ class GetMixin(PropsMixin): operator = self._operators.get(value) if operator is None: raise FieldsValueError( - "Unsupported operator", field=self._field, operator=value, + "Unsupported operator", + field=self._field, + operator=value, ) yield self.Term(operator=operator) continue - if not unary_operator and self._support_legacy and value.startswith("-"): + if ( + not unary_operator + and self._support_legacy + and value.startswith("-") + ): value = value[1:] if not value: raise FieldsValueError( @@ -403,12 +407,25 @@ class GetMixin(PropsMixin): parameters = { k: cls._get_fixed_field_value(k, v) for k, v in parameters.items() } + filters = parameters.pop("filters", {}) + if not isinstance(filters, dict): + raise FieldsValueError( + "invalid value type, string expected", + field=filters, + value=str(filters), + ) opts = parameters_options for field in opts.pattern_fields: pattern = parameters.pop(field, None) if pattern: dict_query[field] = RegexWrapper(pattern) + for field, data in cls._pop_matching_params( + patterns=opts.list_fields, parameters=filters + ).items(): + query &= cls.get_list_filter_query(field, data) + parameters.pop(field, None) + for field, data in cls._pop_matching_params( patterns=opts.list_fields, parameters=parameters ).items(): @@ -532,6 +549,135 @@ class GetMixin(PropsMixin): return q + @attr.s(auto_attribs=True) + class ListQueryFilter: + """ + Deserialize filters data and build db_query object that represents it with the corresponding + mongo engine operations + Each part has include and exclude lists that map to mongoengine operations as following: + "any" + - include -> 'in' + - exclude -> 'not_all' + - combined by 'or' operation + "all" + - include -> 'all' + - exclude -> 'nin' + - combined by 'and' operation + "op" optional parameter for combining "and" and "all" parts. Can be "and" or "or". The default is "and" + """ + + _and_op = "and" + _or_op = "or" + _allowed_op = [_and_op, _or_op] + _db_modifiers: Mapping = { + (Q.OR, True): "in", + (Q.OR, False): "not__all", + (Q.AND, True): "all", + (Q.AND, False): "nin", + } + + @attr.s(auto_attribs=True) + class ListFilter: + include: Sequence[str] = [] + exclude: Sequence[str] = [] + + @classmethod + def from_dict(cls, d: Mapping): + if d is None: + return None + return cls(**d) + + any: ListFilter = attr.ib(converter=ListFilter.from_dict, default=None) + all: ListFilter = attr.ib(converter=ListFilter.from_dict, default=None) + op: str = attr.ib(default="and") + db_query: dict = attr.ib(init=False) + + # noinspection PyUnresolvedReferences + @op.validator + def op_validator(self, _, value): + if value not in self._allowed_op: + raise ValueError( + f"Invalid list query filter operator: {value}. " + f"Should be one of {str(self._allowed_op)}" + ) + + @property + def and_op(self) -> bool: + return self.op == self._and_op + + def __attrs_post_init__(self): + self.db_query = {} + for op, conditions in ((Q.OR, self.any), (Q.AND, self.all)): + if not conditions: + continue + + operations = {} + for vals, include in ( + (conditions.include, True), + (conditions.exclude, False), + ): + if not vals: + continue + operations[self._db_modifiers[(op, include)]] = list(set(vals)) + + self.db_query[op] = operations + + @classmethod + def from_data(cls, field, data: Mapping): + if not isinstance(data, dict): + raise errors.bad_request.ValidationError( + "invalid filter for field, dictionary expected", + field=field, + value=str(data), + ) + + try: + return cls(**data) + except Exception as ex: + raise errors.bad_request.ValidationError( + field=field, + value=str(ex), + ) + + @classmethod + def get_list_filter_query( + cls, field: str, data: Mapping + ) -> Union[RegexQ, RegexQCombination]: + if not data: + return RegexQ() + + filter_ = cls.ListQueryFilter.from_data(field, data) + + mongoengine_field = field.replace(".", "__") + queries = [] + for op, actions in filter_.db_query.items(): + if not actions: + continue + + ops = [] + for action, vals in actions.items(): + if not vals: + continue + + ops.append(RegexQ(**{f"{mongoengine_field}__{action}": vals})) + + if not ops: + continue + + if len(ops) == 1: + queries.extend(ops) + continue + + queries.append(RegexQCombination(operation=op, children=ops)) + + if not queries: + return RegexQ() + if len(queries) == 1: + return queries[0] + + operation = Q.AND if filter_.and_op else Q.OR + return RegexQCombination(operation=operation, children=queries) + @classmethod def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ: """ @@ -640,7 +786,7 @@ class GetMixin(PropsMixin): @classmethod def get_projection(cls, parameters, override_projection=None, **__): - """ Extract a projection list from the provided dictionary. Supports an override projection. """ + """Extract a projection list from the provided dictionary. Supports an override projection.""" if override_projection is not None: return override_projection if not parameters: @@ -654,7 +800,8 @@ class GetMixin(PropsMixin): """Return include and exclude lists based on passed projection and class definition""" if projection: include, exclude = partition( - projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix, + projection, + key=lambda x: x[0] != ProjectionHelper.exclusion_prefix, ) else: include, exclude = [], [] @@ -901,7 +1048,9 @@ class GetMixin(PropsMixin): projection_fields=projection_fields, ) return cls.get_data_with_scroll_support( - query_dict=query_dict, data_getter=data_getter, ret_params=ret_params, + query_dict=query_dict, + data_getter=data_getter, + ret_params=ret_params, ) return cls._get_many_no_company( @@ -914,7 +1063,9 @@ class GetMixin(PropsMixin): @classmethod def get_many_public( - cls, query: Q = None, projection: Collection[str] = None, + cls, + query: Q = None, + projection: Collection[str] = None, ): """ Fetch all public documents matching a provided query. @@ -1207,7 +1358,7 @@ class UpdateMixin(object): class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin): - """ Provide convenience methods for a subclass of mongoengine.Document """ + """Provide convenience methods for a subclass of mongoengine.Document""" @classmethod def aggregate( diff --git a/apiserver/schema/services/_common.conf b/apiserver/schema/services/_common.conf index 590a578..6982e52 100644 --- a/apiserver/schema/services/_common.conf +++ b/apiserver/schema/services/_common.conf @@ -1,3 +1,43 @@ +field_filter { + type: object + description: Filter on a field that includes combination of 'any' or 'all' included and excluded terms + properties { + any { + type: object + description: All the terms in 'any' condition are combined with 'or' operation + properties { + "include" { + type: array + items {type: string} + } + exclude { + type: array + items {type: string} + } + } + } + all { + type: object + description: All the terms in 'all' condition are combined with 'and' operation + properties { + "include" { + type: array + items {type: string} + } + exclude { + type: array + items {type: string} + } + } + } + op { + type: string + description: The operation between 'any' and 'all' parts of the filter if both are provided + default: and + enum: [and, or] + } + } +} metadata_item { type: object properties { diff --git a/apiserver/schema/services/models.conf b/apiserver/schema/services/models.conf index 502e3a2..a64c3e7 100644 --- a/apiserver/schema/services/models.conf +++ b/apiserver/schema/services/models.conf @@ -261,6 +261,14 @@ get_all_ex { } } } + "999.0": ${get_all_ex."2.23"} { + request.properties { + filters { + type: object + additionalProperties: ${_definitions.field_filter} + } + } + } } get_all { "2.1" { diff --git a/apiserver/schema/services/projects.conf b/apiserver/schema/services/projects.conf index 4821eab..7809b5d 100644 --- a/apiserver/schema/services/projects.conf +++ b/apiserver/schema/services/projects.conf @@ -1,5 +1,6 @@ _description: "Provides support for defining Projects containing Tasks, Models and Dataset Versions." _definitions { + include "_common.conf" multi_field_pattern_data { type: object properties { @@ -660,6 +661,15 @@ get_all_ex { items {type: string} } } + "999.0": ${get_all_ex."2.25"} { + request.properties { + filters { + type: object + additionalProperties: ${_definitions.field_filter} + } + children_tags_filter: ${_definitions.field_filter} + } + } } update { "2.1" { diff --git a/apiserver/schema/services/queues.conf b/apiserver/schema/services/queues.conf index 873ee06..ec744f4 100644 --- a/apiserver/schema/services/queues.conf +++ b/apiserver/schema/services/queues.conf @@ -159,6 +159,14 @@ get_all_ex { default: false } } + "999.0": ${get_all_ex."2.21"} { + request.properties { + filters { + type: object + additionalProperties: ${_definitions.field_filter} + } + } + } } get_all { "2.4" { diff --git a/apiserver/schema/services/reports.conf b/apiserver/schema/services/reports.conf index 4fe9b72..79c31c9 100644 --- a/apiserver/schema/services/reports.conf +++ b/apiserver/schema/services/reports.conf @@ -720,6 +720,14 @@ get_all_ex { default: false } } + "999.0": ${get_all_ex."2.26"} { + request.properties { + filters { + type: object + additionalProperties: ${_definitions.field_filter} + } + } + } } get_tags { "2.23" { diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index 9b6c6e0..f307007 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -190,6 +190,14 @@ get_all_ex { } } } + "999.0": ${get_all_ex."2.23"} { + request.properties { + filters { + type: object + additionalProperties: ${_definitions.field_filter} + } + } + } } get_all { "2.1" { diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py index 965c360..b0d701b 100644 --- a/apiserver/services/projects.py +++ b/apiserver/services/projects.py @@ -108,7 +108,13 @@ def _get_project_stats_filter( if request.include_stats_filter or not request.children_type: return request.include_stats_filter, request.search_hidden - stats_filter = {"tags": request.children_tags} if request.children_tags else {} + if request.children_tags_filter: + stats_filter = {"tags": request.children_tags_filter} + elif request.children_tags: + stats_filter = {"tags": request.children_tags} + else: + stats_filter = {} + if request.children_type == ProjectChildrenType.pipeline: return ( { @@ -153,6 +159,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): allow_public=allow_public, children_type=request.children_type, children_tags=request.children_tags, + children_tags_filter=request.children_tags_filter, ) if not ids: return {"projects": []} diff --git a/apiserver/tests/automated/test_get_all_ex_filters.py b/apiserver/tests/automated/test_get_all_ex_filters.py new file mode 100644 index 0000000..08a0979 --- /dev/null +++ b/apiserver/tests/automated/test_get_all_ex_filters.py @@ -0,0 +1,67 @@ +from apiserver.apierrors import errors +from apiserver.tests.automated import TestService + + +class TestGetAllExFilters(TestService): + def test_list_filters(self): + tags = ["a", "b", "c", "d"] + tasks = [self._temp_task(tags=tags[:i]) for i in range(len(tags) + 1)] + + # invalid params check + with self.api.raises(errors.bad_request.ValidationError): + self.api.tasks.get_all_ex(filters={"tags": {"test": ["1"]}}) + + # test any condition + res = self.api.tasks.get_all_ex( + id=tasks, filters={"tags": {"any": {"include": ["a", "b"]}}} + ).tasks + self.assertEqual(set(tasks[1:]), set(t.id for t in res)) + + res = self.api.tasks.get_all_ex( + id=tasks, filters={"tags": {"any": {"exclude": ["c", "d"]}}} + ).tasks + self.assertEqual(set(tasks[:-1]), set(t.id for t in res)) + + res = self.api.tasks.get_all_ex( + id=tasks, + filters={"tags": {"any": {"include": ["a", "b"], "exclude": ["c", "d"]}}}, + ).tasks + self.assertEqual(set(tasks), set(t.id for t in res)) + + # test all condition + res = self.api.tasks.get_all_ex( + id=tasks, filters={"tags": {"all": {"include": ["a", "b"]}}} + ).tasks + self.assertEqual(set(tasks[2:]), set(t.id for t in res)) + + res = self.api.tasks.get_all_ex( + id=tasks, filters={"tags": {"all": {"exclude": ["c", "d"]}}} + ).tasks + self.assertEqual(set(tasks[:-2]), set(t.id for t in res)) + + res = self.api.tasks.get_all_ex( + id=tasks, + filters={"tags": {"all": {"include": ["a", "b"], "exclude": ["c", "d"]}}}, + ).tasks + self.assertEqual([tasks[2]], [t.id for t in res]) + + # test combination + res = self.api.tasks.get_all_ex( + id=tasks, + filters={ + "tags": {"any": {"include": ["a", "b"]}, "all": {"exclude": ["c", "d"]}} + }, + ).tasks + self.assertEqual(set(tasks[1:-2]), set(t.id for t in res)) + + def _temp_task(self, **kwargs): + self.update_missing( + kwargs, + name="test get_all_ex filters", + type="training", + ) + return self.create_temp( + "tasks", + **kwargs, + delete_paramse=dict(can_fail=True, force=True), + ) diff --git a/apiserver/tests/automated/test_subprojects.py b/apiserver/tests/automated/test_subprojects.py index 9330ee7..0158391 100644 --- a/apiserver/tests/automated/test_subprojects.py +++ b/apiserver/tests/automated/test_subprojects.py @@ -64,6 +64,20 @@ class TestSubProjects(TestService): self.assertEqual(p.basename, "project2") self.assertEqual(p.stats.active.total_tasks, 2) + # new filter + projects = self.api.projects.get_all_ex( + parent=[test_root], + children_type="report", + children_tags_filter={"any": {"include": ["test1", "test2"]}}, + shallow_search=True, + include_stats=True, + check_own_contents=True, + ).projects + self.assertEqual(len(projects), 1) + p = projects[0] + self.assertEqual(p.basename, "project2") + self.assertEqual(p.stats.active.total_tasks, 2) + projects = self.api.projects.get_all_ex( parent=[test_root], children_type="report", @@ -77,6 +91,20 @@ class TestSubProjects(TestService): self.assertEqual(p.basename, "project2") self.assertEqual(p.stats.active.total_tasks, 1) + # new filter + projects = self.api.projects.get_all_ex( + parent=[test_root], + children_type="report", + children_tags_filter={"all": {"include": ["test1", "test2"]}}, + shallow_search=True, + include_stats=True, + check_own_contents=True, + ).projects + self.assertEqual(len(projects), 1) + p = projects[0] + self.assertEqual(p.basename, "project2") + self.assertEqual(p.stats.active.total_tasks, 1) + projects = self.api.projects.get_all_ex( parent=[test_root], children_type="report", @@ -102,6 +130,20 @@ class TestSubProjects(TestService): for p in projects: self.assertEqual(p.stats.active.total_tasks, 1) + # new filter + projects = self.api.projects.get_all_ex( + parent=[test_root], + children_type="report", + children_tags_filter={"all": {"exclude": ["test1", "test2"]}}, + shallow_search=True, + include_stats=True, + check_own_contents=True, + ).projects + self.assertEqual(len(projects), 1) + p = projects[0] + self.assertEqual(p.basename, "project1") + self.assertEqual(p.stats.active.total_tasks, 1) + def test_query_children(self): test_root_name = "TestQueryChildren" test_root = self._temp_project(name=test_root_name)