Support querying by children_type in projects.get_all_ex

This commit is contained in:
allegroai 2023-03-23 19:07:42 +02:00
parent 74200a24bd
commit 6664c6237e
8 changed files with 197 additions and 106 deletions

View File

@ -1,5 +1,6 @@
from enum import Enum
from jsonmodels import models, fields
from jsonmodels.fields import EmbeddedField
from apiserver.apimodels import ListField, ActualEnumField, DictField
from apiserver.apimodels.organization import TagsRequest
@ -61,6 +62,12 @@ class ChildrenCondition(models.Base):
system_tags = fields.ListField([str])
class ProjectChildrenType(Enum):
pipeline = "pipeline"
report = "report"
dataset = "dataset"
class ProjectsGetRequest(models.Base):
include_dataset_stats = fields.BoolField(default=False)
include_stats = fields.BoolField(default=False)
@ -73,4 +80,4 @@ class ProjectsGetRequest(models.Base):
shallow_search = fields.BoolField(default=False)
search_hidden = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)
children_condition = EmbeddedField(ChildrenCondition)
children_type = ActualEnumField(ProjectChildrenType)

View File

@ -22,6 +22,7 @@ from mongoengine import Q, Document
from apiserver import database
from apiserver.apierrors import errors
from apiserver.apimodels.projects import ProjectChildrenType
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility, AttributedDocument
from apiserver.database.model.base import GetMixin
@ -44,9 +45,15 @@ from .sub_projects import (
log = config.logger(__file__)
max_depth = config.get("services.projects.sub_projects.max_depth", 10)
reports_project_name = ".reports"
reports_tag = "reports"
dataset_tag = "dataset"
pipeline_tag = "pipeline"
class ProjectBLL:
child_classes = (Task, Model)
@classmethod
def merge_project(
cls, company, source_id: str, destination_id: str
@ -81,7 +88,7 @@ class ProjectBLL:
)
moved_entities = 0
for entity_type in (Task, Model):
for entity_type in cls.child_classes:
moved_entities += entity_type.objects(
company=company,
project=source_id,
@ -724,7 +731,7 @@ class ProjectBLL:
projects_query &= Q(id__in=project_ids)
res = set(Project.objects(projects_query).distinct(field="user"))
for cls_ in (Task, Model):
for cls_ in cls.child_classes:
res |= set(cls_.objects(query).distinct(field="user"))
return res
@ -759,46 +766,49 @@ class ProjectBLL:
users: Sequence[str] = None,
project_ids: Optional[Sequence[str]] = None,
allow_public: bool = True,
children_condition: Mapping[str, Any] = None,
children_type: ProjectChildrenType = None,
) -> Tuple[Sequence[str], Sequence[str]]:
"""
Get the projects ids matching children_condition (if passed) or where the passed user created any tasks
Get the projects ids with children matching children_type (if passed) or created by the passed user
including all the parents of these projects
If project ids are specified then filter the results by these project ids
"""
if not (users or children_condition):
if not (users or children_type):
raise errors.bad_request.ValidationError(
"Either active users or children_condition should be specified"
"Either active users or children_type should be specified"
)
projects_query = Project.prepare_query(
company, parameters=children_condition, allow_public=allow_public
query = (
get_company_or_none_constraint(company)
if allow_public
else Q(company=company)
)
if children_condition:
contained_entities_query = None
else:
contained_entities_query = (
get_company_or_none_constraint(company)
if allow_public
else Q(company=company)
)
if users:
user_query = Q(user__in=users)
projects_query &= user_query
if contained_entities_query:
contained_entities_query &= user_query
query &= Q(user__in=users)
if children_type == ProjectChildrenType.dataset:
project_query = query & Q(system_tags__in=[dataset_tag])
entity_queries = {}
elif children_type == ProjectChildrenType.pipeline:
project_query = query & Q(system_tags__in=[pipeline_tag])
entity_queries = {}
elif children_type == ProjectChildrenType.report:
project_query = None
entity_queries = {Task: query & Q(system_tags__in=[reports_tag])}
else:
project_query = query
entity_queries = {entity_cls: query for entity_cls in cls.child_classes}
if project_ids:
ids_with_children = _ids_with_children(project_ids)
projects_query &= Q(id__in=ids_with_children)
if contained_entities_query:
contained_entities_query &= Q(project__in=ids_with_children)
if project_query:
project_query &= Q(id__in=ids_with_children)
for entity_cls in entity_queries:
entity_queries[entity_cls] &= Q(project__in=ids_with_children)
res = {p.id for p in Project.objects(projects_query).only("id")}
if contained_entities_query:
for cls_ in (Task, Model):
res |= set(cls_.objects(contained_entities_query).distinct(field="project"))
res = {p.id for p in Project.objects(project_query).only("id")} if project_query else set()
for cls_, query_ in entity_queries.items():
res |= set(cls_.objects(query_).distinct(field="project"))
res = list(res)
if not res:

View File

@ -1,3 +1,4 @@
from collections import defaultdict
from typing import Tuple, Set, Sequence
import attr
@ -15,6 +16,7 @@ from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType
from .project_bll import ProjectBLL
from .sub_projects import _ids_with_children
log = config.logger(__file__)
@ -40,9 +42,9 @@ def validate_project_delete(company: str, project_id: str):
is_pipeline = "pipeline" in (project.system_tags or [])
project_ids = _ids_with_children([project_id])
ret = {}
for cls in (Task, Model):
for cls in ProjectBLL.child_classes:
ret[f"{cls.__name__.lower()}s"] = cls.objects(project__in=project_ids).count()
for cls in (Task, Model):
for cls in ProjectBLL.child_classes:
query = dict(
project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value]
)
@ -98,9 +100,10 @@ def delete_project(
)
if not delete_contents:
for cls in (Model, Task):
updated_count = cls.objects(project__in=project_ids).update(project=None)
res = DeleteProjectResult(disassociated_tasks=updated_count)
disassociated = defaultdict(int)
for cls in ProjectBLL.child_classes:
disassociated[cls] = cls.objects(project__in=project_ids).update(project=None)
res = DeleteProjectResult(disassociated_tasks=disassociated[Task])
else:
deleted_models, model_event_urls, model_urls = _delete_models(
company=company, projects=project_ids

View File

@ -621,16 +621,10 @@ get_all_ex {
}
}
"2.24": ${get_all_ex."2.23"} {
request.properties.children_condition {
description: The filter that any of the child projects should match in order that the parent will be included
type: object
properties {
system_tags {
description: The list of system tags to match from
type: string
}
}
additionalProperties: true
request.properties.children_type {
description: If specified that only the projects under which the entities of this type can be found will be returned
type: string
enum: [pipeline, report, dataset]
}
}
}

View File

@ -1,4 +1,4 @@
from typing import Sequence
from typing import Sequence, Optional
import attr
from mongoengine import Q
@ -18,9 +18,11 @@ from apiserver.apimodels.projects import (
ProjectOrNoneRequest,
ProjectRequest,
ProjectModelMetadataValuesRequest,
ProjectChildrenType,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, ProjectQueries
from apiserver.bll.project.project_bll import dataset_tag, pipeline_tag, reports_tag
from apiserver.bll.project.project_cleanup import (
delete_project,
validate_project_delete,
@ -28,6 +30,7 @@ from apiserver.bll.project.project_cleanup import (
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import TaskType
from apiserver.database.utils import (
parse_from_call,
get_company_or_none_constraint,
@ -96,6 +99,17 @@ def _adjust_search_parameters(data: dict, shallow_search: bool):
data["parent"] = [None]
def _get_filter_from_children_type(type_: ProjectChildrenType) -> Optional[dict]:
if type_ == ProjectChildrenType.dataset:
return {"system_tags": [dataset_tag], "type": [TaskType.data_processing]}
if type_ == ProjectChildrenType.pipeline:
return {"system_tags": [pipeline_tag], "type": [TaskType.controller]}
if type_ == ProjectChildrenType.report:
return {"system_tags": [reports_tag], "type": [TaskType.report]}
return None
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
data = call.data
@ -115,15 +129,13 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
data, shallow_search=request.shallow_search,
)
selected_project_ids = None
if request.active_users or request.children_condition:
if request.active_users or request.children_type:
ids, selected_project_ids = project_bll.get_projects_with_selected_children(
company=company_id,
users=request.active_users,
project_ids=requested_ids,
allow_public=allow_public,
children_condition=request.children_condition.to_struct()
if request.children_condition
else None,
children_type=request.children_type,
)
if not ids:
return {"projects": []}
@ -140,33 +152,42 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
if not projects:
return {"projects": projects, **ret_params}
project_ids = list({project["id"] for project in projects})
if request.check_own_contents:
contents = project_bll.calc_own_contents(
company=company_id,
project_ids=project_ids,
filter_=request.include_stats_filter,
users=request.active_users,
)
for project in projects:
project.update(**contents.get(project["id"], {}))
conform_output_tags(call, projects)
if request.include_stats:
stats, children = project_bll.get_project_stats(
company=company_id,
project_ids=project_ids,
specific_state=request.stats_for_state,
include_children=request.stats_with_children,
search_hidden=request.search_hidden,
filter_=request.include_stats_filter,
users=request.active_users,
selected_project_ids=selected_project_ids,
)
project_ids = list({project["id"] for project in projects})
for project in projects:
project["stats"] = stats[project["id"]]
project["sub_projects"] = children[project["id"]]
if request.check_own_contents or request.include_stats:
if request.children_type and not request.include_stats_filter:
filter_ = _get_filter_from_children_type(request.children_type)
search_hidden = True if filter_ else request.search_hidden
else:
filter_ = request.include_stats_filter
search_hidden = request.search_hidden
if request.check_own_contents:
contents = project_bll.calc_own_contents(
company=company_id,
project_ids=project_ids,
filter_=filter_,
users=request.active_users,
)
for project in projects:
project.update(**contents.get(project["id"], {}))
if request.include_stats:
stats, children = project_bll.get_project_stats(
company=company_id,
project_ids=project_ids,
specific_state=request.stats_for_state,
include_children=request.stats_with_children,
search_hidden=search_hidden,
filter_=filter_,
users=request.active_users,
selected_project_ids=selected_project_ids,
)
for project in projects:
project["stats"] = stats[project["id"]]
project["sub_projects"] = children[project["id"]]
if request.include_dataset_stats:
dataset_stats = project_bll.get_dataset_stats(

View File

@ -16,6 +16,7 @@ from apiserver.apimodels.reports import (
)
from apiserver.apierrors import errors
from apiserver.apimodels.base import UpdateResponse
from apiserver.bll.project.project_bll import reports_project_name, reports_tag
from apiserver.services.utils import process_include_subprojects, sort_tags_response
from apiserver.bll.organization import OrgBLL
from apiserver.bll.project import ProjectBLL
@ -42,8 +43,6 @@ project_bll = ProjectBLL()
task_bll = TaskBLL()
reports_project_name = ".reports"
reports_tag = "reports"
update_fields = {
"name",
"tags",
@ -80,7 +79,9 @@ def update_report(call: APICall, company_id: str, request: UpdateReportRequest):
if not partial_update_dict:
return UpdateResponse(updated=0)
allowed_for_published = set(partial_update_dict.keys()).issubset({"tags", "name", "comment"})
allowed_for_published = set(partial_update_dict.keys()).issubset(
{"tags", "name", "comment"}
)
if task.status != TaskStatus.created and not allowed_for_published:
raise errors.bad_request.InvalidTaskStatus(
expected=TaskStatus.created, status=task.status

View File

@ -36,27 +36,56 @@ class TestSubProjects(TestService):
def test_query_children(self):
test_root_name = "TestQueryChildren"
test_root = self._temp_project(name=test_root_name)
child_with_tag = self._temp_project(
name=f"{test_root_name}/Project1/WithTag", system_tags=["test"]
dataset_project = self._temp_project(
name=f"{test_root_name}/Project1/Dataset", system_tags=["dataset"]
)
child_without_tag = self._temp_project(name=f"{test_root_name}/Project2/WithoutTag")
projects = self.api.projects.get_all_ex(parent=[test_root], shallow_search=True).projects
self.assertEqual({p.basename for p in projects}, {"Project1", "Project2"})
self._temp_task(
name="dataset task",
type="data_processing",
system_tags=["dataset"],
project=dataset_project,
)
self._temp_task(name="regular task", project=dataset_project)
pipeline_project = self._temp_project(
name=f"{test_root_name}/Project2/Pipeline", system_tags=["pipeline"]
)
self._temp_task(
name="pipeline task",
type="controller",
system_tags=["pipeline"],
project=pipeline_project,
)
self._temp_task(name="regular task", project=pipeline_project)
report_project = self._temp_project(name=f"{test_root_name}/Project3")
self._temp_report(name="test report", project=report_project)
self._temp_task(name="regular task", project=report_project)
projects = self.api.projects.get_all_ex(
parent=[test_root], children_condition={"system_tags": ["test"]}, shallow_search=True
parent=[test_root], shallow_search=True, include_stats=True
).projects
self.assertEqual({p.basename for p in projects}, {"Project1"})
projects = self.api.projects.get_all_ex(
parent=[projects[0].id], children_condition={"system_tags": ["test"]}, shallow_search=True
).projects
self.assertEqual(projects[0].id, child_with_tag)
self.assertEqual(
{p.basename for p in projects}, {f"Project{idx+1}" for idx in range(3)}
)
for p in projects:
self.assertEqual(
p.stats.active.total_tasks,
2
if p.basename in ("Project1", "Project2")
else 1
)
projects = self.api.projects.get_all_ex(
parent=[test_root], children_condition={"system_tags": ["not existent"]}, shallow_search=True
).projects
self.assertEqual(len(projects), 0)
for i, type_ in enumerate(("dataset", "pipeline", "report")):
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type=type_,
shallow_search=True,
include_stats=True,
).projects
self.assertEqual({p.basename for p in projects}, {f"Project{i+1}"})
p = projects[0]
self.assertEqual(
p.stats.active.total_tasks, 1
)
def test_project_aggregations(self):
"""This test requires user with user_auth_only... credentials in db"""
@ -323,12 +352,21 @@ class TestSubProjects(TestService):
**kwargs,
)
def _temp_task(self, client=None, **kwargs):
def _temp_report(self, name, **kwargs):
return self.create_temp(
"reports",
name=name,
object_name="task",
delete_params=self.delete_params,
**kwargs,
)
def _temp_task(self, client=None, name=None, type=None, **kwargs):
return self.create_temp(
"tasks",
delete_params=self.delete_params,
type="testing",
name=db_id(),
type=type or "testing",
name=name or db_id(),
input=dict(view=dict()),
client=client,
**kwargs,

View File

@ -50,7 +50,9 @@ class TestTasksResetDelete(TestService):
self.assertEqual(res.urls.artifact_urls, [])
task = self.new_task()
(_, published_model_urls), (model, draft_model_urls) = self.create_task_models(task)
(_, published_model_urls), (model, draft_model_urls) = self.create_task_models(
task
)
artifact_urls = self.send_artifacts(task)
event_urls = self.send_debug_image_events(task)
event_urls.update(self.send_plot_events(task))
@ -74,7 +76,12 @@ class TestTasksResetDelete(TestService):
self.api.tasks.reset(task=task, force=True)
# test urls
task, (published_model_urls, draft_model_urls), artifact_urls, event_urls = self.create_task_with_data()
(
task,
(published_model_urls, draft_model_urls),
artifact_urls,
event_urls,
) = self.create_task_with_data()
res = self.api.tasks.reset(task=task, force=True, return_file_urls=True)
self.assertEqual(set(res.urls.model_urls), draft_model_urls)
self.assertEqual(set(res.urls.event_urls), event_urls)
@ -101,13 +108,18 @@ class TestTasksResetDelete(TestService):
# with delete_contents flag
project = self.new_project()
task, (published_model_urls, draft_model_urls), artifact_urls, event_urls = self.create_task_with_data(
project=project
)
(
task,
(published_model_urls, draft_model_urls),
artifact_urls,
event_urls,
) = self.create_task_with_data(project=project)
res = self.api.projects.delete(
project=project, force=True, delete_contents=True
)
self.assertEqual(set(res.urls.model_urls), published_model_urls | draft_model_urls)
self.assertEqual(
set(res.urls.model_urls), published_model_urls | draft_model_urls
)
self.assertEqual(res.deleted, 1)
self.assertEqual(res.disassociated_tasks, 0)
self.assertEqual(res.deleted_tasks, 1)
@ -121,7 +133,9 @@ class TestTasksResetDelete(TestService):
self, **kwargs
) -> Tuple[str, Tuple[Set[str], Set[str]], Set[str], Set[str]]:
task = self.new_task(**kwargs)
(_, published_model_urls), (model, draft_model_urls) = self.create_task_models(task, **kwargs)
(_, published_model_urls), (model, draft_model_urls) = self.create_task_models(
task, **kwargs
)
artifact_urls = self.send_artifacts(task)
event_urls = self.send_debug_image_events(task)
event_urls.update(self.send_plot_events(task))
@ -172,7 +186,7 @@ class TestTasksResetDelete(TestService):
),
self.create_event(
model, "plot", 0, plot_str=f'{{"source": "{url2}"}}', model_event=True
)
),
]
self.send_batch(events)
return {url1, url2}
@ -181,7 +195,10 @@ class TestTasksResetDelete(TestService):
url_pattern = "url_{num}.txt"
events = [
self.create_event(
task, "training_debug_image", iteration, url=url_pattern.format(num=iteration)
task,
"training_debug_image",
iteration,
url=url_pattern.format(num=iteration),
)
for iteration in range(5)
]