Add filtering on child projects in projects.get_all_ex

This commit is contained in:
allegroai 2023-03-23 19:06:49 +02:00
parent 2fb9288a6c
commit 74200a24bd
10 changed files with 122 additions and 38 deletions

View File

@ -1,4 +1,5 @@
from jsonmodels import models, fields
from jsonmodels.fields import EmbeddedField
from apiserver.apimodels import ListField, ActualEnumField, DictField
from apiserver.apimodels.organization import TagsRequest
@ -56,6 +57,10 @@ class ProjectModelMetadataValuesRequest(MultiProjectRequest):
allow_public = fields.BoolField(default=True)
class ChildrenCondition(models.Base):
system_tags = fields.ListField([str])
class ProjectsGetRequest(models.Base):
include_dataset_stats = fields.BoolField(default=False)
include_stats = fields.BoolField(default=False)
@ -68,3 +73,4 @@ class ProjectsGetRequest(models.Base):
shallow_search = fields.BoolField(default=False)
search_hidden = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)
children_condition = EmbeddedField(ChildrenCondition)

View File

@ -571,7 +571,7 @@ class ProjectBLL:
search_hidden: bool = False,
filter_: Mapping[str, Any] = None,
users: Sequence[str] = None,
user_active_project_ids: Sequence[str] = None,
selected_project_ids: Sequence[str] = None,
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
if not project_ids:
return {}, {}
@ -581,7 +581,7 @@ class ProjectBLL:
project_ids,
_only=("id", "name"),
search_hidden=search_hidden,
allowed_ids=user_active_project_ids,
allowed_ids=selected_project_ids,
)
if include_children
else {}
@ -753,46 +753,65 @@ class ProjectBLL:
return tags, system_tags
@classmethod
def get_projects_with_active_user(
def get_projects_with_selected_children(
cls,
company: str,
users: Sequence[str],
users: Sequence[str] = None,
project_ids: Optional[Sequence[str]] = None,
allow_public: bool = True,
children_condition: Mapping[str, Any] = None,
) -> Tuple[Sequence[str], Sequence[str]]:
"""
Get the projects ids where user created any tasks including all the parents of these projects
Get the projects ids matching children_condition (if passed) or where the passed user created any tasks
including all the parents of these projects
If project ids are specified then filter the results by these project ids
"""
query = Q(user__in=users)
if not (users or children_condition):
raise errors.bad_request.ValidationError(
"Either active users or children_condition should be specified"
)
if allow_public:
query &= get_company_or_none_constraint(company)
projects_query = Project.prepare_query(
company, parameters=children_condition, allow_public=allow_public
)
if children_condition:
contained_entities_query = None
else:
query &= Q(company=company)
contained_entities_query = (
get_company_or_none_constraint(company)
if allow_public
else Q(company=company)
)
if users:
user_query = Q(user__in=users)
projects_query &= user_query
if contained_entities_query:
contained_entities_query &= user_query
user_projects_query = query
if project_ids:
ids_with_children = _ids_with_children(project_ids)
query &= Q(project__in=ids_with_children)
user_projects_query &= Q(id__in=ids_with_children)
projects_query &= Q(id__in=ids_with_children)
if contained_entities_query:
contained_entities_query &= Q(project__in=ids_with_children)
res = {p.id for p in Project.objects(user_projects_query).only("id")}
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="project"))
res = {p.id for p in Project.objects(projects_query).only("id")}
if contained_entities_query:
for cls_ in (Task, Model):
res |= set(cls_.objects(contained_entities_query).distinct(field="project"))
res = list(res)
if not res:
return res, res
user_active_project_ids = _ids_with_parents(res)
selected_project_ids = _ids_with_parents(res)
filtered_ids = (
list(set(user_active_project_ids) & set(project_ids))
list(set(selected_project_ids) & set(project_ids))
if project_ids
else list(user_active_project_ids)
else list(selected_project_ids)
)
return filtered_ids, user_active_project_ids
return filtered_ids, selected_project_ids
@classmethod
def get_task_parents(

View File

@ -41,10 +41,6 @@
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
# but not declared in a data model
strict: false
aggregate {
allow_disk_use: true
}
}
elastic {

View File

@ -2,3 +2,8 @@ max_page_size: 500
# expiration time in seconds for the redis scroll states in get_many family of apis
scroll_state_expiration_seconds: 600
allow_disk_use {
# sort: true
aggregate: true
}

View File

@ -17,7 +17,7 @@ from typing import (
from boltons.iterutils import first, partition
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField, IntField
from mongoengine import Q, Document, ListField, StringField, IntField, QuerySet
from pymongo.command_cursor import CommandCursor
from apiserver.apierrors import errors, APIError
@ -39,7 +39,7 @@ from apiserver.redis_manager import redman
from apiserver.utilities.dicts import project_dict, exclude_fields_from_dict
log = config.logger("dbmodel")
mongo_conf = config.get("services._mongo")
ACCESS_REGEX = re.compile(r"^(?P<prefix>>=|>|<=|<)?(?P<value>.*)$")
ACCESS_MODIFIER = {">=": "gte", ">": "gt", "<=": "lte", "<": "lt"}
@ -158,7 +158,9 @@ class GetMixin(PropsMixin):
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
try:
op = (
v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None
v[len(self.op_prefix) :]
if v and v.startswith(self.op_prefix)
else None
)
if translate:
tup = self._ops.get(op, None)
@ -166,7 +168,9 @@ class GetMixin(PropsMixin):
return op
except AttributeError:
raise errors.bad_request.FieldsValueError(
"invalid value type, string expected", field=self._field, value=str(v)
"invalid value type, string expected",
field=self._field,
value=str(v),
)
def _key(self, v) -> Optional[Union[str, bool]]:
@ -233,8 +237,8 @@ class GetMixin(PropsMixin):
cls._cache_manager = RedisCacheManager(
state_class=cls.GetManyScrollState,
redis=redman.connection("apiserver"),
expiration_interval=config.get(
"services._mongo.scroll_state_expiration_seconds", 600
expiration_interval=mongo_conf.get(
"scroll_state_expiration_seconds", 600
),
)
@ -451,7 +455,9 @@ class GetMixin(PropsMixin):
raise
except Exception as ex:
raise errors.bad_request.FieldsValueError(
"failed parsing query field", error=str(ex), **({"field": field} if field else {})
"failed parsing query field",
error=str(ex),
**({"field": field} if field else {}),
)
return query & RegexQ(**dict_query)
@ -570,7 +576,7 @@ class GetMixin(PropsMixin):
if start is not None:
return start, cls.validate_scroll_size(parameters)
max_page_size = config.get("services._mongo.max_page_size", 500)
max_page_size = mongo_conf.get("max_page_size", 500)
page = parameters.get("page", default_page)
if page is not None and page < 0:
raise errors.bad_request.ValidationError("page must be >=0", field="page")
@ -880,6 +886,13 @@ class GetMixin(PropsMixin):
return cls._get_many_no_company(query=_query, override_projection=projection)
@staticmethod
def _get_qs_with_ordering(qs: QuerySet, order_by: Sequence):
disk_use_setting = mongo_conf.get("allow_disk_use.sort", None)
if disk_use_setting is not None:
qs = qs.allow_disk_use(disk_use_setting)
return qs.order_by(*order_by)
@classmethod
def _get_many_no_company(
cls: Union["GetMixin", Document],
@ -1173,7 +1186,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
kwargs.update(
allowDiskUse=allow_disk_use
if allow_disk_use is not None
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
else mongo_conf.get("allow_disk_use.aggregate", True)
)
return cls.objects.aggregate(pipeline, **kwargs)

View File

@ -19,6 +19,7 @@ from apiserver.database.fields import (
SafeSortedListField,
EmbeddedDocumentListField,
NullableStringField,
NoneType,
)
from apiserver.database.model import AttributedDocument
from apiserver.database.model.base import ProperDictMixin, GetMixin
@ -89,7 +90,9 @@ class Artifact(EmbeddedDocument):
content_size = LongField()
timestamp = LongField()
type_data = EmbeddedDocumentField(ArtifactTypeData)
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
display_data = SafeSortedListField(
ListField(UnionField((int, float, str, NoneType)))
)
class ParamsItem(EmbeddedDocument, ProperDictMixin):
@ -231,6 +234,7 @@ class Task(AttributedDocument):
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update"),
pattern_fields=("name", "comment", "report"),
fields=("execution.queue", "runtime.*", "models.input.model"),
)
id = StringField(primary_key=True)

View File

@ -620,6 +620,19 @@ get_all_ex {
}
}
}
"2.24": ${get_all_ex."2.23"} {
request.properties.children_condition {
description: The filter that any of the child projects should match in order that the parent will be included
type: object
properties {
system_tags {
description: The list of system tags to match from
type: string
}
}
additionalProperties: true
}
}
}
update {
"2.1" {

View File

@ -76,7 +76,7 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest):
requested_ids = data.get("id")
if isinstance(requested_ids, str):
requested_ids = [requested_ids]
ids, _ = project_bll.get_projects_with_active_user(
ids, _ = project_bll.get_projects_with_selected_children(
company=company,
users=request.active_users,
project_ids=requested_ids,

View File

@ -114,13 +114,16 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
_adjust_search_parameters(
data, shallow_search=request.shallow_search,
)
user_active_project_ids = None
if request.active_users:
ids, user_active_project_ids = project_bll.get_projects_with_active_user(
selected_project_ids = None
if request.active_users or request.children_condition:
ids, selected_project_ids = project_bll.get_projects_with_selected_children(
company=company_id,
users=request.active_users,
project_ids=requested_ids,
allow_public=allow_public,
children_condition=request.children_condition.to_struct()
if request.children_condition
else None,
)
if not ids:
return {"projects": []}
@ -158,7 +161,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
search_hidden=request.search_hidden,
filter_=request.include_stats_filter,
users=request.active_users,
user_active_project_ids=user_active_project_ids,
selected_project_ids=selected_project_ids,
)
for project in projects:

View File

@ -33,6 +33,31 @@ class TestSubProjects(TestService):
).projects[0]
self.assertEqual(data.dataset_stats, {"file_count": 2, "total_size": 1000})
def test_query_children(self):
test_root_name = "TestQueryChildren"
test_root = self._temp_project(name=test_root_name)
child_with_tag = self._temp_project(
name=f"{test_root_name}/Project1/WithTag", system_tags=["test"]
)
child_without_tag = self._temp_project(name=f"{test_root_name}/Project2/WithoutTag")
projects = self.api.projects.get_all_ex(parent=[test_root], shallow_search=True).projects
self.assertEqual({p.basename for p in projects}, {"Project1", "Project2"})
projects = self.api.projects.get_all_ex(
parent=[test_root], children_condition={"system_tags": ["test"]}, shallow_search=True
).projects
self.assertEqual({p.basename for p in projects}, {"Project1"})
projects = self.api.projects.get_all_ex(
parent=[projects[0].id], children_condition={"system_tags": ["test"]}, shallow_search=True
).projects
self.assertEqual(projects[0].id, child_with_tag)
projects = self.api.projects.get_all_ex(
parent=[test_root], children_condition={"system_tags": ["not existent"]}, shallow_search=True
).projects
self.assertEqual(len(projects), 0)
def test_project_aggregations(self):
"""This test requires user with user_auth_only... credentials in db"""
user2_client = APIClient(