Support querying by children_type in projects.get_all_ex

This commit is contained in:
allegroai 2023-03-23 19:07:42 +02:00
parent 74200a24bd
commit 6664c6237e
8 changed files with 197 additions and 106 deletions

View File

@ -1,5 +1,6 @@
from enum import Enum
from jsonmodels import models, fields from jsonmodels import models, fields
from jsonmodels.fields import EmbeddedField
from apiserver.apimodels import ListField, ActualEnumField, DictField from apiserver.apimodels import ListField, ActualEnumField, DictField
from apiserver.apimodels.organization import TagsRequest from apiserver.apimodels.organization import TagsRequest
@ -61,6 +62,12 @@ class ChildrenCondition(models.Base):
system_tags = fields.ListField([str]) system_tags = fields.ListField([str])
class ProjectChildrenType(Enum):
pipeline = "pipeline"
report = "report"
dataset = "dataset"
class ProjectsGetRequest(models.Base): class ProjectsGetRequest(models.Base):
include_dataset_stats = fields.BoolField(default=False) include_dataset_stats = fields.BoolField(default=False)
include_stats = fields.BoolField(default=False) include_stats = fields.BoolField(default=False)
@ -73,4 +80,4 @@ class ProjectsGetRequest(models.Base):
shallow_search = fields.BoolField(default=False) shallow_search = fields.BoolField(default=False)
search_hidden = fields.BoolField(default=False) search_hidden = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True) allow_public = fields.BoolField(default=True)
children_condition = EmbeddedField(ChildrenCondition) children_type = ActualEnumField(ProjectChildrenType)

View File

