Support filtering by children tags in projects.get_all_ex

This commit is contained in:
allegroai 2023-05-25 19:19:10 +03:00
parent e99817b28b
commit 818496236b
5 changed files with 108 additions and 24 deletions

View File

@ -82,3 +82,4 @@ class ProjectsGetRequest(models.Base):
search_hidden = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)
children_type = ActualEnumField(ProjectChildrenType)
children_tags = fields.ListField(str)

View File

@ -16,7 +16,6 @@ from typing import (
Any,
)
from boltons.iterutils import partition
from mongoengine import Q, Document
from apiserver import database
@ -58,7 +57,7 @@ class ProjectBLL:
@classmethod
def merge_project(
cls, company, source_id: str, destination_id: str
cls, company: str, source_id: str, destination_id: str
) -> Tuple[int, int, Set[str]]:
"""
Move all the tasks and sub projects from the source project to the destination
@ -901,6 +900,7 @@ class ProjectBLL:
project_ids: Optional[Sequence[str]] = None,
allow_public: bool = True,
children_type: ProjectChildrenType = None,
children_tags: Sequence[str] = None,
) -> Tuple[Sequence[str], Sequence[str]]:
"""
Get the projects ids matching children_condition (if passed) or where the passed user created any tasks
@ -921,15 +921,20 @@ class ProjectBLL:
query &= Q(user__in=users)
project_query = None
child_query = (
query & GetMixin.get_list_field_query("tags", children_tags)
if children_tags
else query
)
if children_type == ProjectChildrenType.dataset:
child_queries = {
Project: query
Project: child_query
& Q(system_tags__in=[dataset_tag], basename__ne=datasets_project_name)
}
elif children_type == ProjectChildrenType.pipeline:
child_queries = {Task: query & Q(system_tags__in=[pipeline_tag])}
child_queries = {Task: child_query & Q(system_tags__in=[pipeline_tag])}
elif children_type == ProjectChildrenType.report:
child_queries = {Task: query & Q(system_tags__in=[reports_tag])}
child_queries = {Task: child_query & Q(system_tags__in=[reports_tag])}
else:
project_query = query
child_queries = {entity_cls: query for entity_cls in cls.child_classes}
@ -1065,10 +1070,11 @@ class ProjectBLL:
raise errors.bad_request.ValidationError(
f"List of strings expected for the field: {field}"
)
exclude, include = partition(field_filter, lambda x: x.startswith("-"))
helper = GetMixin.ListFieldBucketHelper(field, legacy=True)
actions = helper.get_actions(field_filter)
conditions[field] = {
**({"$in": include} if include else {}),
**({"$nin": [e[1:] for e in exclude]} if exclude else {}),
f"${action}": list(set(actions[action]))
for action in filter(None, actions)
}
return conditions

View File

