Support filtering by children tags in projects.get_all_ex

This commit is contained in:
allegroai 2023-05-25 19:19:10 +03:00
parent e99817b28b
commit 818496236b
5 changed files with 108 additions and 24 deletions

View File

@ -82,3 +82,4 @@ class ProjectsGetRequest(models.Base):
search_hidden = fields.BoolField(default=False) search_hidden = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True) allow_public = fields.BoolField(default=True)
children_type = ActualEnumField(ProjectChildrenType) children_type = ActualEnumField(ProjectChildrenType)
children_tags = fields.ListField(str)

View File

@ -16,7 +16,6 @@ from typing import (
Any, Any,
) )
from boltons.iterutils import partition
from mongoengine import Q, Document from mongoengine import Q, Document
from apiserver import database from apiserver import database
@ -58,7 +57,7 @@ class ProjectBLL:
@classmethod @classmethod
def merge_project( def merge_project(
cls, company, source_id: str, destination_id: str cls, company: str, source_id: str, destination_id: str
) -> Tuple[int, int, Set[str]]: ) -> Tuple[int, int, Set[str]]:
""" """
Move all the tasks and sub projects from the source project to the destination Move all the tasks and sub projects from the source project to the destination
@ -901,6 +900,7 @@ class ProjectBLL:
project_ids: Optional[Sequence[str]] = None, project_ids: Optional[Sequence[str]] = None,
allow_public: bool = True, allow_public: bool = True,
children_type: ProjectChildrenType = None, children_type: ProjectChildrenType = None,
children_tags: Sequence[str] = None,
) -> Tuple[Sequence[str], Sequence[str]]: ) -> Tuple[Sequence[str], Sequence[str]]:
""" """
Get the projects ids matching children_condition (if passed) or where the passed user created any tasks Get the projects ids matching children_condition (if passed) or where the passed user created any tasks
@ -921,15 +921,20 @@ class ProjectBLL:
query &= Q(user__in=users) query &= Q(user__in=users)
project_query = None project_query = None
child_query = (
query & GetMixin.get_list_field_query("tags", children_tags)
if children_tags
else query
)
if children_type == ProjectChildrenType.dataset: if children_type == ProjectChildrenType.dataset:
child_queries = { child_queries = {
Project: query Project: child_query
& Q(system_tags__in=[dataset_tag], basename__ne=datasets_project_name) & Q(system_tags__in=[dataset_tag], basename__ne=datasets_project_name)
} }
elif children_type == ProjectChildrenType.pipeline: elif children_type == ProjectChildrenType.pipeline:
child_queries = {Task: query & Q(system_tags__in=[pipeline_tag])} child_queries = {Task: child_query & Q(system_tags__in=[pipeline_tag])}
elif children_type == ProjectChildrenType.report: elif children_type == ProjectChildrenType.report:
child_queries = {Task: query & Q(system_tags__in=[reports_tag])} child_queries = {Task: child_query & Q(system_tags__in=[reports_tag])}
else: else:
project_query = query project_query = query
child_queries = {entity_cls: query for entity_cls in cls.child_classes} child_queries = {entity_cls: query for entity_cls in cls.child_classes}
@ -1065,10 +1070,11 @@ class ProjectBLL:
raise errors.bad_request.ValidationError( raise errors.bad_request.ValidationError(
f"List of strings expected for the field: {field}" f"List of strings expected for the field: {field}"
) )
exclude, include = partition(field_filter, lambda x: x.startswith("-")) helper = GetMixin.ListFieldBucketHelper(field, legacy=True)
actions = helper.get_actions(field_filter)
conditions[field] = { conditions[field] = {
**({"$in": include} if include else {}), f"${action}": list(set(actions[action]))
**({"$nin": [e[1:] for e in exclude]} if exclude else {}), for action in filter(None, actions)
} }
return conditions return conditions

View File

