Add check_contents flag for projects.get_all_ex

This commit is contained in:
allegroai 2021-05-03 18:12:44 +03:00
parent d60d6dfe99
commit 1cef03b8c2
5 changed files with 95 additions and 2 deletions

View File

@ -56,4 +56,5 @@ class ProjectsGetRequest(models.Base):
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
non_public = fields.BoolField(default=False)
active_users = fields.ListField(str)
check_own_contents = fields.BoolField(default=False)
shallow_search = fields.BoolField(default=False)

View File

@ -21,7 +21,7 @@ from mongoengine import Q, Document
from apiserver import database
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model import EntityVisibility, AttributedDocument
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, external_task_types
@ -672,3 +672,48 @@ class ProjectBLL:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
return Model.objects(query).distinct(field="framework")
@classmethod
def calc_own_contents(cls, company: str, project_ids: Sequence[str]) -> Dict[str, dict]:
"""
Returns the amount of task/dataviews/models per requested project
Use separate aggregation calls on Task/Dataview/Model instead of lookup
aggregation on projects in order not to hit memory limits on large tasks
"""
if not project_ids:
return {}
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company]},
"project": {"$in": project_ids},
}
},
{
"$project": {"project": 1}
},
{
"$group": {
"_id": "$project",
"count": {"$sum": 1},
}
}
]
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
return {
data["_id"]: data["count"]
for data in cls_.aggregate(pipeline)
}
with TimingContext("mongo", "get_security_groups"):
tasks = get_agrregate_res(Task)
models = get_agrregate_res(Model)
return {
pid: {
"own_tasks": tasks.get(pid, 0),
"own_models": models.get(pid, 0),
}
for pid in project_ids
}

View File

@ -468,6 +468,23 @@ get_all_ex {
type: boolean
default: false
}
check_own_contents {
description: "If set to 'true' and project ids are passed to the query then for these projects their own tasks, models and dataviews are counted"
type: boolean
default: false
}
}
}
response {
properties {
own_tasks {
description: "The amount of tasks under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
type: integer
}
own_models {
description: "The amount of models under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
type: integer
}
}
}
}

View File

@ -89,13 +89,15 @@ def _adjust_search_parameters(data: dict, shallow_search: bool):
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
conform_tag_fields(call, call.data)
allow_public = not request.non_public
data = call.data
requested_ids = data.get("id")
with TimingContext("mongo", "projects_get_all"):
data = call.data
if request.active_users:
ids = project_bll.get_projects_with_active_user(
company=company_id,
users=request.active_users,
project_ids=data.get("id"),
project_ids=requested_ids,
allow_public=allow_public,
)
if not ids:
@ -109,6 +111,17 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
company=company_id, query_dict=data, allow_public=allow_public,
)
if request.check_own_contents and requested_ids:
existing_requested_ids = {
project["id"] for project in projects if project["id"] in requested_ids
}
if existing_requested_ids:
contents = project_bll.calc_own_contents(
company=company_id, project_ids=list(existing_requested_ids)
)
for project in projects:
project.update(**contents.get(project["id"], {}))
conform_output_tags(call, projects)
if not request.include_stats:
call.result.data = {"projects": projects}

View File

@ -169,6 +169,23 @@ class TestSubProjects(TestService):
self.api.projects.delete(project=project1, force=True)
def test_get_all_with_check_own_contents(self):
project1, _ = self._temp_project_with_tasks(name="project1x")
project2 = self._temp_project(name="project2x")
self._temp_project_with_tasks(name="project2x/project22")
self._temp_model(project=project1)
res = self.api.projects.get_all_ex(
id=[project1, project2], check_own_contents=True
).projects
res1 = next(p for p in res if p.id == project1)
self.assertEqual(res1.own_tasks, 2)
self.assertEqual(res1.own_models, 1)
res2 = next(p for p in res if p.id == project2)
self.assertEqual(res2.own_tasks, 0)
self.assertEqual(res2.own_models, 0)
def test_get_all_with_stats(self):
project4, _ = self._temp_project_with_tasks(name="project1/project3/project4")
project5, _ = self._temp_project_with_tasks(name="project1/project3/project5")