@ -22,6 +22,7 @@ from mongoengine import Q, Document
from apiserver import database from apiserver import database
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.apimodels.projects import ProjectChildrenType
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility, AttributedDocument from apiserver.database.model import EntityVisibility, AttributedDocument
from apiserver.database.model.base import GetMixin from apiserver.database.model.base import GetMixin
@ -44,9 +45,15 @@ from .sub_projects import (
log = config.logger(__file__) log = config.logger(__file__)
max_depth = config.get("services.projects.sub_projects.max_depth", 10) max_depth = config.get("services.projects.sub_projects.max_depth", 10)
reports_project_name = ".reports"
reports_tag = "reports"
dataset_tag = "dataset"
pipeline_tag = "pipeline"
class ProjectBLL: class ProjectBLL:
child_classes = (Task, Model)
@classmethod @classmethod
def merge_project( def merge_project(
cls, company, source_id: str, destination_id: str cls, company, source_id: str, destination_id: str
@ -81,7 +88,7 @@ class ProjectBLL:
) )
moved_entities = 0 moved_entities = 0
for entity_type in (Task, Model): for entity_type in cls.child_classes:
moved_entities += entity_type.objects( moved_entities += entity_type.objects(
company=company, company=company,
project=source_id, project=source_id,
@ -724,7 +731,7 @@ class ProjectBLL:
projects_query &= Q(id__in=project_ids) projects_query &= Q(id__in=project_ids)
res = set(Project.objects(projects_query).distinct(field="user")) res = set(Project.objects(projects_query).distinct(field="user"))
for cls_ in (Task, Model): for cls_ in cls.child_classes:
res |= set(cls_.objects(query).distinct(field="user")) res |= set(cls_.objects(query).distinct(field="user"))
return res return res
@ -759,46 +766,49 @@ class ProjectBLL:
users: Sequence[str] = None, users: Sequence[str] = None,
project_ids: Optional[Sequence[str]] = None, project_ids: Optional[Sequence[str]] = None,
allow_public: bool = True, allow_public: bool = True,
children_condition: Mapping[str, Any] = None, children_type: ProjectChildrenType = None,
) -> Tuple[Sequence[str], Sequence[str]]: ) -> Tuple[Sequence[str], Sequence[str]]:
""" """
Get the projects ids matching children_condition (if passed) or where the passed user created any tasks Get the projects ids with children matching children_type (if passed) or created by the passed user
including all the parents of these projects including all the parents of these projects
If project ids are specified then filter the results by these project ids If project ids are specified then filter the results by these project ids
""" """
if not (users or children_condition): if not (users or children_type):
raise errors.bad_request.ValidationError( raise errors.bad_request.ValidationError(
"Either active users or children_condition should be specified" "Either active users or children_type should be specified"
) )
projects_query = Project.prepare_query( query = (
company, parameters=children_condition, allow_public=allow_public
)
if children_condition:
contained_entities_query = None
else:
contained_entities_query = (
get_company_or_none_constraint(company) get_company_or_none_constraint(company)
if allow_public if allow_public
else Q(company=company) else Q(company=company)
) )
if users: if users:
user_query = Q(user__in=users) query &= Q(user__in=users)
projects_query &= user_query
if contained_entities_query: if children_type == ProjectChildrenType.dataset:
contained_entities_query &= user_query project_query = query & Q(system_tags__in=[dataset_tag])
entity_queries = {}
elif children_type == ProjectChildrenType.pipeline:
project_query = query & Q(system_tags__in=[pipeline_tag])
entity_queries = {}
elif children_type == ProjectChildrenType.report:
project_query = None
entity_queries = {Task: query & Q(system_tags__in=[reports_tag])}
else:
project_query = query
entity_queries = {entity_cls: query for entity_cls in cls.child_classes}
if project_ids: if project_ids:
ids_with_children = _ids_with_children(project_ids) ids_with_children = _ids_with_children(project_ids)
projects_query &= Q(id__in=ids_with_children) if project_query:
if contained_entities_query: project_query &= Q(id__in=ids_with_children)
contained_entities_query &= Q(project__in=ids_with_children) for entity_cls in entity_queries:
entity_queries[entity_cls] &= Q(project__in=ids_with_children)
res = {p.id for p in Project.objects(projects_query).only("id")} res = {p.id for p in Project.objects(project_query).only("id")} if project_query else set()
if contained_entities_query: for cls_, query_ in entity_queries.items():
for cls_ in (Task, Model): res |= set(cls_.objects(query_).distinct(field="project"))
res |= set(cls_.objects(contained_entities_query).distinct(field="project"))
res = list(res) res = list(res)
if not res: if not res:

View File

@ -1,3 +1,4 @@
from collections import defaultdict
from typing import Tuple, Set, Sequence from typing import Tuple, Set, Sequence
import attr import attr
@ -15,6 +16,7 @@ from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
from apiserver.database.model.project import Project from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType from apiserver.database.model.task.task import Task, ArtifactModes, TaskType
from .project_bll import ProjectBLL
from .sub_projects import _ids_with_children from .sub_projects import _ids_with_children
log = config.logger(__file__) log = config.logger(__file__)
@ -40,9 +42,9 @@ def validate_project_delete(company: str, project_id: str):
is_pipeline = "pipeline" in (project.system_tags or []) is_pipeline = "pipeline" in (project.system_tags or [])
project_ids = _ids_with_children([project_id]) project_ids = _ids_with_children([project_id])
ret = {} ret = {}
for cls in (Task, Model): for cls in ProjectBLL.child_classes:
ret[f"{cls.__name__.lower()}s"] = cls.objects(project__in=project_ids).count() ret[f"{cls.__name__.lower()}s"] = cls.objects(project__in=project_ids).count()
for cls in (Task, Model): for cls in ProjectBLL.child_classes:
query = dict( query = dict(
project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value] project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value]
) )
@ -98,9 +100,10 @@ def delete_project(
) )
if not delete_contents: if not delete_contents:
for cls in (Model, Task): disassociated = defaultdict(int)
updated_count = cls.objects(project__in=project_ids).update(project=None) for cls in ProjectBLL.child_classes:
res = DeleteProjectResult(disassociated_tasks=updated_count) disassociated[cls] = cls.objects(project__in=project_ids).update(project=None)
res = DeleteProjectResult(disassociated_tasks=disassociated[Task])
else: else:
deleted_models, model_event_urls, model_urls = _delete_models( deleted_models, model_event_urls, model_urls = _delete_models(
company=company, projects=project_ids company=company, projects=project_ids

View File

@ -621,16 +621,10 @@ get_all_ex {
} }
} }
"2.24": ${get_all_ex."2.23"} { "2.24": ${get_all_ex."2.23"} {
request.properties.children_condition { request.properties.children_type {
description: The filter that any of the child projects should match in order that the parent will be included description: If specified that only the projects under which the entities of this type can be found will be returned
type: object
properties {
system_tags {
description: The list of system tags to match from
type: string type: string
} enum: [pipeline, report, dataset]
}
additionalProperties: true
} }
} }
} }

View File