@ -653,6 +653,13 @@ get_all_ex {
enum: [pipeline, report, dataset] enum: [pipeline, report, dataset]
} }
} }
"999.0": ${get_all_ex."2.24"} {
request.properties.children_tags {
description: "The list of tag values to filter children by. Takes effect only if children_type is set. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
type: array
items {type: string}
}
}
} }
update { update {
"2.1" { "2.1" {

View File

@ -7,12 +7,12 @@ from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidProjectId from apiserver.apierrors.errors.bad_request import InvalidProjectId
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
from apiserver.apimodels.projects import ( from apiserver.apimodels.projects import (
DeleteRequest,
GetParamsRequest, GetParamsRequest,
ProjectTagsRequest, ProjectTagsRequest,
ProjectTaskParentsRequest, ProjectTaskParentsRequest,
ProjectHyperparamValuesRequest, ProjectHyperparamValuesRequest,
ProjectsGetRequest, ProjectsGetRequest,
DeleteRequest,
MoveRequest, MoveRequest,
MergeRequest, MergeRequest,
ProjectRequest, ProjectRequest,
@ -99,19 +99,31 @@ def _adjust_search_parameters(data: dict, shallow_search: bool):
data["parent"] = [None] data["parent"] = [None]
def _get_project_stats_filter(request: ProjectsGetRequest) -> Tuple[Optional[dict], bool]: def _get_project_stats_filter(
request: ProjectsGetRequest,
) -> Tuple[Optional[dict], bool]:
if request.include_stats_filter or not request.children_type: if request.include_stats_filter or not request.children_type:
return request.include_stats_filter, request.search_hidden return request.include_stats_filter, request.search_hidden
stats_filter = {"tags": request.children_tags} if request.children_tags else {}
if request.children_type == ProjectChildrenType.pipeline: if request.children_type == ProjectChildrenType.pipeline:
return {"system_tags": [pipeline_tag], "type": [TaskType.controller]}, True return (
{
**stats_filter,
"system_tags": [pipeline_tag],
"type": [TaskType.controller],
},
True,
)
if request.children_type == ProjectChildrenType.report: if request.children_type == ProjectChildrenType.report:
return {"system_tags": [reports_tag], "type": [TaskType.report]}, True return (
{**stats_filter, "system_tags": [reports_tag], "type": [TaskType.report]},
return request.include_stats_filter, request.search_hidden True,
)
return stats_filter, request.search_hidden
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest) @endpoint("projects.get_all_ex")
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
data = call.data data = call.data
conform_tag_fields(call, data) conform_tag_fields(call, data)
@ -137,6 +149,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
project_ids=requested_ids, project_ids=requested_ids,
allow_public=allow_public, allow_public=allow_public,
children_type=request.children_type, children_type=request.children_type,
children_tags=request.children_tags,
) )
if not ids: if not ids:
return {"projects": []} return {"projects": []}
@ -174,19 +187,20 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
conform_output_tags(call, projects) conform_output_tags(call, projects)
project_ids = list({project["id"] for project in projects}) project_ids = list({project["id"] for project in projects})
stats_filter, stats_search_hidden = _get_project_stats_filter(request)
if request.check_own_contents: if request.check_own_contents:
if request.children_type == ProjectChildrenType.dataset: if request.children_type == ProjectChildrenType.dataset:
contents = project_bll.calc_own_datasets( contents = project_bll.calc_own_datasets(
company=company_id, company=company_id,
project_ids=project_ids, project_ids=project_ids,
filter_=request.include_stats_filter, filter_=stats_filter,
users=request.active_users, users=request.active_users,
) )
else: else:
contents = project_bll.calc_own_contents( contents = project_bll.calc_own_contents(
company=company_id, company=company_id,
project_ids=project_ids, project_ids=project_ids,
filter_=_get_project_stats_filter(request)[0], filter_=stats_filter,
users=request.active_users, users=request.active_users,
) )
@ -199,19 +213,18 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
company=company_id, company=company_id,
project_ids=project_ids, project_ids=project_ids,
include_children=request.stats_with_children, include_children=request.stats_with_children,
filter_=request.include_stats_filter, filter_=stats_filter,
users=request.active_users, users=request.active_users,
selected_project_ids=selected_project_ids, selected_project_ids=selected_project_ids,
) )
else: else:
filter_, search_hidden = _get_project_stats_filter(request)
stats, children = project_bll.get_project_stats( stats, children = project_bll.get_project_stats(
company=company_id, company=company_id,
project_ids=project_ids, project_ids=project_ids,
specific_state=request.stats_for_state, specific_state=request.stats_for_state,
include_children=request.stats_with_children, include_children=request.stats_with_children,
search_hidden=search_hidden, search_hidden=stats_search_hidden,
filter_=filter_, filter_=stats_filter,
users=request.active_users, users=request.active_users,
selected_project_ids=selected_project_ids, selected_project_ids=selected_project_ids,
) )
@ -348,7 +361,7 @@ def delete(call: APICall, company_id: str, request: DeleteRequest):
"projects.get_unique_metric_variants", request_data_model=GetUniqueMetricsRequest "projects.get_unique_metric_variants", request_data_model=GetUniqueMetricsRequest
) )
def get_unique_metric_variants( def get_unique_metric_variants(
call: APICall, company_id: str, request: GetUniqueMetricsRequest, call: APICall, company_id: str, request: GetUniqueMetricsRequest
): ):
metrics = project_queries.get_unique_metric_variants( metrics = project_queries.get_unique_metric_variants(
@ -361,7 +374,7 @@ def get_unique_metric_variants(
call.result.data = {"metrics": metrics} call.result.data = {"metrics": metrics}
@endpoint("projects.get_model_metadata_keys",) @endpoint("projects.get_model_metadata_keys")
def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest): def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest):
total, remaining, keys = project_queries.get_model_metadata_keys( total, remaining, keys = project_queries.get_model_metadata_keys(
company_id, company_id,
@ -505,7 +518,7 @@ def get_task_parents(
call: APICall, company_id: str, request: ProjectTaskParentsRequest call: APICall, company_id: str, request: ProjectTaskParentsRequest
): ):
call.result.data = { call.result.data = {
"parents": project_bll.get_task_parents( "parents": ProjectBLL.get_task_parents(
company_id, company_id,
projects=request.projects, projects=request.projects,
include_subprojects=request.include_subprojects, include_subprojects=request.include_subprojects,

View File

@ -33,6 +33,63 @@ class TestSubProjects(TestService):
).projects[0] ).projects[0]
self.assertEqual(data.dataset_stats, {"file_count": 2, "total_size": 1000}) self.assertEqual(data.dataset_stats, {"file_count": 2, "total_size": 1000})
def test_query_children_system_tags(self):
test_root_name = "TestQueryChildrenTags"
test_root = self._temp_project(name=test_root_name)
project1 = self._temp_project(name=f"{test_root_name}/project1")
project2 = self._temp_project(name=f"{test_root_name}/project2")
self._temp_report(name="test report", project=project1)
self._temp_report(name="test report", project=project2, tags=["test1", "test2"])
self._temp_report(name="test report", project=project2, tags=["test1"])
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type="report",
shallow_search=True,
include_stats=True,
check_own_contents=True,
).projects
self.assertEqual(len(projects), 2)
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type="report",
children_tags=["test1", "test2"],
shallow_search=True,
include_stats=True,
check_own_contents=True,
).projects
self.assertEqual(len(projects), 1)
p = projects[0]
self.assertEqual(p.basename, "project2")
self.assertEqual(p.stats.active.total_tasks, 2)
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type="report",
children_tags=["__$all", "test1", "test2"],
shallow_search=True,
include_stats=True,
check_own_contents=True,
).projects
self.assertEqual(len(projects), 1)
p = projects[0]
self.assertEqual(p.basename, "project2")
self.assertEqual(p.stats.active.total_tasks, 1)
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type="report",
children_tags=["-test1", "-test2"],
shallow_search=True,
include_stats=True,
check_own_contents=True,
).projects
self.assertEqual(len(projects), 1)
p = projects[0]
self.assertEqual(p.basename, "project1")
self.assertEqual(p.stats.active.total_tasks, 1)
def test_query_children(self): def test_query_children(self):
test_root_name = "TestQueryChildren" test_root_name = "TestQueryChildren"
test_root = self._temp_project(name=test_root_name) test_root = self._temp_project(name=test_root_name)