@ -653,6 +653,13 @@ get_all_ex {
enum: [pipeline, report, dataset]
}
}
"999.0": ${get_all_ex."2.24"} {
request.properties.children_tags {
description: "The list of tag values to filter children by. Takes effect only if children_type is set. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
type: array
items {type: string}
}
}
}
update {
"2.1" {

View File

@ -7,12 +7,12 @@ from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidProjectId
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
from apiserver.apimodels.projects import (
DeleteRequest,
GetParamsRequest,
ProjectTagsRequest,
ProjectTaskParentsRequest,
ProjectHyperparamValuesRequest,
ProjectsGetRequest,
DeleteRequest,
MoveRequest,
MergeRequest,
ProjectRequest,
@ -99,19 +99,31 @@ def _adjust_search_parameters(data: dict, shallow_search: bool):
data["parent"] = [None]
def _get_project_stats_filter(request: ProjectsGetRequest) -> Tuple[Optional[dict], bool]:
def _get_project_stats_filter(
request: ProjectsGetRequest,
) -> Tuple[Optional[dict], bool]:
if request.include_stats_filter or not request.children_type:
return request.include_stats_filter, request.search_hidden
stats_filter = {"tags": request.children_tags} if request.children_tags else {}
if request.children_type == ProjectChildrenType.pipeline:
return {"system_tags": [pipeline_tag], "type": [TaskType.controller]}, True
return (
{
**stats_filter,
"system_tags": [pipeline_tag],
"type": [TaskType.controller],
},
True,
)
if request.children_type == ProjectChildrenType.report:
return {"system_tags": [reports_tag], "type": [TaskType.report]}, True
return request.include_stats_filter, request.search_hidden
return (
{**stats_filter, "system_tags": [reports_tag], "type": [TaskType.report]},
True,
)
return stats_filter, request.search_hidden
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
@endpoint("projects.get_all_ex")
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
data = call.data
conform_tag_fields(call, data)
@ -137,6 +149,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
project_ids=requested_ids,
allow_public=allow_public,
children_type=request.children_type,
children_tags=request.children_tags,
)
if not ids:
return {"projects": []}
@ -174,19 +187,20 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
conform_output_tags(call, projects)
project_ids = list({project["id"] for project in projects})
stats_filter, stats_search_hidden = _get_project_stats_filter(request)
if request.check_own_contents:
if request.children_type == ProjectChildrenType.dataset:
contents = project_bll.calc_own_datasets(
company=company_id,
project_ids=project_ids,
filter_=request.include_stats_filter,
filter_=stats_filter,
users=request.active_users,
)
else:
contents = project_bll.calc_own_contents(
company=company_id,
project_ids=project_ids,
filter_=_get_project_stats_filter(request)[0],
filter_=stats_filter,
users=request.active_users,
)
@ -199,19 +213,18 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
company=company_id,
project_ids=project_ids,
include_children=request.stats_with_children,
filter_=request.include_stats_filter,
filter_=stats_filter,
users=request.active_users,
selected_project_ids=selected_project_ids,
)
else:
filter_, search_hidden = _get_project_stats_filter(request)
stats, children = project_bll.get_project_stats(
company=company_id,
project_ids=project_ids,
specific_state=request.stats_for_state,
include_children=request.stats_with_children,
search_hidden=search_hidden,
filter_=filter_,
search_hidden=stats_search_hidden,
filter_=stats_filter,
users=request.active_users,
selected_project_ids=selected_project_ids,
)
@ -348,7 +361,7 @@ def delete(call: APICall, company_id: str, request: DeleteRequest):
"projects.get_unique_metric_variants", request_data_model=GetUniqueMetricsRequest
)
def get_unique_metric_variants(
call: APICall, company_id: str, request: GetUniqueMetricsRequest,
call: APICall, company_id: str, request: GetUniqueMetricsRequest
):
metrics = project_queries.get_unique_metric_variants(
@ -361,7 +374,7 @@ def get_unique_metric_variants(
call.result.data = {"metrics": metrics}
@endpoint("projects.get_model_metadata_keys",)
@endpoint("projects.get_model_metadata_keys")
def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest):
total, remaining, keys = project_queries.get_model_metadata_keys(
company_id,
@ -505,7 +518,7 @@ def get_task_parents(
call: APICall, company_id: str, request: ProjectTaskParentsRequest
):
call.result.data = {
"parents": project_bll.get_task_parents(
"parents": ProjectBLL.get_task_parents(
company_id,
projects=request.projects,
include_subprojects=request.include_subprojects,

View File

@ -33,6 +33,63 @@ class TestSubProjects(TestService):
).projects[0]
self.assertEqual(data.dataset_stats, {"file_count": 2, "total_size": 1000})
def test_query_children_system_tags(self):
test_root_name = "TestQueryChildrenTags"
test_root = self._temp_project(name=test_root_name)
project1 = self._temp_project(name=f"{test_root_name}/project1")
project2 = self._temp_project(name=f"{test_root_name}/project2")
self._temp_report(name="test report", project=project1)
self._temp_report(name="test report", project=project2, tags=["test1", "test2"])
self._temp_report(name="test report", project=project2, tags=["test1"])
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type="report",
shallow_search=True,
include_stats=True,
check_own_contents=True,
).projects
self.assertEqual(len(projects), 2)
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type="report",
children_tags=["test1", "test2"],
shallow_search=True,
include_stats=True,
check_own_contents=True,
).projects
self.assertEqual(len(projects), 1)
p = projects[0]
self.assertEqual(p.basename, "project2")
self.assertEqual(p.stats.active.total_tasks, 2)
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type="report",
children_tags=["__$all", "test1", "test2"],
shallow_search=True,
include_stats=True,
check_own_contents=True,
).projects
self.assertEqual(len(projects), 1)
p = projects[0]
self.assertEqual(p.basename, "project2")
self.assertEqual(p.stats.active.total_tasks, 1)
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type="report",
children_tags=["-test1", "-test2"],
shallow_search=True,
include_stats=True,
check_own_contents=True,
).projects
self.assertEqual(len(projects), 1)
p = projects[0]
self.assertEqual(p.basename, "project1")
self.assertEqual(p.stats.active.total_tasks, 1)
def test_query_children(self):
test_root_name = "TestQueryChildren"
test_root = self._temp_project(name=test_root_name)