From 74200a24bd3cca96fc6050c2dd38ca39df2f8803 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 23 Mar 2023 19:06:49 +0200 Subject: [PATCH] Add filtering on child projects in projects.get_all_ex --- apiserver/apimodels/projects.py | 6 ++ apiserver/bll/project/project_bll.py | 57 ++++++++++++------- apiserver/config/default/apiserver.conf | 4 -- apiserver/config/default/services/_mongo.conf | 5 ++ apiserver/database/model/base.py | 31 +++++++--- apiserver/database/model/task/task.py | 6 +- apiserver/schema/services/projects.conf | 13 +++++ apiserver/services/organization.py | 2 +- apiserver/services/projects.py | 11 ++-- apiserver/tests/automated/test_subprojects.py | 25 ++++++++ 10 files changed, 122 insertions(+), 38 deletions(-) diff --git a/apiserver/apimodels/projects.py b/apiserver/apimodels/projects.py index 755829a..0cbfb3f 100644 --- a/apiserver/apimodels/projects.py +++ b/apiserver/apimodels/projects.py @@ -1,4 +1,5 @@ from jsonmodels import models, fields +from jsonmodels.fields import EmbeddedField from apiserver.apimodels import ListField, ActualEnumField, DictField from apiserver.apimodels.organization import TagsRequest @@ -56,6 +57,10 @@ class ProjectModelMetadataValuesRequest(MultiProjectRequest): allow_public = fields.BoolField(default=True) +class ChildrenCondition(models.Base): + system_tags = fields.ListField([str]) + + class ProjectsGetRequest(models.Base): include_dataset_stats = fields.BoolField(default=False) include_stats = fields.BoolField(default=False) @@ -68,3 +73,4 @@ class ProjectsGetRequest(models.Base): shallow_search = fields.BoolField(default=False) search_hidden = fields.BoolField(default=False) allow_public = fields.BoolField(default=True) + children_condition = EmbeddedField(ChildrenCondition) diff --git a/apiserver/bll/project/project_bll.py b/apiserver/bll/project/project_bll.py index b623a7b..2656f0c 100644 --- a/apiserver/bll/project/project_bll.py +++ b/apiserver/bll/project/project_bll.py @@ -571,7 +571,7 @@ class ProjectBLL: search_hidden: bool = False, filter_: Mapping[str, Any] = None, users: Sequence[str] = None, - user_active_project_ids: Sequence[str] = None, + selected_project_ids: Sequence[str] = None, ) -> Tuple[Dict[str, dict], Dict[str, dict]]: if not project_ids: return {}, {} @@ -581,7 +581,7 @@ class ProjectBLL: project_ids, _only=("id", "name"), search_hidden=search_hidden, - allowed_ids=user_active_project_ids, + allowed_ids=selected_project_ids, ) if include_children else {} @@ -753,46 +753,65 @@ class ProjectBLL: return tags, system_tags @classmethod - def get_projects_with_active_user( + def get_projects_with_selected_children( cls, company: str, - users: Sequence[str], + users: Sequence[str] = None, project_ids: Optional[Sequence[str]] = None, allow_public: bool = True, + children_condition: Mapping[str, Any] = None, ) -> Tuple[Sequence[str], Sequence[str]]: """ - Get the projects ids where user created any tasks including all the parents of these projects + Get the projects ids matching children_condition (if passed) or where the passed user created any tasks + including all the parents of these projects If project ids are specified then filter the results by these project ids """ - query = Q(user__in=users) + if not (users or children_condition): + raise errors.bad_request.ValidationError( + "Either active users or children_condition should be specified" + ) - if allow_public: - query &= get_company_or_none_constraint(company) + projects_query = Project.prepare_query( + company, parameters=children_condition, allow_public=allow_public + ) + if children_condition: + contained_entities_query = None else: - query &= Q(company=company) + contained_entities_query = ( + get_company_or_none_constraint(company) + if allow_public + else Q(company=company) + ) + + if users: + user_query = Q(user__in=users) + projects_query &= user_query + if contained_entities_query: + contained_entities_query &= user_query - user_projects_query = query if project_ids: ids_with_children = _ids_with_children(project_ids) - query &= Q(project__in=ids_with_children) - user_projects_query &= Q(id__in=ids_with_children) + projects_query &= Q(id__in=ids_with_children) + if contained_entities_query: + contained_entities_query &= Q(project__in=ids_with_children) - res = {p.id for p in Project.objects(user_projects_query).only("id")} - for cls_ in (Task, Model): - res |= set(cls_.objects(query).distinct(field="project")) + res = {p.id for p in Project.objects(projects_query).only("id")} + if contained_entities_query: + for cls_ in (Task, Model): + res |= set(cls_.objects(contained_entities_query).distinct(field="project")) res = list(res) if not res: return res, res - user_active_project_ids = _ids_with_parents(res) + selected_project_ids = _ids_with_parents(res) filtered_ids = ( - list(set(user_active_project_ids) & set(project_ids)) + list(set(selected_project_ids) & set(project_ids)) if project_ids - else list(user_active_project_ids) + else list(selected_project_ids) ) - return filtered_ids, user_active_project_ids + return filtered_ids, selected_project_ids @classmethod def get_task_parents( diff --git a/apiserver/config/default/apiserver.conf b/apiserver/config/default/apiserver.conf index 4b5b67e..1b18911 100644 --- a/apiserver/config/default/apiserver.conf +++ b/apiserver/config/default/apiserver.conf @@ -41,10 +41,6 @@ # controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data # but not declared in a data model strict: false - - aggregate { - allow_disk_use: true - } } elastic { diff --git a/apiserver/config/default/services/_mongo.conf b/apiserver/config/default/services/_mongo.conf index ee9f49e..eada154 100644 --- a/apiserver/config/default/services/_mongo.conf +++ b/apiserver/config/default/services/_mongo.conf @@ -2,3 +2,8 @@ max_page_size: 500 # expiration time in seconds for the redis scroll states in get_many family of apis scroll_state_expiration_seconds: 600 + +allow_disk_use { + # sort: true + aggregate: true +} \ No newline at end of file diff --git a/apiserver/database/model/base.py b/apiserver/database/model/base.py index a9be188..f1583ab 100644 --- a/apiserver/database/model/base.py +++ b/apiserver/database/model/base.py @@ -17,7 +17,7 @@ from typing import ( from boltons.iterutils import first, partition from dateutil.parser import parse as parse_datetime -from mongoengine import Q, Document, ListField, StringField, IntField +from mongoengine import Q, Document, ListField, StringField, IntField, QuerySet from pymongo.command_cursor import CommandCursor from apiserver.apierrors import errors, APIError @@ -39,7 +39,7 @@ from apiserver.redis_manager import redman from apiserver.utilities.dicts import project_dict, exclude_fields_from_dict log = config.logger("dbmodel") - +mongo_conf = config.get("services._mongo") ACCESS_REGEX = re.compile(r"^(?P>=|>|<=|<)?(?P.*)$") ACCESS_MODIFIER = {">=": "gte", ">": "gt", "<=": "lte", "<": "lt"} @@ -158,7 +158,9 @@ class GetMixin(PropsMixin): def _get_op(self, v: str, translate: bool = False) -> Optional[str]: try: op = ( - v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None + v[len(self.op_prefix) :] + if v and v.startswith(self.op_prefix) + else None ) if translate: tup = self._ops.get(op, None) @@ -166,7 +168,9 @@ class GetMixin(PropsMixin): return op except AttributeError: raise errors.bad_request.FieldsValueError( - "invalid value type, string expected", field=self._field, value=str(v) + "invalid value type, string expected", + field=self._field, + value=str(v), ) def _key(self, v) -> Optional[Union[str, bool]]: @@ -233,8 +237,8 @@ class GetMixin(PropsMixin): cls._cache_manager = RedisCacheManager( state_class=cls.GetManyScrollState, redis=redman.connection("apiserver"), - expiration_interval=config.get( - "services._mongo.scroll_state_expiration_seconds", 600 + expiration_interval=mongo_conf.get( + "scroll_state_expiration_seconds", 600 ), ) @@ -451,7 +455,9 @@ class GetMixin(PropsMixin): raise except Exception as ex: raise errors.bad_request.FieldsValueError( - "failed parsing query field", error=str(ex), **({"field": field} if field else {}) + "failed parsing query field", + error=str(ex), + **({"field": field} if field else {}), ) return query & RegexQ(**dict_query) @@ -570,7 +576,7 @@ class GetMixin(PropsMixin): if start is not None: return start, cls.validate_scroll_size(parameters) - max_page_size = config.get("services._mongo.max_page_size", 500) + max_page_size = mongo_conf.get("max_page_size", 500) page = parameters.get("page", default_page) if page is not None and page < 0: raise errors.bad_request.ValidationError("page must be >=0", field="page") @@ -880,6 +886,13 @@ class GetMixin(PropsMixin): return cls._get_many_no_company(query=_query, override_projection=projection) + @staticmethod + def _get_qs_with_ordering(qs: QuerySet, order_by: Sequence): + disk_use_setting = mongo_conf.get("allow_disk_use.sort", None) + if disk_use_setting is not None: + qs = qs.allow_disk_use(disk_use_setting) + return qs.order_by(*order_by) + @classmethod def _get_many_no_company( cls: Union["GetMixin", Document], @@ -1173,7 +1186,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin): kwargs.update( allowDiskUse=allow_disk_use if allow_disk_use is not None - else config.get("apiserver.mongo.aggregate.allow_disk_use", True) + else mongo_conf.get("allow_disk_use.aggregate", True) ) return cls.objects.aggregate(pipeline, **kwargs) diff --git a/apiserver/database/model/task/task.py b/apiserver/database/model/task/task.py index b199120..d513398 100644 --- a/apiserver/database/model/task/task.py +++ b/apiserver/database/model/task/task.py @@ -19,6 +19,7 @@ from apiserver.database.fields import ( SafeSortedListField, EmbeddedDocumentListField, NullableStringField, + NoneType, ) from apiserver.database.model import AttributedDocument from apiserver.database.model.base import ProperDictMixin, GetMixin @@ -89,7 +90,9 @@ class Artifact(EmbeddedDocument): content_size = LongField() timestamp = LongField() type_data = EmbeddedDocumentField(ArtifactTypeData) - display_data = SafeSortedListField(ListField(UnionField((int, float, str)))) + display_data = SafeSortedListField( + ListField(UnionField((int, float, str, NoneType))) + ) class ParamsItem(EmbeddedDocument, ProperDictMixin): @@ -231,6 +234,7 @@ class Task(AttributedDocument): range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"), datetime_fields=("status_changed", "last_update"), pattern_fields=("name", "comment", "report"), + fields=("execution.queue", "runtime.*", "models.input.model"), ) id = StringField(primary_key=True) diff --git a/apiserver/schema/services/projects.conf b/apiserver/schema/services/projects.conf index de0f80a..7ffe9df 100644 --- a/apiserver/schema/services/projects.conf +++ b/apiserver/schema/services/projects.conf @@ -620,6 +620,19 @@ get_all_ex { } } } + "2.24": ${get_all_ex."2.23"} { + request.properties.children_condition { + description: The filter that any of the child projects should match in order that the parent will be included + type: object + properties { + system_tags { + description: The list of system tags to match from + type: string + } + } + additionalProperties: true + } + } } update { "2.1" { diff --git a/apiserver/services/organization.py b/apiserver/services/organization.py index 30395c8..1a4f625 100644 --- a/apiserver/services/organization.py +++ b/apiserver/services/organization.py @@ -76,7 +76,7 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest): requested_ids = data.get("id") if isinstance(requested_ids, str): requested_ids = [requested_ids] - ids, _ = project_bll.get_projects_with_active_user( + ids, _ = project_bll.get_projects_with_selected_children( company=company, users=request.active_users, project_ids=requested_ids, diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py index 1bb4118..7d1c324 100644 --- a/apiserver/services/projects.py +++ b/apiserver/services/projects.py @@ -114,13 +114,16 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): _adjust_search_parameters( data, shallow_search=request.shallow_search, ) - user_active_project_ids = None - if request.active_users: - ids, user_active_project_ids = project_bll.get_projects_with_active_user( + selected_project_ids = None + if request.active_users or request.children_condition: + ids, selected_project_ids = project_bll.get_projects_with_selected_children( company=company_id, users=request.active_users, project_ids=requested_ids, allow_public=allow_public, + children_condition=request.children_condition.to_struct() + if request.children_condition + else None, ) if not ids: return {"projects": []} @@ -158,7 +161,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): search_hidden=request.search_hidden, filter_=request.include_stats_filter, users=request.active_users, - user_active_project_ids=user_active_project_ids, + selected_project_ids=selected_project_ids, ) for project in projects: diff --git a/apiserver/tests/automated/test_subprojects.py b/apiserver/tests/automated/test_subprojects.py index ebaa488..44c327a 100644 --- a/apiserver/tests/automated/test_subprojects.py +++ b/apiserver/tests/automated/test_subprojects.py @@ -33,6 +33,31 @@ class TestSubProjects(TestService): ).projects[0] self.assertEqual(data.dataset_stats, {"file_count": 2, "total_size": 1000}) + def test_query_children(self): + test_root_name = "TestQueryChildren" + test_root = self._temp_project(name=test_root_name) + child_with_tag = self._temp_project( + name=f"{test_root_name}/Project1/WithTag", system_tags=["test"] + ) + child_without_tag = self._temp_project(name=f"{test_root_name}/Project2/WithoutTag") + + projects = self.api.projects.get_all_ex(parent=[test_root], shallow_search=True).projects + self.assertEqual({p.basename for p in projects}, {"Project1", "Project2"}) + + projects = self.api.projects.get_all_ex( + parent=[test_root], children_condition={"system_tags": ["test"]}, shallow_search=True + ).projects + self.assertEqual({p.basename for p in projects}, {"Project1"}) + projects = self.api.projects.get_all_ex( + parent=[projects[0].id], children_condition={"system_tags": ["test"]}, shallow_search=True + ).projects + self.assertEqual(projects[0].id, child_with_tag) + + projects = self.api.projects.get_all_ex( + parent=[test_root], children_condition={"system_tags": ["not existent"]}, shallow_search=True + ).projects + self.assertEqual(len(projects), 0) + def test_project_aggregations(self): """This test requires user with user_auth_only... credentials in db""" user2_client = APIClient(