@ -1,4 +1,4 @@
from typing import Sequence from typing import Sequence, Optional
import attr import attr
from mongoengine import Q from mongoengine import Q
@ -18,9 +18,11 @@ from apiserver.apimodels.projects import (
ProjectOrNoneRequest, ProjectOrNoneRequest,
ProjectRequest, ProjectRequest,
ProjectModelMetadataValuesRequest, ProjectModelMetadataValuesRequest,
ProjectChildrenType,
) )
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, ProjectQueries from apiserver.bll.project import ProjectBLL, ProjectQueries
from apiserver.bll.project.project_bll import dataset_tag, pipeline_tag, reports_tag
from apiserver.bll.project.project_cleanup import ( from apiserver.bll.project.project_cleanup import (
delete_project, delete_project,
validate_project_delete, validate_project_delete,
@ -28,6 +30,7 @@ from apiserver.bll.project.project_cleanup import (
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility from apiserver.database.model import EntityVisibility
from apiserver.database.model.project import Project from apiserver.database.model.project import Project
from apiserver.database.model.task.task import TaskType
from apiserver.database.utils import ( from apiserver.database.utils import (
parse_from_call, parse_from_call,
get_company_or_none_constraint, get_company_or_none_constraint,
@ -96,6 +99,17 @@ def _adjust_search_parameters(data: dict, shallow_search: bool):
data["parent"] = [None] data["parent"] = [None]
def _get_filter_from_children_type(type_: ProjectChildrenType) -> Optional[dict]:
if type_ == ProjectChildrenType.dataset:
return {"system_tags": [dataset_tag], "type": [TaskType.data_processing]}
if type_ == ProjectChildrenType.pipeline:
return {"system_tags": [pipeline_tag], "type": [TaskType.controller]}
if type_ == ProjectChildrenType.report:
return {"system_tags": [reports_tag], "type": [TaskType.report]}
return None
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest) @endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
data = call.data data = call.data
@ -115,15 +129,13 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
data, shallow_search=request.shallow_search, data, shallow_search=request.shallow_search,
) )
selected_project_ids = None selected_project_ids = None
if request.active_users or request.children_condition: if request.active_users or request.children_type:
ids, selected_project_ids = project_bll.get_projects_with_selected_children( ids, selected_project_ids = project_bll.get_projects_with_selected_children(
company=company_id, company=company_id,
users=request.active_users, users=request.active_users,
project_ids=requested_ids, project_ids=requested_ids,
allow_public=allow_public, allow_public=allow_public,
children_condition=request.children_condition.to_struct() children_type=request.children_type,
if request.children_condition
else None,
) )
if not ids: if not ids:
return {"projects": []} return {"projects": []}
@ -140,26 +152,35 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
if not projects: if not projects:
return {"projects": projects, **ret_params} return {"projects": projects, **ret_params}
conform_output_tags(call, projects)
project_ids = list({project["id"] for project in projects}) project_ids = list({project["id"] for project in projects})
if request.check_own_contents or request.include_stats:
if request.children_type and not request.include_stats_filter:
filter_ = _get_filter_from_children_type(request.children_type)
search_hidden = True if filter_ else request.search_hidden
else:
filter_ = request.include_stats_filter
search_hidden = request.search_hidden
if request.check_own_contents: if request.check_own_contents:
contents = project_bll.calc_own_contents( contents = project_bll.calc_own_contents(
company=company_id, company=company_id,
project_ids=project_ids, project_ids=project_ids,
filter_=request.include_stats_filter, filter_=filter_,
users=request.active_users, users=request.active_users,
) )
for project in projects: for project in projects:
project.update(**contents.get(project["id"], {})) project.update(**contents.get(project["id"], {}))
conform_output_tags(call, projects)
if request.include_stats: if request.include_stats:
stats, children = project_bll.get_project_stats( stats, children = project_bll.get_project_stats(
company=company_id, company=company_id,
project_ids=project_ids, project_ids=project_ids,
specific_state=request.stats_for_state, specific_state=request.stats_for_state,
include_children=request.stats_with_children, include_children=request.stats_with_children,
search_hidden=request.search_hidden, search_hidden=search_hidden,
filter_=request.include_stats_filter, filter_=filter_,
users=request.active_users, users=request.active_users,
selected_project_ids=selected_project_ids, selected_project_ids=selected_project_ids,
) )

View File

@ -16,6 +16,7 @@ from apiserver.apimodels.reports import (
) )
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.apimodels.base import UpdateResponse from apiserver.apimodels.base import UpdateResponse
from apiserver.bll.project.project_bll import reports_project_name, reports_tag
from apiserver.services.utils import process_include_subprojects, sort_tags_response from apiserver.services.utils import process_include_subprojects, sort_tags_response
from apiserver.bll.organization import OrgBLL from apiserver.bll.organization import OrgBLL
from apiserver.bll.project import ProjectBLL from apiserver.bll.project import ProjectBLL
@ -42,8 +43,6 @@ project_bll = ProjectBLL()
task_bll = TaskBLL() task_bll = TaskBLL()
reports_project_name = ".reports"
reports_tag = "reports"
update_fields = { update_fields = {
"name", "name",
"tags", "tags",
@ -80,7 +79,9 @@ def update_report(call: APICall, company_id: str, request: UpdateReportRequest):
if not partial_update_dict: if not partial_update_dict:
return UpdateResponse(updated=0) return UpdateResponse(updated=0)
allowed_for_published = set(partial_update_dict.keys()).issubset({"tags", "name", "comment"}) allowed_for_published = set(partial_update_dict.keys()).issubset(
{"tags", "name", "comment"}
)
if task.status != TaskStatus.created and not allowed_for_published: if task.status != TaskStatus.created and not allowed_for_published:
raise errors.bad_request.InvalidTaskStatus( raise errors.bad_request.InvalidTaskStatus(
expected=TaskStatus.created, status=task.status expected=TaskStatus.created, status=task.status

View File

@ -36,27 +36,56 @@ class TestSubProjects(TestService):
def test_query_children(self): def test_query_children(self):
test_root_name = "TestQueryChildren" test_root_name = "TestQueryChildren"
test_root = self._temp_project(name=test_root_name) test_root = self._temp_project(name=test_root_name)
child_with_tag = self._temp_project( dataset_project = self._temp_project(
name=f"{test_root_name}/Project1/WithTag", system_tags=["test"] name=f"{test_root_name}/Project1/Dataset", system_tags=["dataset"]
) )
child_without_tag = self._temp_project(name=f"{test_root_name}/Project2/WithoutTag") self._temp_task(
name="dataset task",
projects = self.api.projects.get_all_ex(parent=[test_root], shallow_search=True).projects type="data_processing",
self.assertEqual({p.basename for p in projects}, {"Project1", "Project2"}) system_tags=["dataset"],
project=dataset_project,
)
self._temp_task(name="regular task", project=dataset_project)
pipeline_project = self._temp_project(
name=f"{test_root_name}/Project2/Pipeline", system_tags=["pipeline"]
)
self._temp_task(
name="pipeline task",
type="controller",
system_tags=["pipeline"],
project=pipeline_project,
)
self._temp_task(name="regular task", project=pipeline_project)
report_project = self._temp_project(name=f"{test_root_name}/Project3")
self._temp_report(name="test report", project=report_project)
self._temp_task(name="regular task", project=report_project)
projects = self.api.projects.get_all_ex( projects = self.api.projects.get_all_ex(
parent=[test_root], children_condition={"system_tags": ["test"]}, shallow_search=True parent=[test_root], shallow_search=True, include_stats=True
).projects ).projects
self.assertEqual({p.basename for p in projects}, {"Project1"}) self.assertEqual(
projects = self.api.projects.get_all_ex( {p.basename for p in projects}, {f"Project{idx+1}" for idx in range(3)}
parent=[projects[0].id], children_condition={"system_tags": ["test"]}, shallow_search=True )
).projects for p in projects:
self.assertEqual(projects[0].id, child_with_tag) self.assertEqual(
p.stats.active.total_tasks,
2
if p.basename in ("Project1", "Project2")
else 1
)
for i, type_ in enumerate(("dataset", "pipeline", "report")):
projects = self.api.projects.get_all_ex( projects = self.api.projects.get_all_ex(
parent=[test_root], children_condition={"system_tags": ["not existent"]}, shallow_search=True parent=[test_root],
children_type=type_,
shallow_search=True,
include_stats=True,
).projects ).projects
self.assertEqual(len(projects), 0) self.assertEqual({p.basename for p in projects}, {f"Project{i+1}"})
p = projects[0]
self.assertEqual(
p.stats.active.total_tasks, 1
)
def test_project_aggregations(self): def test_project_aggregations(self):
"""This test requires user with user_auth_only... credentials in db""" """This test requires user with user_auth_only... credentials in db"""
@ -323,12 +352,21 @@ class TestSubProjects(TestService):
**kwargs, **kwargs,
) )
def _temp_task(self, client=None, **kwargs): def _temp_report(self, name, **kwargs):
return self.create_temp(
"reports",
name=name,
object_name="task",
delete_params=self.delete_params,
**kwargs,
)
def _temp_task(self, client=None, name=None, type=None, **kwargs):
return self.create_temp( return self.create_temp(
"tasks", "tasks",
delete_params=self.delete_params, delete_params=self.delete_params,
type="testing", type=type or "testing",
name=db_id(), name=name or db_id(),
input=dict(view=dict()), input=dict(view=dict()),
client=client, client=client,
**kwargs, **kwargs,

View File

@ -50,7 +50,9 @@ class TestTasksResetDelete(TestService):
self.assertEqual(res.urls.artifact_urls, []) self.assertEqual(res.urls.artifact_urls, [])
task = self.new_task() task = self.new_task()
(_, published_model_urls), (model, draft_model_urls) = self.create_task_models(task) (_, published_model_urls), (model, draft_model_urls) = self.create_task_models(
task
)
artifact_urls = self.send_artifacts(task) artifact_urls = self.send_artifacts(task)
event_urls = self.send_debug_image_events(task) event_urls = self.send_debug_image_events(task)
event_urls.update(self.send_plot_events(task)) event_urls.update(self.send_plot_events(task))
@ -74,7 +76,12 @@ class TestTasksResetDelete(TestService):
self.api.tasks.reset(task=task, force=True) self.api.tasks.reset(task=task, force=True)
# test urls # test urls
task, (published_model_urls, draft_model_urls), artifact_urls, event_urls = self.create_task_with_data() (
task,
(published_model_urls, draft_model_urls),
artifact_urls,
event_urls,
) = self.create_task_with_data()
res = self.api.tasks.reset(task=task, force=True, return_file_urls=True) res = self.api.tasks.reset(task=task, force=True, return_file_urls=True)
self.assertEqual(set(res.urls.model_urls), draft_model_urls) self.assertEqual(set(res.urls.model_urls), draft_model_urls)
self.assertEqual(set(res.urls.event_urls), event_urls) self.assertEqual(set(res.urls.event_urls), event_urls)
@ -101,13 +108,18 @@ class TestTasksResetDelete(TestService):
# with delete_contents flag # with delete_contents flag
project = self.new_project() project = self.new_project()
task, (published_model_urls, draft_model_urls), artifact_urls, event_urls = self.create_task_with_data( (
project=project task,
) (published_model_urls, draft_model_urls),
artifact_urls,
event_urls,
) = self.create_task_with_data(project=project)
res = self.api.projects.delete( res = self.api.projects.delete(
project=project, force=True, delete_contents=True project=project, force=True, delete_contents=True
) )
self.assertEqual(set(res.urls.model_urls), published_model_urls | draft_model_urls) self.assertEqual(
set(res.urls.model_urls), published_model_urls | draft_model_urls
)
self.assertEqual(res.deleted, 1) self.assertEqual(res.deleted, 1)
self.assertEqual(res.disassociated_tasks, 0) self.assertEqual(res.disassociated_tasks, 0)
self.assertEqual(res.deleted_tasks, 1) self.assertEqual(res.deleted_tasks, 1)
@ -121,7 +133,9 @@ class TestTasksResetDelete(TestService):
self, **kwargs self, **kwargs
) -> Tuple[str, Tuple[Set[str], Set[str]], Set[str], Set[str]]: ) -> Tuple[str, Tuple[Set[str], Set[str]], Set[str], Set[str]]:
task = self.new_task(**kwargs) task = self.new_task(**kwargs)
(_, published_model_urls), (model, draft_model_urls) = self.create_task_models(task, **kwargs) (_, published_model_urls), (model, draft_model_urls) = self.create_task_models(
task, **kwargs
)
artifact_urls = self.send_artifacts(task) artifact_urls = self.send_artifacts(task)
event_urls = self.send_debug_image_events(task) event_urls = self.send_debug_image_events(task)
event_urls.update(self.send_plot_events(task)) event_urls.update(self.send_plot_events(task))
@ -172,7 +186,7 @@ class TestTasksResetDelete(TestService):
), ),
self.create_event( self.create_event(
model, "plot", 0, plot_str=f'{{"source": "{url2}"}}', model_event=True model, "plot", 0, plot_str=f'{{"source": "{url2}"}}', model_event=True
) ),
] ]
self.send_batch(events) self.send_batch(events)
return {url1, url2} return {url1, url2}
@ -181,7 +195,10 @@ class TestTasksResetDelete(TestService):
url_pattern = "url_{num}.txt" url_pattern = "url_{num}.txt"
events = [ events = [
self.create_event( self.create_event(
task, "training_debug_image", iteration, url=url_pattern.format(num=iteration) task,
"training_debug_image",
iteration,
url=url_pattern.format(num=iteration),
) )
for iteration in range(5) for iteration in range(5)
] ]