Add filters parameter for passing user defined list filters for all get_all_ex apis

This commit is contained in:
allegroai 2023-11-17 09:36:58 +02:00
parent 388dd1b01f
commit cc0129a800
12 changed files with 421 additions and 44 deletions

View File

@ -99,3 +99,4 @@ class ProjectsGetRequest(models.Base):
allow_public = fields.BoolField(default=True)
children_type = ActualEnumField(ProjectChildrenType)
children_tags = fields.ListField(str)
children_tags_filter = DictField()

View File

@ -551,7 +551,10 @@ class ProjectBLL:
@classmethod
def get_dataset_stats(
cls, company: str, project_ids: Sequence[str], users: Sequence[str] = None,
cls,
company: str,
project_ids: Sequence[str],
users: Sequence[str] = None,
) -> Dict[str, dict]:
if not project_ids:
return {}
@ -585,7 +588,9 @@ class ProjectBLL:
@staticmethod
def _get_projects_children(
project_ids: Sequence[str], search_hidden: bool, allowed_ids: Sequence[str],
project_ids: Sequence[str],
search_hidden: bool,
allowed_ids: Sequence[str],
) -> Tuple[ProjectsChildren, Set[str]]:
child_projects = _get_sub_projects(
project_ids,
@ -629,7 +634,9 @@ class ProjectBLL:
project_ids_with_children = set(project_ids)
if include_children:
child_projects, children_ids = cls._get_projects_children(
project_ids, search_hidden=True, allowed_ids=selected_project_ids,
project_ids,
search_hidden=True,
allowed_ids=selected_project_ids,
)
project_ids_with_children |= children_ids
@ -903,6 +910,7 @@ class ProjectBLL:
allow_public: bool = True,
children_type: ProjectChildrenType = None,
children_tags: Sequence[str] = None,
children_tags_filter: dict = None,
) -> Tuple[Sequence[str], Sequence[str]]:
"""
Get the projects ids matching children_condition (if passed) or where the passed user created any tasks
@ -923,11 +931,15 @@ class ProjectBLL:
query &= Q(user__in=users)
project_query = None
child_query = (
query & GetMixin.get_list_field_query("tags", children_tags)
if children_tags
else query
)
if children_tags_filter:
child_query = query & GetMixin.get_list_filter_query(
"tags", children_tags_filter
)
elif children_tags:
child_query = query & GetMixin.get_list_field_query("tags", children_tags)
else:
child_query = query
if children_type == ProjectChildrenType.dataset:
child_queries = {
Project: child_query
@ -1087,39 +1099,54 @@ class ProjectBLL:
or_conditions = []
for field, field_filter in filter_.items():
if not (
field_filter
and isinstance(field_filter, list)
and all(isinstance(t, str) for t in field_filter)
):
if not (field_filter and isinstance(field_filter, (list, dict))):
raise errors.bad_request.ValidationError(
f"List of strings expected for the field: {field}"
f"Non empty list or dictionary expected for the field: {field}"
)
helper = GetMixin.NewListFieldBucketHelper(
field, data=field_filter, legacy=True
)
field_conditions = {}
for action, values in helper.actions.items():
value = list(set(values))
for key in reversed(action.split("__")):
value = {f"${key}": value}
field_conditions.update(value)
if (
helper.explicit_operator
and helper.global_operator == Q.OR
and len(field_conditions) > 1
):
or_conditions.append(
[{field: {op: cond}} for op, cond in field_conditions.items()]
if isinstance(field_filter, list):
if not all(isinstance(t, str) for t in field_filter):
raise errors.bad_request.ValidationError(
f"Only string values are allowed in the list filter: {field}"
)
helper = GetMixin.NewListFieldBucketHelper(
field, data=field_filter, legacy=True
)
op = (
Q.OR
if helper.explicit_operator and helper.global_operator == Q.OR
else Q.AND
)
db_query = {op: helper.actions}
else:
conditions[field] = field_conditions
helper = GetMixin.ListQueryFilter.from_data(field, field_filter)
db_query = helper.db_query
for op, actions in db_query.items():
field_conditions = {}
for action, values in actions.items():
value = list(set(values))
for key in reversed(action.split("__")):
value = {f"${key}": value}
field_conditions.update(value)
if op == Q.OR and len(field_conditions) > 1:
or_conditions.append(
{
"$or": [
{field: {db_modifier: cond}}
for db_modifier, cond in field_conditions.items()
]
}
)
else:
conditions[field] = field_conditions
if or_conditions:
if len(or_conditions) == 1:
conditions["$or"] = next(iter(or_conditions))
conditions = next(iter(or_conditions))
else:
conditions["$and"] = [{"$or": c} for c in or_conditions]
conditions["$and"] = [c for c in or_conditions]
return conditions

View File

@ -197,9 +197,7 @@ class GetMixin(PropsMixin):
if self.global_operator is None:
self.global_operator = self.default_operator
def _get_next_term(
self, data: Sequence[str]
) -> Generator[Term, None, None]:
def _get_next_term(self, data: Sequence[str]) -> Generator[Term, None, None]:
unary_operator = None
for value in data:
if value is None:
@ -233,12 +231,18 @@ class GetMixin(PropsMixin):
operator = self._operators.get(value)
if operator is None:
raise FieldsValueError(
"Unsupported operator", field=self._field, operator=value,
"Unsupported operator",
field=self._field,
operator=value,
)
yield self.Term(operator=operator)
continue
if not unary_operator and self._support_legacy and value.startswith("-"):
if (
not unary_operator
and self._support_legacy
and value.startswith("-")
):
value = value[1:]
if not value:
raise FieldsValueError(
@ -403,12 +407,25 @@ class GetMixin(PropsMixin):
parameters = {
k: cls._get_fixed_field_value(k, v) for k, v in parameters.items()
}
filters = parameters.pop("filters", {})
if not isinstance(filters, dict):
raise FieldsValueError(
"invalid value type, string expected",
field=filters,
value=str(filters),
)
opts = parameters_options
for field in opts.pattern_fields:
pattern = parameters.pop(field, None)
if pattern:
dict_query[field] = RegexWrapper(pattern)
for field, data in cls._pop_matching_params(
patterns=opts.list_fields, parameters=filters
).items():
query &= cls.get_list_filter_query(field, data)
parameters.pop(field, None)
for field, data in cls._pop_matching_params(
patterns=opts.list_fields, parameters=parameters
).items():
@ -532,6 +549,135 @@ class GetMixin(PropsMixin):
return q
@attr.s(auto_attribs=True)
class ListQueryFilter:
"""
Deserialize filters data and build db_query object that represents it with the corresponding
mongo engine operations
Each part has include and exclude lists that map to mongoengine operations as following:
"any"
- include -> 'in'
- exclude -> 'not_all'
- combined by 'or' operation
"all"
- include -> 'all'
- exclude -> 'nin'
- combined by 'and' operation
"op" optional parameter for combining "and" and "all" parts. Can be "and" or "or". The default is "and"
"""
_and_op = "and"
_or_op = "or"
_allowed_op = [_and_op, _or_op]
_db_modifiers: Mapping = {
(Q.OR, True): "in",
(Q.OR, False): "not__all",
(Q.AND, True): "all",
(Q.AND, False): "nin",
}
@attr.s(auto_attribs=True)
class ListFilter:
include: Sequence[str] = []
exclude: Sequence[str] = []
@classmethod
def from_dict(cls, d: Mapping):
if d is None:
return None
return cls(**d)
any: ListFilter = attr.ib(converter=ListFilter.from_dict, default=None)
all: ListFilter = attr.ib(converter=ListFilter.from_dict, default=None)
op: str = attr.ib(default="and")
db_query: dict = attr.ib(init=False)
# noinspection PyUnresolvedReferences
@op.validator
def op_validator(self, _, value):
if value not in self._allowed_op:
raise ValueError(
f"Invalid list query filter operator: {value}. "
f"Should be one of {str(self._allowed_op)}"
)
@property
def and_op(self) -> bool:
return self.op == self._and_op
def __attrs_post_init__(self):
self.db_query = {}
for op, conditions in ((Q.OR, self.any), (Q.AND, self.all)):
if not conditions:
continue
operations = {}
for vals, include in (
(conditions.include, True),
(conditions.exclude, False),
):
if not vals:
continue
operations[self._db_modifiers[(op, include)]] = list(set(vals))
self.db_query[op] = operations
@classmethod
def from_data(cls, field, data: Mapping):
if not isinstance(data, dict):
raise errors.bad_request.ValidationError(
"invalid filter for field, dictionary expected",
field=field,
value=str(data),
)
try:
return cls(**data)
except Exception as ex:
raise errors.bad_request.ValidationError(
field=field,
value=str(ex),
)
@classmethod
def get_list_filter_query(
cls, field: str, data: Mapping
) -> Union[RegexQ, RegexQCombination]:
if not data:
return RegexQ()
filter_ = cls.ListQueryFilter.from_data(field, data)
mongoengine_field = field.replace(".", "__")
queries = []
for op, actions in filter_.db_query.items():
if not actions:
continue
ops = []
for action, vals in actions.items():
if not vals:
continue
ops.append(RegexQ(**{f"{mongoengine_field}__{action}": vals}))
if not ops:
continue
if len(ops) == 1:
queries.extend(ops)
continue
queries.append(RegexQCombination(operation=op, children=ops))
if not queries:
return RegexQ()
if len(queries) == 1:
return queries[0]
operation = Q.AND if filter_.and_op else Q.OR
return RegexQCombination(operation=operation, children=queries)
@classmethod
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
"""
@ -640,7 +786,7 @@ class GetMixin(PropsMixin):
@classmethod
def get_projection(cls, parameters, override_projection=None, **__):
""" Extract a projection list from the provided dictionary. Supports an override projection. """
"""Extract a projection list from the provided dictionary. Supports an override projection."""
if override_projection is not None:
return override_projection
if not parameters:
@ -654,7 +800,8 @@ class GetMixin(PropsMixin):
"""Return include and exclude lists based on passed projection and class definition"""
if projection:
include, exclude = partition(
projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
projection,
key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
)
else:
include, exclude = [], []
@ -901,7 +1048,9 @@ class GetMixin(PropsMixin):
projection_fields=projection_fields,
)
return cls.get_data_with_scroll_support(
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
query_dict=query_dict,
data_getter=data_getter,
ret_params=ret_params,
)
return cls._get_many_no_company(
@ -914,7 +1063,9 @@ class GetMixin(PropsMixin):
@classmethod
def get_many_public(
cls, query: Q = None, projection: Collection[str] = None,
cls,
query: Q = None,
projection: Collection[str] = None,
):
"""
Fetch all public documents matching a provided query.
@ -1207,7 +1358,7 @@ class UpdateMixin(object):
class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
""" Provide convenience methods for a subclass of mongoengine.Document """
"""Provide convenience methods for a subclass of mongoengine.Document"""
@classmethod
def aggregate(

View File

@ -1,3 +1,43 @@
field_filter {
type: object
description: Filter on a field that includes combination of 'any' or 'all' included and excluded terms
properties {
any {
type: object
description: All the terms in 'any' condition are combined with 'or' operation
properties {
"include" {
type: array
items {type: string}
}
exclude {
type: array
items {type: string}
}
}
}
all {
type: object
description: All the terms in 'all' condition are combined with 'and' operation
properties {
"include" {
type: array
items {type: string}
}
exclude {
type: array
items {type: string}
}
}
}
op {
type: string
description: The operation between 'any' and 'all' parts of the filter if both are provided
default: and
enum: [and, or]
}
}
}
metadata_item {
type: object
properties {

View File

@ -261,6 +261,14 @@ get_all_ex {
}
}
}
"999.0": ${get_all_ex."2.23"} {
request.properties {
filters {
type: object
additionalProperties: ${_definitions.field_filter}
}
}
}
}
get_all {
"2.1" {

View File

@ -1,5 +1,6 @@
_description: "Provides support for defining Projects containing Tasks, Models and Dataset Versions."
_definitions {
include "_common.conf"
multi_field_pattern_data {
type: object
properties {
@ -660,6 +661,15 @@ get_all_ex {
items {type: string}
}
}
"999.0": ${get_all_ex."2.25"} {
request.properties {
filters {
type: object
additionalProperties: ${_definitions.field_filter}
}
children_tags_filter: ${_definitions.field_filter}
}
}
}
update {
"2.1" {

View File

@ -159,6 +159,14 @@ get_all_ex {
default: false
}
}
"999.0": ${get_all_ex."2.21"} {
request.properties {
filters {
type: object
additionalProperties: ${_definitions.field_filter}
}
}
}
}
get_all {
"2.4" {

View File

@ -720,6 +720,14 @@ get_all_ex {
default: false
}
}
"999.0": ${get_all_ex."2.26"} {
request.properties {
filters {
type: object
additionalProperties: ${_definitions.field_filter}
}
}
}
}
get_tags {
"2.23" {

View File

@ -190,6 +190,14 @@ get_all_ex {
}
}
}
"999.0": ${get_all_ex."2.23"} {
request.properties {
filters {
type: object
additionalProperties: ${_definitions.field_filter}
}
}
}
}
get_all {
"2.1" {

View File

@ -108,7 +108,13 @@ def _get_project_stats_filter(
if request.include_stats_filter or not request.children_type:
return request.include_stats_filter, request.search_hidden
stats_filter = {"tags": request.children_tags} if request.children_tags else {}
if request.children_tags_filter:
stats_filter = {"tags": request.children_tags_filter}
elif request.children_tags:
stats_filter = {"tags": request.children_tags}
else:
stats_filter = {}
if request.children_type == ProjectChildrenType.pipeline:
return (
{
@ -153,6 +159,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
allow_public=allow_public,
children_type=request.children_type,
children_tags=request.children_tags,
children_tags_filter=request.children_tags_filter,
)
if not ids:
return {"projects": []}

View File

@ -0,0 +1,67 @@
from apiserver.apierrors import errors
from apiserver.tests.automated import TestService
class TestGetAllExFilters(TestService):
def test_list_filters(self):
tags = ["a", "b", "c", "d"]
tasks = [self._temp_task(tags=tags[:i]) for i in range(len(tags) + 1)]
# invalid params check
with self.api.raises(errors.bad_request.ValidationError):
self.api.tasks.get_all_ex(filters={"tags": {"test": ["1"]}})
# test any condition
res = self.api.tasks.get_all_ex(
id=tasks, filters={"tags": {"any": {"include": ["a", "b"]}}}
).tasks
self.assertEqual(set(tasks[1:]), set(t.id for t in res))
res = self.api.tasks.get_all_ex(
id=tasks, filters={"tags": {"any": {"exclude": ["c", "d"]}}}
).tasks
self.assertEqual(set(tasks[:-1]), set(t.id for t in res))
res = self.api.tasks.get_all_ex(
id=tasks,
filters={"tags": {"any": {"include": ["a", "b"], "exclude": ["c", "d"]}}},
).tasks
self.assertEqual(set(tasks), set(t.id for t in res))
# test all condition
res = self.api.tasks.get_all_ex(
id=tasks, filters={"tags": {"all": {"include": ["a", "b"]}}}
).tasks
self.assertEqual(set(tasks[2:]), set(t.id for t in res))
res = self.api.tasks.get_all_ex(
id=tasks, filters={"tags": {"all": {"exclude": ["c", "d"]}}}
).tasks
self.assertEqual(set(tasks[:-2]), set(t.id for t in res))
res = self.api.tasks.get_all_ex(
id=tasks,
filters={"tags": {"all": {"include": ["a", "b"], "exclude": ["c", "d"]}}},
).tasks
self.assertEqual([tasks[2]], [t.id for t in res])
# test combination
res = self.api.tasks.get_all_ex(
id=tasks,
filters={
"tags": {"any": {"include": ["a", "b"]}, "all": {"exclude": ["c", "d"]}}
},
).tasks
self.assertEqual(set(tasks[1:-2]), set(t.id for t in res))
def _temp_task(self, **kwargs):
self.update_missing(
kwargs,
name="test get_all_ex filters",
type="training",
)
return self.create_temp(
"tasks",
**kwargs,
delete_paramse=dict(can_fail=True, force=True),
)

View File

@ -64,6 +64,20 @@ class TestSubProjects(TestService):
self.assertEqual(p.basename, "project2")
self.assertEqual(p.stats.active.total_tasks, 2)
# new filter
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type="report",
children_tags_filter={"any": {"include": ["test1", "test2"]}},
shallow_search=True,
include_stats=True,
check_own_contents=True,
).projects
self.assertEqual(len(projects), 1)
p = projects[0]
self.assertEqual(p.basename, "project2")
self.assertEqual(p.stats.active.total_tasks, 2)
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type="report",
@ -77,6 +91,20 @@ class TestSubProjects(TestService):
self.assertEqual(p.basename, "project2")
self.assertEqual(p.stats.active.total_tasks, 1)
# new filter
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type="report",
children_tags_filter={"all": {"include": ["test1", "test2"]}},
shallow_search=True,
include_stats=True,
check_own_contents=True,
).projects
self.assertEqual(len(projects), 1)
p = projects[0]
self.assertEqual(p.basename, "project2")
self.assertEqual(p.stats.active.total_tasks, 1)
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type="report",
@ -102,6 +130,20 @@ class TestSubProjects(TestService):
for p in projects:
self.assertEqual(p.stats.active.total_tasks, 1)
# new filter
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type="report",
children_tags_filter={"all": {"exclude": ["test1", "test2"]}},
shallow_search=True,
include_stats=True,
check_own_contents=True,
).projects
self.assertEqual(len(projects), 1)
p = projects[0]
self.assertEqual(p.basename, "project1")
self.assertEqual(p.stats.active.total_tasks, 1)
def test_query_children(self):
test_root_name = "TestQueryChildren"
test_root = self._temp_project(name=test_root_name)