diff --git a/apiserver/apimodels/projects.py b/apiserver/apimodels/projects.py index 0cbfb3f..017de80 100644 --- a/apiserver/apimodels/projects.py +++ b/apiserver/apimodels/projects.py @@ -1,5 +1,6 @@ +from enum import Enum + from jsonmodels import models, fields -from jsonmodels.fields import EmbeddedField from apiserver.apimodels import ListField, ActualEnumField, DictField from apiserver.apimodels.organization import TagsRequest @@ -61,6 +62,12 @@ class ChildrenCondition(models.Base): system_tags = fields.ListField([str]) +class ProjectChildrenType(Enum): + pipeline = "pipeline" + report = "report" + dataset = "dataset" + + class ProjectsGetRequest(models.Base): include_dataset_stats = fields.BoolField(default=False) include_stats = fields.BoolField(default=False) @@ -73,4 +80,4 @@ class ProjectsGetRequest(models.Base): shallow_search = fields.BoolField(default=False) search_hidden = fields.BoolField(default=False) allow_public = fields.BoolField(default=True) - children_condition = EmbeddedField(ChildrenCondition) + children_type = ActualEnumField(ProjectChildrenType) diff --git a/apiserver/bll/project/project_bll.py b/apiserver/bll/project/project_bll.py index 2656f0c..5716463 100644 --- a/apiserver/bll/project/project_bll.py +++ b/apiserver/bll/project/project_bll.py @@ -22,6 +22,7 @@ from mongoengine import Q, Document from apiserver import database from apiserver.apierrors import errors +from apiserver.apimodels.projects import ProjectChildrenType from apiserver.config_repo import config from apiserver.database.model import EntityVisibility, AttributedDocument from apiserver.database.model.base import GetMixin @@ -44,9 +45,15 @@ from .sub_projects import ( log = config.logger(__file__) max_depth = config.get("services.projects.sub_projects.max_depth", 10) +reports_project_name = ".reports" +reports_tag = "reports" +dataset_tag = "dataset" +pipeline_tag = "pipeline" class ProjectBLL: + child_classes = (Task, Model) + @classmethod def merge_project( cls, company, source_id: str, destination_id: str @@ -81,7 +88,7 @@ class ProjectBLL: ) moved_entities = 0 - for entity_type in (Task, Model): + for entity_type in cls.child_classes: moved_entities += entity_type.objects( company=company, project=source_id, @@ -724,7 +731,7 @@ class ProjectBLL: projects_query &= Q(id__in=project_ids) res = set(Project.objects(projects_query).distinct(field="user")) - for cls_ in (Task, Model): + for cls_ in cls.child_classes: res |= set(cls_.objects(query).distinct(field="user")) return res @@ -759,46 +766,49 @@ class ProjectBLL: users: Sequence[str] = None, project_ids: Optional[Sequence[str]] = None, allow_public: bool = True, - children_condition: Mapping[str, Any] = None, + children_type: ProjectChildrenType = None, ) -> Tuple[Sequence[str], Sequence[str]]: """ - Get the projects ids matching children_condition (if passed) or where the passed user created any tasks + Get the projects ids with children matching children_type (if passed) or created by the passed user including all the parents of these projects If project ids are specified then filter the results by these project ids """ - if not (users or children_condition): + if not (users or children_type): raise errors.bad_request.ValidationError( - "Either active users or children_condition should be specified" + "Either active users or children_type should be specified" ) - projects_query = Project.prepare_query( - company, parameters=children_condition, allow_public=allow_public + query = ( + get_company_or_none_constraint(company) + if allow_public + else Q(company=company) ) - if children_condition: - contained_entities_query = None - else: - contained_entities_query = ( - get_company_or_none_constraint(company) - if allow_public - else Q(company=company) - ) - if users: - user_query = Q(user__in=users) - projects_query &= user_query - if contained_entities_query: - contained_entities_query &= user_query + query &= Q(user__in=users) + + if children_type == ProjectChildrenType.dataset: + project_query = query & Q(system_tags__in=[dataset_tag]) + entity_queries = {} + elif children_type == ProjectChildrenType.pipeline: + project_query = query & Q(system_tags__in=[pipeline_tag]) + entity_queries = {} + elif children_type == ProjectChildrenType.report: + project_query = None + entity_queries = {Task: query & Q(system_tags__in=[reports_tag])} + else: + project_query = query + entity_queries = {entity_cls: query for entity_cls in cls.child_classes} if project_ids: ids_with_children = _ids_with_children(project_ids) - projects_query &= Q(id__in=ids_with_children) - if contained_entities_query: - contained_entities_query &= Q(project__in=ids_with_children) + if project_query: + project_query &= Q(id__in=ids_with_children) + for entity_cls in entity_queries: + entity_queries[entity_cls] &= Q(project__in=ids_with_children) - res = {p.id for p in Project.objects(projects_query).only("id")} - if contained_entities_query: - for cls_ in (Task, Model): - res |= set(cls_.objects(contained_entities_query).distinct(field="project")) + res = {p.id for p in Project.objects(project_query).only("id")} if project_query else set() + for cls_, query_ in entity_queries.items(): + res |= set(cls_.objects(query_).distinct(field="project")) res = list(res) if not res: diff --git a/apiserver/bll/project/project_cleanup.py b/apiserver/bll/project/project_cleanup.py index 2ba33f6..1a2d601 100644 --- a/apiserver/bll/project/project_cleanup.py +++ b/apiserver/bll/project/project_cleanup.py @@ -1,3 +1,4 @@ +from collections import defaultdict from typing import Tuple, Set, Sequence import attr @@ -15,6 +16,7 @@ from apiserver.database.model import EntityVisibility from apiserver.database.model.model import Model from apiserver.database.model.project import Project from apiserver.database.model.task.task import Task, ArtifactModes, TaskType +from .project_bll import ProjectBLL from .sub_projects import _ids_with_children log = config.logger(__file__) @@ -40,9 +42,9 @@ def validate_project_delete(company: str, project_id: str): is_pipeline = "pipeline" in (project.system_tags or []) project_ids = _ids_with_children([project_id]) ret = {} - for cls in (Task, Model): + for cls in ProjectBLL.child_classes: ret[f"{cls.__name__.lower()}s"] = cls.objects(project__in=project_ids).count() - for cls in (Task, Model): + for cls in ProjectBLL.child_classes: query = dict( project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value] ) @@ -98,9 +100,10 @@ def delete_project( ) if not delete_contents: - for cls in (Model, Task): - updated_count = cls.objects(project__in=project_ids).update(project=None) - res = DeleteProjectResult(disassociated_tasks=updated_count) + disassociated = defaultdict(int) + for cls in ProjectBLL.child_classes: + disassociated[cls] = cls.objects(project__in=project_ids).update(project=None) + res = DeleteProjectResult(disassociated_tasks=disassociated[Task]) else: deleted_models, model_event_urls, model_urls = _delete_models( company=company, projects=project_ids diff --git a/apiserver/schema/services/projects.conf b/apiserver/schema/services/projects.conf index 7ffe9df..95ee096 100644 --- a/apiserver/schema/services/projects.conf +++ b/apiserver/schema/services/projects.conf @@ -621,16 +621,10 @@ get_all_ex { } } "2.24": ${get_all_ex."2.23"} { - request.properties.children_condition { - description: The filter that any of the child projects should match in order that the parent will be included - type: object - properties { - system_tags { - description: The list of system tags to match from - type: string - } - } - additionalProperties: true + request.properties.children_type { + description: If specified that only the projects under which the entities of this type can be found will be returned + type: string + enum: [pipeline, report, dataset] } } } diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py index 7d1c324..c17aa7c 100644 --- a/apiserver/services/projects.py +++ b/apiserver/services/projects.py @@ -1,4 +1,4 @@ -from typing import Sequence +from typing import Sequence, Optional import attr from mongoengine import Q @@ -18,9 +18,11 @@ from apiserver.apimodels.projects import ( ProjectOrNoneRequest, ProjectRequest, ProjectModelMetadataValuesRequest, + ProjectChildrenType, ) from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL, ProjectQueries +from apiserver.bll.project.project_bll import dataset_tag, pipeline_tag, reports_tag from apiserver.bll.project.project_cleanup import ( delete_project, validate_project_delete, @@ -28,6 +30,7 @@ from apiserver.bll.project.project_cleanup import ( from apiserver.database.errors import translate_errors_context from apiserver.database.model import EntityVisibility from apiserver.database.model.project import Project +from apiserver.database.model.task.task import TaskType from apiserver.database.utils import ( parse_from_call, get_company_or_none_constraint, @@ -96,6 +99,17 @@ def _adjust_search_parameters(data: dict, shallow_search: bool): data["parent"] = [None] +def _get_filter_from_children_type(type_: ProjectChildrenType) -> Optional[dict]: + if type_ == ProjectChildrenType.dataset: + return {"system_tags": [dataset_tag], "type": [TaskType.data_processing]} + if type_ == ProjectChildrenType.pipeline: + return {"system_tags": [pipeline_tag], "type": [TaskType.controller]} + if type_ == ProjectChildrenType.report: + return {"system_tags": [reports_tag], "type": [TaskType.report]} + + return None + + @endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest) def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): data = call.data @@ -115,15 +129,13 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): data, shallow_search=request.shallow_search, ) selected_project_ids = None - if request.active_users or request.children_condition: + if request.active_users or request.children_type: ids, selected_project_ids = project_bll.get_projects_with_selected_children( company=company_id, users=request.active_users, project_ids=requested_ids, allow_public=allow_public, - children_condition=request.children_condition.to_struct() - if request.children_condition - else None, + children_type=request.children_type, ) if not ids: return {"projects": []} @@ -140,33 +152,42 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): if not projects: return {"projects": projects, **ret_params} - project_ids = list({project["id"] for project in projects}) - if request.check_own_contents: - contents = project_bll.calc_own_contents( - company=company_id, - project_ids=project_ids, - filter_=request.include_stats_filter, - users=request.active_users, - ) - for project in projects: - project.update(**contents.get(project["id"], {})) - conform_output_tags(call, projects) - if request.include_stats: - stats, children = project_bll.get_project_stats( - company=company_id, - project_ids=project_ids, - specific_state=request.stats_for_state, - include_children=request.stats_with_children, - search_hidden=request.search_hidden, - filter_=request.include_stats_filter, - users=request.active_users, - selected_project_ids=selected_project_ids, - ) + project_ids = list({project["id"] for project in projects}) - for project in projects: - project["stats"] = stats[project["id"]] - project["sub_projects"] = children[project["id"]] + if request.check_own_contents or request.include_stats: + if request.children_type and not request.include_stats_filter: + filter_ = _get_filter_from_children_type(request.children_type) + search_hidden = True if filter_ else request.search_hidden + else: + filter_ = request.include_stats_filter + search_hidden = request.search_hidden + + if request.check_own_contents: + contents = project_bll.calc_own_contents( + company=company_id, + project_ids=project_ids, + filter_=filter_, + users=request.active_users, + ) + for project in projects: + project.update(**contents.get(project["id"], {})) + + if request.include_stats: + stats, children = project_bll.get_project_stats( + company=company_id, + project_ids=project_ids, + specific_state=request.stats_for_state, + include_children=request.stats_with_children, + search_hidden=search_hidden, + filter_=filter_, + users=request.active_users, + selected_project_ids=selected_project_ids, + ) + + for project in projects: + project["stats"] = stats[project["id"]] + project["sub_projects"] = children[project["id"]] if request.include_dataset_stats: dataset_stats = project_bll.get_dataset_stats( diff --git a/apiserver/services/reports.py b/apiserver/services/reports.py index 689d744..d29afb5 100644 --- a/apiserver/services/reports.py +++ b/apiserver/services/reports.py @@ -16,6 +16,7 @@ from apiserver.apimodels.reports import ( ) from apiserver.apierrors import errors from apiserver.apimodels.base import UpdateResponse +from apiserver.bll.project.project_bll import reports_project_name, reports_tag from apiserver.services.utils import process_include_subprojects, sort_tags_response from apiserver.bll.organization import OrgBLL from apiserver.bll.project import ProjectBLL @@ -42,8 +43,6 @@ project_bll = ProjectBLL() task_bll = TaskBLL() -reports_project_name = ".reports" -reports_tag = "reports" update_fields = { "name", "tags", @@ -80,7 +79,9 @@ def update_report(call: APICall, company_id: str, request: UpdateReportRequest): if not partial_update_dict: return UpdateResponse(updated=0) - allowed_for_published = set(partial_update_dict.keys()).issubset({"tags", "name", "comment"}) + allowed_for_published = set(partial_update_dict.keys()).issubset( + {"tags", "name", "comment"} + ) if task.status != TaskStatus.created and not allowed_for_published: raise errors.bad_request.InvalidTaskStatus( expected=TaskStatus.created, status=task.status diff --git a/apiserver/tests/automated/test_subprojects.py b/apiserver/tests/automated/test_subprojects.py index 44c327a..52b56ae 100644 --- a/apiserver/tests/automated/test_subprojects.py +++ b/apiserver/tests/automated/test_subprojects.py @@ -36,27 +36,56 @@ class TestSubProjects(TestService): def test_query_children(self): test_root_name = "TestQueryChildren" test_root = self._temp_project(name=test_root_name) - child_with_tag = self._temp_project( - name=f"{test_root_name}/Project1/WithTag", system_tags=["test"] + dataset_project = self._temp_project( + name=f"{test_root_name}/Project1/Dataset", system_tags=["dataset"] ) - child_without_tag = self._temp_project(name=f"{test_root_name}/Project2/WithoutTag") - - projects = self.api.projects.get_all_ex(parent=[test_root], shallow_search=True).projects - self.assertEqual({p.basename for p in projects}, {"Project1", "Project2"}) + self._temp_task( + name="dataset task", + type="data_processing", + system_tags=["dataset"], + project=dataset_project, + ) + self._temp_task(name="regular task", project=dataset_project) + pipeline_project = self._temp_project( + name=f"{test_root_name}/Project2/Pipeline", system_tags=["pipeline"] + ) + self._temp_task( + name="pipeline task", + type="controller", + system_tags=["pipeline"], + project=pipeline_project, + ) + self._temp_task(name="regular task", project=pipeline_project) + report_project = self._temp_project(name=f"{test_root_name}/Project3") + self._temp_report(name="test report", project=report_project) + self._temp_task(name="regular task", project=report_project) projects = self.api.projects.get_all_ex( - parent=[test_root], children_condition={"system_tags": ["test"]}, shallow_search=True + parent=[test_root], shallow_search=True, include_stats=True ).projects - self.assertEqual({p.basename for p in projects}, {"Project1"}) - projects = self.api.projects.get_all_ex( - parent=[projects[0].id], children_condition={"system_tags": ["test"]}, shallow_search=True - ).projects - self.assertEqual(projects[0].id, child_with_tag) + self.assertEqual( + {p.basename for p in projects}, {f"Project{idx+1}" for idx in range(3)} + ) + for p in projects: + self.assertEqual( + p.stats.active.total_tasks, + 2 + if p.basename in ("Project1", "Project2") + else 1 + ) - projects = self.api.projects.get_all_ex( - parent=[test_root], children_condition={"system_tags": ["not existent"]}, shallow_search=True - ).projects - self.assertEqual(len(projects), 0) + for i, type_ in enumerate(("dataset", "pipeline", "report")): + projects = self.api.projects.get_all_ex( + parent=[test_root], + children_type=type_, + shallow_search=True, + include_stats=True, + ).projects + self.assertEqual({p.basename for p in projects}, {f"Project{i+1}"}) + p = projects[0] + self.assertEqual( + p.stats.active.total_tasks, 1 + ) def test_project_aggregations(self): """This test requires user with user_auth_only... credentials in db""" @@ -323,12 +352,21 @@ class TestSubProjects(TestService): **kwargs, ) - def _temp_task(self, client=None, **kwargs): + def _temp_report(self, name, **kwargs): + return self.create_temp( + "reports", + name=name, + object_name="task", + delete_params=self.delete_params, + **kwargs, + ) + + def _temp_task(self, client=None, name=None, type=None, **kwargs): return self.create_temp( "tasks", delete_params=self.delete_params, - type="testing", - name=db_id(), + type=type or "testing", + name=name or db_id(), input=dict(view=dict()), client=client, **kwargs, diff --git a/apiserver/tests/automated/test_tasks_delete.py b/apiserver/tests/automated/test_tasks_delete.py index 1f4652e..e589eff 100644 --- a/apiserver/tests/automated/test_tasks_delete.py +++ b/apiserver/tests/automated/test_tasks_delete.py @@ -50,7 +50,9 @@ class TestTasksResetDelete(TestService): self.assertEqual(res.urls.artifact_urls, []) task = self.new_task() - (_, published_model_urls), (model, draft_model_urls) = self.create_task_models(task) + (_, published_model_urls), (model, draft_model_urls) = self.create_task_models( + task + ) artifact_urls = self.send_artifacts(task) event_urls = self.send_debug_image_events(task) event_urls.update(self.send_plot_events(task)) @@ -74,7 +76,12 @@ class TestTasksResetDelete(TestService): self.api.tasks.reset(task=task, force=True) # test urls - task, (published_model_urls, draft_model_urls), artifact_urls, event_urls = self.create_task_with_data() + ( + task, + (published_model_urls, draft_model_urls), + artifact_urls, + event_urls, + ) = self.create_task_with_data() res = self.api.tasks.reset(task=task, force=True, return_file_urls=True) self.assertEqual(set(res.urls.model_urls), draft_model_urls) self.assertEqual(set(res.urls.event_urls), event_urls) @@ -101,13 +108,18 @@ class TestTasksResetDelete(TestService): # with delete_contents flag project = self.new_project() - task, (published_model_urls, draft_model_urls), artifact_urls, event_urls = self.create_task_with_data( - project=project - ) + ( + task, + (published_model_urls, draft_model_urls), + artifact_urls, + event_urls, + ) = self.create_task_with_data(project=project) res = self.api.projects.delete( project=project, force=True, delete_contents=True ) - self.assertEqual(set(res.urls.model_urls), published_model_urls | draft_model_urls) + self.assertEqual( + set(res.urls.model_urls), published_model_urls | draft_model_urls + ) self.assertEqual(res.deleted, 1) self.assertEqual(res.disassociated_tasks, 0) self.assertEqual(res.deleted_tasks, 1) @@ -121,7 +133,9 @@ class TestTasksResetDelete(TestService): self, **kwargs ) -> Tuple[str, Tuple[Set[str], Set[str]], Set[str], Set[str]]: task = self.new_task(**kwargs) - (_, published_model_urls), (model, draft_model_urls) = self.create_task_models(task, **kwargs) + (_, published_model_urls), (model, draft_model_urls) = self.create_task_models( + task, **kwargs + ) artifact_urls = self.send_artifacts(task) event_urls = self.send_debug_image_events(task) event_urls.update(self.send_plot_events(task)) @@ -172,7 +186,7 @@ class TestTasksResetDelete(TestService): ), self.create_event( model, "plot", 0, plot_str=f'{{"source": "{url2}"}}', model_event=True - ) + ), ] self.send_batch(events) return {url1, url2} @@ -181,7 +195,10 @@ class TestTasksResetDelete(TestService): url_pattern = "url_{num}.txt" events = [ self.create_event( - task, "training_debug_image", iteration, url=url_pattern.format(num=iteration) + task, + "training_debug_image", + iteration, + url=url_pattern.format(num=iteration), ) for iteration in range(5) ]