From 818496236b6303aa8bcf5793aff8819329465b9b Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 25 May 2023 19:19:10 +0300 Subject: [PATCH] Support filtering by children tags in projects.get_all_ex --- apiserver/apimodels/projects.py | 1 + apiserver/bll/project/project_bll.py | 22 ++++--- apiserver/schema/services/projects.conf | 7 +++ apiserver/services/projects.py | 45 +++++++++------ apiserver/tests/automated/test_subprojects.py | 57 +++++++++++++++++++ 5 files changed, 108 insertions(+), 24 deletions(-) diff --git a/apiserver/apimodels/projects.py b/apiserver/apimodels/projects.py index 339201b..53a5897 100644 --- a/apiserver/apimodels/projects.py +++ b/apiserver/apimodels/projects.py @@ -82,3 +82,4 @@ class ProjectsGetRequest(models.Base): search_hidden = fields.BoolField(default=False) allow_public = fields.BoolField(default=True) children_type = ActualEnumField(ProjectChildrenType) + children_tags = fields.ListField(str) diff --git a/apiserver/bll/project/project_bll.py b/apiserver/bll/project/project_bll.py index 347f26f..2f24013 100644 --- a/apiserver/bll/project/project_bll.py +++ b/apiserver/bll/project/project_bll.py @@ -16,7 +16,6 @@ from typing import ( Any, ) -from boltons.iterutils import partition from mongoengine import Q, Document from apiserver import database @@ -58,7 +57,7 @@ class ProjectBLL: @classmethod def merge_project( - cls, company, source_id: str, destination_id: str + cls, company: str, source_id: str, destination_id: str ) -> Tuple[int, int, Set[str]]: """ Move all the tasks and sub projects from the source project to the destination @@ -901,6 +900,7 @@ class ProjectBLL: project_ids: Optional[Sequence[str]] = None, allow_public: bool = True, children_type: ProjectChildrenType = None, + children_tags: Sequence[str] = None, ) -> Tuple[Sequence[str], Sequence[str]]: """ Get the projects ids matching children_condition (if passed) or where the passed user created any tasks @@ -921,15 +921,20 @@ class ProjectBLL: query &= Q(user__in=users) project_query = None + child_query = ( + query & GetMixin.get_list_field_query("tags", children_tags) + if children_tags + else query + ) if children_type == ProjectChildrenType.dataset: child_queries = { - Project: query + Project: child_query & Q(system_tags__in=[dataset_tag], basename__ne=datasets_project_name) } elif children_type == ProjectChildrenType.pipeline: - child_queries = {Task: query & Q(system_tags__in=[pipeline_tag])} + child_queries = {Task: child_query & Q(system_tags__in=[pipeline_tag])} elif children_type == ProjectChildrenType.report: - child_queries = {Task: query & Q(system_tags__in=[reports_tag])} + child_queries = {Task: child_query & Q(system_tags__in=[reports_tag])} else: project_query = query child_queries = {entity_cls: query for entity_cls in cls.child_classes} @@ -1065,10 +1070,11 @@ class ProjectBLL: raise errors.bad_request.ValidationError( f"List of strings expected for the field: {field}" ) - exclude, include = partition(field_filter, lambda x: x.startswith("-")) + helper = GetMixin.ListFieldBucketHelper(field, legacy=True) + actions = helper.get_actions(field_filter) conditions[field] = { - **({"$in": include} if include else {}), - **({"$nin": [e[1:] for e in exclude]} if exclude else {}), + f"${action}": list(set(actions[action])) + for action in filter(None, actions) } return conditions diff --git a/apiserver/schema/services/projects.conf b/apiserver/schema/services/projects.conf index a4c4e46..9a4b3db 100644 --- a/apiserver/schema/services/projects.conf +++ b/apiserver/schema/services/projects.conf @@ -653,6 +653,13 @@ get_all_ex { enum: [pipeline, report, dataset] } } + "999.0": ${get_all_ex."2.24"} { + request.properties.children_tags { + description: "The list of tag values to filter children by. Takes effect only if children_type is set. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded" + type: array + items {type: string} + } + } } update { "2.1" { diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py index 1a3161f..97453c0 100644 --- a/apiserver/services/projects.py +++ b/apiserver/services/projects.py @@ -7,12 +7,12 @@ from apiserver.apierrors import errors from apiserver.apierrors.errors.bad_request import InvalidProjectId from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse from apiserver.apimodels.projects import ( + DeleteRequest, GetParamsRequest, ProjectTagsRequest, ProjectTaskParentsRequest, ProjectHyperparamValuesRequest, ProjectsGetRequest, - DeleteRequest, MoveRequest, MergeRequest, ProjectRequest, @@ -99,19 +99,31 @@ def _adjust_search_parameters(data: dict, shallow_search: bool): data["parent"] = [None] -def _get_project_stats_filter(request: ProjectsGetRequest) -> Tuple[Optional[dict], bool]: +def _get_project_stats_filter( + request: ProjectsGetRequest, +) -> Tuple[Optional[dict], bool]: if request.include_stats_filter or not request.children_type: return request.include_stats_filter, request.search_hidden + stats_filter = {"tags": request.children_tags} if request.children_tags else {} if request.children_type == ProjectChildrenType.pipeline: - return {"system_tags": [pipeline_tag], "type": [TaskType.controller]}, True + return ( + { + **stats_filter, + "system_tags": [pipeline_tag], + "type": [TaskType.controller], + }, + True, + ) if request.children_type == ProjectChildrenType.report: - return {"system_tags": [reports_tag], "type": [TaskType.report]}, True - - return request.include_stats_filter, request.search_hidden + return ( + {**stats_filter, "system_tags": [reports_tag], "type": [TaskType.report]}, + True, + ) + return stats_filter, request.search_hidden -@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest) +@endpoint("projects.get_all_ex") def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): data = call.data conform_tag_fields(call, data) @@ -137,6 +149,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): project_ids=requested_ids, allow_public=allow_public, children_type=request.children_type, + children_tags=request.children_tags, ) if not ids: return {"projects": []} @@ -174,19 +187,20 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): conform_output_tags(call, projects) project_ids = list({project["id"] for project in projects}) + stats_filter, stats_search_hidden = _get_project_stats_filter(request) if request.check_own_contents: if request.children_type == ProjectChildrenType.dataset: contents = project_bll.calc_own_datasets( company=company_id, project_ids=project_ids, - filter_=request.include_stats_filter, + filter_=stats_filter, users=request.active_users, ) else: contents = project_bll.calc_own_contents( company=company_id, project_ids=project_ids, - filter_=_get_project_stats_filter(request)[0], + filter_=stats_filter, users=request.active_users, ) @@ -199,19 +213,18 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): company=company_id, project_ids=project_ids, include_children=request.stats_with_children, - filter_=request.include_stats_filter, + filter_=stats_filter, users=request.active_users, selected_project_ids=selected_project_ids, ) else: - filter_, search_hidden = _get_project_stats_filter(request) stats, children = project_bll.get_project_stats( company=company_id, project_ids=project_ids, specific_state=request.stats_for_state, include_children=request.stats_with_children, - search_hidden=search_hidden, - filter_=filter_, + search_hidden=stats_search_hidden, + filter_=stats_filter, users=request.active_users, selected_project_ids=selected_project_ids, ) @@ -348,7 +361,7 @@ def delete(call: APICall, company_id: str, request: DeleteRequest): "projects.get_unique_metric_variants", request_data_model=GetUniqueMetricsRequest ) def get_unique_metric_variants( - call: APICall, company_id: str, request: GetUniqueMetricsRequest, + call: APICall, company_id: str, request: GetUniqueMetricsRequest ): metrics = project_queries.get_unique_metric_variants( @@ -361,7 +374,7 @@ def get_unique_metric_variants( call.result.data = {"metrics": metrics} -@endpoint("projects.get_model_metadata_keys",) +@endpoint("projects.get_model_metadata_keys") def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest): total, remaining, keys = project_queries.get_model_metadata_keys( company_id, @@ -505,7 +518,7 @@ def get_task_parents( call: APICall, company_id: str, request: ProjectTaskParentsRequest ): call.result.data = { - "parents": project_bll.get_task_parents( + "parents": ProjectBLL.get_task_parents( company_id, projects=request.projects, include_subprojects=request.include_subprojects, diff --git a/apiserver/tests/automated/test_subprojects.py b/apiserver/tests/automated/test_subprojects.py index 1e5d5ee..2b41fb2 100644 --- a/apiserver/tests/automated/test_subprojects.py +++ b/apiserver/tests/automated/test_subprojects.py @@ -33,6 +33,63 @@ class TestSubProjects(TestService): ).projects[0] self.assertEqual(data.dataset_stats, {"file_count": 2, "total_size": 1000}) + def test_query_children_system_tags(self): + test_root_name = "TestQueryChildrenTags" + test_root = self._temp_project(name=test_root_name) + project1 = self._temp_project(name=f"{test_root_name}/project1") + project2 = self._temp_project(name=f"{test_root_name}/project2") + self._temp_report(name="test report", project=project1) + self._temp_report(name="test report", project=project2, tags=["test1", "test2"]) + self._temp_report(name="test report", project=project2, tags=["test1"]) + + projects = self.api.projects.get_all_ex( + parent=[test_root], + children_type="report", + shallow_search=True, + include_stats=True, + check_own_contents=True, + ).projects + self.assertEqual(len(projects), 2) + + projects = self.api.projects.get_all_ex( + parent=[test_root], + children_type="report", + children_tags=["test1", "test2"], + shallow_search=True, + include_stats=True, + check_own_contents=True, + ).projects + self.assertEqual(len(projects), 1) + p = projects[0] + self.assertEqual(p.basename, "project2") + self.assertEqual(p.stats.active.total_tasks, 2) + + projects = self.api.projects.get_all_ex( + parent=[test_root], + children_type="report", + children_tags=["__$all", "test1", "test2"], + shallow_search=True, + include_stats=True, + check_own_contents=True, + ).projects + self.assertEqual(len(projects), 1) + p = projects[0] + self.assertEqual(p.basename, "project2") + self.assertEqual(p.stats.active.total_tasks, 1) + + projects = self.api.projects.get_all_ex( + parent=[test_root], + children_type="report", + children_tags=["-test1", "-test2"], + shallow_search=True, + include_stats=True, + check_own_contents=True, + ).projects + self.assertEqual(len(projects), 1) + p = projects[0] + self.assertEqual(p.basename, "project1") + self.assertEqual(p.stats.active.total_tasks, 1) + def test_query_children(self): test_root_name = "TestQueryChildren" test_root = self._temp_project(name=test_root_name)