From c034c1a986c67a229bf8caa47c6801b890a8f2d5 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 3 May 2021 17:42:10 +0300 Subject: [PATCH] Add sub-projects support --- apiserver/apierrors/errors.conf | 4 +- apiserver/apimodels/projects.py | 21 +- apiserver/bll/model/__init__.py | 18 - apiserver/bll/organization/__init__.py | 37 +- apiserver/bll/organization/tags_cache.py | 3 +- apiserver/bll/project/__init__.py | 1 + apiserver/bll/project/project_bll.py | 350 ++++++++++++++++-- apiserver/bll/project/project_cleanup.py | 43 ++- apiserver/bll/project/sub_projects.py | 163 ++++++++ apiserver/bll/task/task_bll.py | 13 - .../config/default/services/projects.conf | 5 + apiserver/database/model/project.py | 8 +- apiserver/mongo/initialize/pre_populate.py | 23 +- apiserver/schema/services/projects.conf | 95 ++++- apiserver/services/models.py | 4 +- apiserver/services/projects.py | 114 ++++-- apiserver/services/tasks.py | 2 +- apiserver/tests/automated/test_subprojects.py | 240 ++++++++++++ 18 files changed, 967 insertions(+), 177 deletions(-) delete mode 100644 apiserver/bll/model/__init__.py create mode 100644 apiserver/bll/project/sub_projects.py create mode 100644 apiserver/tests/automated/test_subprojects.py diff --git a/apiserver/apierrors/errors.conf b/apiserver/apierrors/errors.conf index a1017eb..a78d780 100644 --- a/apiserver/apierrors/errors.conf +++ b/apiserver/apierrors/errors.conf @@ -63,7 +63,9 @@ 403: ["project_not_found", "project not found"] 405: ["project_has_models", "project has associated models"] 407: ["invalid_project_name", "invalid project name"] - 408: ["cannot_update_project_location", "cannot update project location"] + 408: ["cannot_update_project_location", "Cannot update project location. Use projects.move instead"] + 409: ["project_path_exceeds_max", "Project path exceed the maximum allowed depth"] + 410: ["project_source_and_destination_are_the_same", "Project has the same source and destination paths"] # Queues 701: ["invalid_queue_id", "invalid queue id"] diff --git a/apiserver/apimodels/projects.py b/apiserver/apimodels/projects.py index 2070576..454c5dc 100644 --- a/apiserver/apimodels/projects.py +++ b/apiserver/apimodels/projects.py @@ -5,16 +5,24 @@ from apiserver.apimodels.organization import TagsRequest from apiserver.database.model import EntityVisibility -class ProjectReq(models.Base): +class ProjectRequest(models.Base): project = fields.StringField(required=True) -class DeleteRequest(ProjectReq): +class MergeRequest(ProjectRequest): + destination_project = fields.StringField() + + +class MoveRequest(ProjectRequest): + new_location = fields.StringField() + + +class DeleteRequest(ProjectRequest): force = fields.BoolField(default=False) delete_contents = fields.BoolField(default=False) -class GetHyperParamReq(ProjectReq): +class GetHyperParamRequest(ProjectRequest): page = fields.IntField(default=0) page_size = fields.IntField(default=500) @@ -23,15 +31,15 @@ class ProjectTagsRequest(TagsRequest): projects = ListField(str) -class MultiProjectReq(models.Base): +class MultiProjectRequest(models.Base): projects = fields.ListField(str) -class ProjectTaskParentsRequest(MultiProjectReq): +class ProjectTaskParentsRequest(MultiProjectRequest): tasks_state = ActualEnumField(EntityVisibility) -class ProjectHyperparamValuesRequest(MultiProjectReq): +class ProjectHyperparamValuesRequest(MultiProjectRequest): section = fields.StringField(required=True) name = fields.StringField(required=True) allow_public = fields.BoolField(default=True) @@ -42,3 +50,4 @@ class ProjectsGetRequest(models.Base): stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active) non_public = fields.BoolField(default=False) active_users = fields.ListField(str) + shallow_search = fields.BoolField(default=False) diff --git a/apiserver/bll/model/__init__.py b/apiserver/bll/model/__init__.py deleted file mode 100644 index 34278ec..0000000 --- a/apiserver/bll/model/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Optional, Sequence - -from mongoengine import Q - -from apiserver.database.model.model import Model -from apiserver.database.utils import get_company_or_none_constraint - - -class ModelBLL: - def get_frameworks(self, company, project_ids: Optional[Sequence]) -> Sequence: - """ - Return the list of unique frameworks used by company and public models - If project ids passed then only models from these projects are considered - """ - query = get_company_or_none_constraint(company) - if project_ids: - query &= Q(project__in=project_ids) - return Model.objects(query).distinct(field="framework") diff --git a/apiserver/bll/organization/__init__.py b/apiserver/bll/organization/__init__.py index bb6baf1..9fd64e2 100644 --- a/apiserver/bll/organization/__init__.py +++ b/apiserver/bll/organization/__init__.py @@ -1,12 +1,8 @@ from collections import defaultdict from enum import Enum -from operator import itemgetter -from typing import Sequence, Dict, Optional - -from mongoengine import Q +from typing import Sequence, Dict from apiserver.config_repo import config -from apiserver.database.model import EntityVisibility from apiserver.database.model.model import Model from apiserver.database.model.task.task import Task from apiserver.redis_manager import redman @@ -65,34 +61,3 @@ class OrgBLL: def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache: return self._task_tags if entity == Tags.Task else self._model_tags - - @classmethod - def get_parent_tasks( - cls, - company_id: str, - projects: Sequence[str], - state: Optional[EntityVisibility] = None, - ) -> Sequence[dict]: - """ - Get list of unique parent tasks sorted by task name for the passed company projects - If projects is None or empty then get parents for all the company tasks - """ - query = Q(company=company_id) - if projects: - query &= Q(project__in=projects) - if state == EntityVisibility.archived: - query &= Q(system_tags__in=[EntityVisibility.archived.value]) - elif state == EntityVisibility.active: - query &= Q(system_tags__nin=[EntityVisibility.archived.value]) - - parent_ids = set(Task.objects(query).distinct("parent")) - if not parent_ids: - return [] - - parents = Task.get_many_with_join( - company_id, - query=Q(id__in=parent_ids), - allow_public=True, - override_projection=("id", "name", "project.name"), - ) - return sorted(parents, key=itemgetter("name")) diff --git a/apiserver/bll/organization/tags_cache.py b/apiserver/bll/organization/tags_cache.py index 33e7e2c..7ed62ce 100644 --- a/apiserver/bll/organization/tags_cache.py +++ b/apiserver/bll/organization/tags_cache.py @@ -5,6 +5,7 @@ from mongoengine import Q from redis import Redis from apiserver.config_repo import config +from apiserver.bll.project import project_ids_with_children from apiserver.database.model.base import GetMixin from apiserver.database.model.model import Model from apiserver.database.model.task.task import Task @@ -40,7 +41,7 @@ class _TagsCache: if vals: query &= GetMixin.get_list_field_query(name, vals) if project: - query &= Q(project=project) + query &= Q(project__in=project_ids_with_children([project])) return self.db_cls.objects(query).distinct(field) diff --git a/apiserver/bll/project/__init__.py b/apiserver/bll/project/__init__.py index 0b8ab93..3ebcb19 100644 --- a/apiserver/bll/project/__init__.py +++ b/apiserver/bll/project/__init__.py @@ -1 +1,2 @@ from .project_bll import ProjectBLL +from .sub_projects import _ids_with_children as project_ids_with_children diff --git a/apiserver/bll/project/project_bll.py b/apiserver/bll/project/project_bll.py index dff97d5..236bb56 100644 --- a/apiserver/bll/project/project_bll.py +++ b/apiserver/bll/project/project_bll.py @@ -1,8 +1,20 @@ +import itertools from collections import defaultdict from datetime import datetime +from functools import reduce from itertools import groupby from operator import itemgetter -from typing import Sequence, Optional, Type, Tuple, Dict +from typing import ( + Sequence, + Optional, + Type, + Tuple, + Dict, + Set, + TypeVar, + Callable, + Mapping, +) from mongoengine import Q, Document @@ -12,35 +24,122 @@ from apiserver.config_repo import config from apiserver.database.model import EntityVisibility from apiserver.database.model.model import Model from apiserver.database.model.project import Project -from apiserver.database.model.task.task import Task, TaskStatus -from apiserver.database.utils import get_options +from apiserver.database.model.task.task import Task, TaskStatus, external_task_types +from apiserver.database.utils import get_options, get_company_or_none_constraint from apiserver.timing_context import TimingContext -from apiserver.tools import safe_get +from apiserver.utilities.dicts import nested_get +from .sub_projects import ( + _reposition_project_with_children, + _ensure_project, + _validate_project_name, + _update_subproject_names, + _save_under_parent, + _get_sub_projects, + _ids_with_children, + _ids_with_parents, +) log = config.logger(__file__) class ProjectBLL: @classmethod - def get_active_users( - cls, company, project_ids: Sequence, user_ids: Optional[Sequence] = None - ) -> set: + def merge_project( + cls, company, source_id: str, destination_id: str + ) -> Tuple[int, int, Set[str]]: """ - Get the set of user ids that created tasks/models in the given projects - If project_ids is empty then all projects are examined - If user_ids are passed then only subset of these users is returned + Move all the tasks and sub projects from the source project to the destination + Remove the source project + Return the amounts of moved entities and subprojects + set of all the affected project ids """ - with TimingContext("mongo", "active_users_in_projects"): - res = set() - query = Q(company=company) - if project_ids: - query &= Q(project__in=project_ids) - if user_ids: - query &= Q(user__in=user_ids) - for cls_ in (Task, Model): - res |= set(cls_.objects(query).distinct(field="user")) + with TimingContext("mongo", "move_project"): + if source_id == destination_id: + raise errors.bad_request.ProjectSourceAndDestinationAreTheSame( + parent=source_id + ) + source = Project.get(company, source_id) + destination = Project.get(company, destination_id) - return res + moved_entities = 0 + for entity_type in (Task, Model): + moved_entities += entity_type.objects( + company=company, + project=source_id, + system_tags__nin=[EntityVisibility.archived.value], + ).update(upsert=False, project=destination_id) + + moved_sub_projects = 0 + for child in Project.objects(company=company, parent=source_id): + _reposition_project_with_children(project=child, parent=destination) + moved_sub_projects += 1 + + affected = {source.id, *(source.path or [])} + source.delete() + + if destination: + destination.update(last_update=datetime.utcnow()) + affected.update({destination.id, *(destination.path or [])}) + + return moved_entities, moved_sub_projects, affected + + @classmethod + def move_project( + cls, company: str, user: str, project_id: str, new_location: str + ) -> Tuple[int, Set[str]]: + """ + Move project with its sub projects from its current location to the target one. + If the target location does not exist then it will be created. If it exists then + it should be writable. The source location should be writable too. + Return the number of moved projects + set of all the affected project ids + """ + with TimingContext("mongo", "move_project"): + project = Project.get(company, project_id) + old_parent_id = project.parent + old_parent = ( + Project.get_for_writing(company=project.company, id=old_parent_id) + if old_parent_id + else None + ) + new_parent = _ensure_project(company=company, user=user, name=new_location) + new_parent_id = new_parent.id if new_parent else None + if old_parent_id == new_parent_id: + raise errors.bad_request.ProjectSourceAndDestinationAreTheSame( + location=new_parent.name if new_parent else "" + ) + + moved = _reposition_project_with_children(project, parent=new_parent) + + now = datetime.utcnow() + affected = set() + for p in filter(None, (old_parent, new_parent)): + p.update(last_update=now) + affected.update({p.id, *(p.path or [])}) + + return moved, affected + + @classmethod + def update(cls, company: str, project_id: str, **fields): + with TimingContext("mongo", "projects_update"): + project = Project.get_for_writing(company=company, id=project_id) + if not project: + raise errors.bad_request.InvalidProjectId(id=project_id) + + new_name = fields.pop("name", None) + if new_name: + new_name, new_location = _validate_project_name(new_name) + old_name, old_location = _validate_project_name(project.name) + if new_location != old_location: + raise errors.bad_request.CannotUpdateProjectLocation(name=new_name) + fields["name"] = new_name + + fields["last_update"] = datetime.utcnow() + updated = project.update(upsert=False, **fields) + + if new_name: + project.name = new_name + _update_subproject_names(project=project) + + return updated @classmethod def create( @@ -57,6 +156,7 @@ class ProjectBLL: Create a new project. Returns project ID """ + name, location = _validate_project_name(name) now = datetime.utcnow() project = Project( id=database.utils.id(), @@ -70,7 +170,11 @@ class ProjectBLL: created=now, last_update=now, ) - project.save() + parent = _ensure_project(company=company, user=user, name=location) + _save_under_parent(project=project, parent=parent) + if parent: + parent.update(last_update=now) + return project.id @classmethod @@ -98,6 +202,7 @@ class ProjectBLL: raise errors.bad_request.InvalidProjectId(id=project_id) return project_id + project_name, _ = _validate_project_name(project_name) project = Project.objects(company=company, name=project_name).only("id").first() if project: return project.id @@ -265,18 +370,48 @@ class ProjectBLL: return status_count_pipeline, runtime_pipeline + T = TypeVar("T") + + @staticmethod + def aggregate_project_data( + func: Callable[[T, T], T], + project_ids: Sequence[str], + child_projects: Mapping[str, Sequence[Project]], + data: Mapping[str, T], + ) -> Dict[str, T]: + """ + Given a list of project ids and data collected over these projects and their subprojects + For each project aggregates the data from all of its subprojects + """ + aggregated = {} + if not data: + return aggregated + for pid in project_ids: + relevant_projects = {p.id for p in child_projects.get(pid, [])} | {pid} + relevant_data = [data for p, data in data.items() if p in relevant_projects] + if not relevant_data: + continue + aggregated[pid] = reduce(func, relevant_data) + return aggregated + @classmethod def get_project_stats( cls, company: str, project_ids: Sequence[str], specific_state: Optional[EntityVisibility] = None, - ) -> Dict[str, dict]: + ) -> Tuple[Dict[str, dict], Dict[str, dict]]: if not project_ids: - return {} + return {}, {} + child_projects = _get_sub_projects(project_ids, _only=("id", "name")) + project_ids_with_children = set(project_ids) | { + c.id for c in itertools.chain.from_iterable(child_projects.values()) + } status_count_pipeline, runtime_pipeline = cls.make_projects_get_all_pipelines( - company, project_ids=project_ids, specific_state=specific_state + company, + project_ids=list(project_ids_with_children), + specific_state=specific_state, ) default_counts = dict.fromkeys(get_options(TaskStatus), 0) @@ -298,23 +433,58 @@ class ProjectBLL: } ) + def sum_status_count( + a: Mapping[str, Mapping], b: Mapping[str, Mapping] + ) -> Dict[str, dict]: + return { + section: { + status: nested_get(a, (section, status), 0) + + nested_get(b, (section, status), 0) + for status in set(a.get(section, {})) | set(b.get(section, {})) + } + for section in set(a) | set(b) + } + + status_count = cls.aggregate_project_data( + func=sum_status_count, + project_ids=project_ids, + child_projects=child_projects, + data=status_count, + ) + runtime = { result["_id"]: {k: v for k, v in result.items() if k != "_id"} for result in Task.aggregate(runtime_pipeline) } - def get_status_counts(project_id, section): - path = "/".join((project_id, section)) + def sum_runtime( + a: Mapping[str, Mapping], b: Mapping[str, Mapping] + ) -> Dict[str, dict]: return { - "total_runtime": safe_get(runtime, path, 0), - "status_count": safe_get(status_count, path, default_counts), + section: a.get(section, 0) + b.get(section, 0) + for section in set(a) | set(b) + } + + runtime = cls.aggregate_project_data( + func=sum_runtime, + project_ids=project_ids, + child_projects=child_projects, + data=runtime, + ) + + def get_status_counts(project_id, section): + return { + "total_runtime": nested_get(runtime, (project_id, section), 0), + "status_count": nested_get( + status_count, (project_id, section), default_counts + ), } report_for_states = [ s for s in EntityVisibility if not specific_state or specific_state == s ] - return { + stats = { project: { task_state.value: get_status_counts(project, task_state.value) for task_state in report_for_states @@ -322,6 +492,40 @@ class ProjectBLL: for project in project_ids } + children = { + project: sorted( + [{"id": c.id, "name": c.name} for c in child_projects.get(project, [])], + key=itemgetter("name"), + ) + for project in project_ids + } + return stats, children + + @classmethod + def get_active_users( + cls, + company, + project_ids: Sequence[str], + user_ids: Optional[Sequence[str]] = None, + ) -> set: + """ + Get the set of user ids that created tasks/models/dataviews in the given projects + If project_ids is empty then all projects are examined + If user_ids are passed then only subset of these users is returned + """ + with TimingContext("mongo", "active_users_in_projects"): + res = set() + query = Q(company=company) + if project_ids: + project_ids = _ids_with_children(project_ids) + query &= Q(project__in=project_ids) + if user_ids: + query &= Q(user__in=user_ids) + for cls_ in (Task, Model): + res |= set(cls_.objects(query).distinct(field="user")) + + return res + @classmethod def get_projects_with_active_user( cls, @@ -330,13 +534,83 @@ class ProjectBLL: project_ids: Optional[Sequence[str]] = None, allow_public: bool = True, ) -> Sequence[str]: - """Get the projects ids where user created any tasks""" - company = ( - {"company__in": [None, "", company]} - if allow_public - else {"company": company} - ) - projects = {"project__in": project_ids} if project_ids else {} - return Task.objects(**company, user__in=users, **projects).distinct( - field="project" + """ + Get the projects ids where user created any tasks including all the parents of these projects + If project ids are specified then filter the results by these project ids + """ + query = Q(user__in=users) + + if allow_public: + query &= get_company_or_none_constraint(company) + else: + query &= Q(company=company) + + if project_ids: + query &= Q(project__in=_ids_with_children(project_ids)) + + res = Task.objects(query).distinct(field="project") + if not res: + return res + + ids_with_parents = _ids_with_parents(res) + if project_ids: + return [pid for pid in ids_with_parents if pid in project_ids] + + return ids_with_parents + + @classmethod + def get_task_parents( + cls, + company_id: str, + projects: Sequence[str], + state: Optional[EntityVisibility] = None, + ) -> Sequence[dict]: + """ + Get list of unique parent tasks sorted by task name for the passed company projects + If projects is None or empty then get parents for all the company tasks + """ + query = Q(company=company_id) + if projects: + projects = _ids_with_children(projects) + query &= Q(project__in=projects) + if state == EntityVisibility.archived: + query &= Q(system_tags__in=[EntityVisibility.archived.value]) + elif state == EntityVisibility.active: + query &= Q(system_tags__nin=[EntityVisibility.archived.value]) + + parent_ids = set(Task.objects(query).distinct("parent")) + if not parent_ids: + return [] + + parents = Task.get_many_with_join( + company_id, + query=Q(id__in=parent_ids), + allow_public=True, + override_projection=("id", "name", "project.name"), ) + return sorted(parents, key=itemgetter("name")) + + @classmethod + def get_task_types(cls, company, project_ids: Optional[Sequence]) -> set: + """ + Return the list of unique task types used by company and public tasks + If project ids passed then only tasks from these projects are considered + """ + query = get_company_or_none_constraint(company) + if project_ids: + project_ids = _ids_with_children(project_ids) + query &= Q(project__in=project_ids) + res = Task.objects(query).distinct(field="type") + return set(res).intersection(external_task_types) + + @classmethod + def get_model_frameworks(cls, company, project_ids: Optional[Sequence]) -> Sequence: + """ + Return the list of unique frameworks used by company and public models + If project ids passed then only models from these projects are considered + """ + query = get_company_or_none_constraint(company) + if project_ids: + project_ids = _ids_with_children(project_ids) + query &= Q(project__in=project_ids) + return Model.objects(query).distinct(field="framework") diff --git a/apiserver/bll/project/project_cleanup.py b/apiserver/bll/project/project_cleanup.py index 1c5aa38..31d82cb 100644 --- a/apiserver/bll/project/project_cleanup.py +++ b/apiserver/bll/project/project_cleanup.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Tuple, Set +from typing import Tuple, Set, Sequence import attr @@ -16,6 +16,7 @@ from apiserver.database.model.model import Model from apiserver.database.model.project import Project from apiserver.database.model.task.task import Task, ArtifactModes from apiserver.timing_context import TimingContext +from .sub_projects import _ids_with_children log = config.logger(__file__) event_bll = EventBLL() @@ -32,18 +33,22 @@ class DeleteProjectResult: def delete_project( company: str, project_id: str, force: bool, delete_contents: bool -) -> DeleteProjectResult: - project = Project.get_for_writing(company=company, id=project_id) +) -> Tuple[DeleteProjectResult, Set[str]]: + project = Project.get_for_writing( + company=company, id=project_id, _only=("id", "path") + ) if not project: raise errors.bad_request.InvalidProjectId(id=project_id) + project_ids = _ids_with_children([project_id]) if not force: for cls, error in ( (Task, errors.bad_request.ProjectHasTasks), (Model, errors.bad_request.ProjectHasModels), ): non_archived = cls.objects( - project=project_id, system_tags__nin=[EntityVisibility.archived.value], + project__in=project_ids, + system_tags__nin=[EntityVisibility.archived.value], ).only("id") if non_archived: raise error("use force=true to delete", id=project_id) @@ -51,12 +56,14 @@ def delete_project( if not delete_contents: with TimingContext("mongo", "update_children"): for cls in (Model, Task): - updated_count = cls.objects(project=project_id).update(project=None) + updated_count = cls.objects(project__in=project_ids).update( + project=None + ) res = DeleteProjectResult(disassociated_tasks=updated_count) else: - deleted_models, model_urls = _delete_models(project=project_id) + deleted_models, model_urls = _delete_models(projects=project_ids) deleted_tasks, event_urls, artifact_urls = _delete_tasks( - company=company, project=project_id + company=company, projects=project_ids ) res = DeleteProjectResult( deleted_tasks=deleted_tasks, @@ -68,25 +75,27 @@ def delete_project( ), ) - res.deleted = Project.objects(id=project_id).delete() - return res + affected = {*project_ids, *(project.path or [])} + res.deleted = Project.objects(id__in=project_ids).delete() + + return res, affected -def _delete_tasks(company: str, project: str) -> Tuple[int, Set, Set]: +def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]: """ Delete only the task themselves and their non published version. Child models under the same project are deleted separately. Children tasks should be deleted in the same api call. If any child entities are left in another projects then updated their parent task to None """ - tasks = Task.objects(project=project).only("id", "execution__artifacts") + tasks = Task.objects(project__in=projects).only("id", "execution__artifacts") if not tasks: return 0, set(), set() task_ids = {t.id for t in tasks} with TimingContext("mongo", "delete_tasks_update_children"): - Task.objects(parent__in=task_ids, project__ne=project).update(parent=None) - Model.objects(task__in=task_ids, project__ne=project).update(task=None) + Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None) + Model.objects(task__in=task_ids, project__nin=projects).update(task=None) event_urls, artifact_urls = set(), set() for task in tasks: @@ -106,18 +115,18 @@ def _delete_tasks(company: str, project: str) -> Tuple[int, Set, Set]: return deleted, event_urls, artifact_urls -def _delete_models(project: str) -> Tuple[int, Set[str]]: +def _delete_models(projects: Sequence[str]) -> Tuple[int, Set[str]]: """ Delete project models and update the tasks from other projects that reference them to reference None. """ with TimingContext("mongo", "delete_models"): - models = Model.objects(project=project).only("task", "id", "uri") + models = Model.objects(project__in=projects).only("task", "id", "uri") if not models: return 0, set() model_ids = {m.id for m in models} - Task.objects(execution__model__in=model_ids, project__ne=project).update( + Task.objects(execution__model__in=model_ids, project__nin=projects).update( execution__model=None ) @@ -125,7 +134,7 @@ def _delete_models(project: str) -> Tuple[int, Set[str]]: if model_tasks: now = datetime.utcnow() Task.objects( - id__in=model_tasks, project__ne=project, output__model__in=model_ids + id__in=model_tasks, project__nin=projects, output__model__in=model_ids ).update( output__model=None, output__error=f"model deleted on {now.isoformat()}", diff --git a/apiserver/bll/project/sub_projects.py b/apiserver/bll/project/sub_projects.py new file mode 100644 index 0000000..c80ee1d --- /dev/null +++ b/apiserver/bll/project/sub_projects.py @@ -0,0 +1,163 @@ +import itertools +from datetime import datetime +from typing import Tuple, Optional, Sequence, Mapping + +from apiserver import database +from apiserver.apierrors import errors +from apiserver.config_repo import config +from apiserver.database.model.project import Project + +name_separator = "/" +max_depth = config.get("services.projects.sub_projects.max_depth", 10) + + +def _validate_project_name(project_name: str) -> Tuple[str, str]: + """ + Remove redundant '/' characters. Ensure that the project name is not empty + and path to it is not larger then max_depth parameter. + Return the cleaned up project name and location + """ + name_parts = list(filter(None, project_name.split(name_separator))) + if not name_parts: + raise errors.bad_request.InvalidProjectName(name=project_name) + + if len(name_parts) > max_depth: + raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth) + + return name_separator.join(name_parts), name_separator.join(name_parts[:-1]) + + +def _ensure_project(company: str, user: str, name: str) -> Optional[Project]: + """ + Makes sure that the project with the given name exists + If needed auto-create the project and all the missing projects in the path to it + Return the project + """ + name = name.strip(name_separator) + if not name: + return None + + project = _get_writable_project_from_name(company, name) + if project: + return project + + now = datetime.utcnow() + name, location = _validate_project_name(name) + project = Project( + id=database.utils.id(), + user=user, + company=company, + created=now, + last_update=now, + name=name, + description="", + ) + parent = _ensure_project(company, user, location) + _save_under_parent(project=project, parent=parent) + if parent: + parent.update(last_update=now) + + return project + + +def _save_under_parent(project: Project, parent: Optional[Project]): + """ + Save the project under the given parent project or top level (parent=None) + Check that the project location matches the parent name + """ + location, _, _ = project.name.rpartition(name_separator) + if not parent: + if location: + raise ValueError( + f"Project location {location} does not match empty parent name" + ) + project.parent = None + project.path = [] + project.save() + return + + if location != parent.name: + raise ValueError( + f"Project location {location} does not match parent name {parent.name}" + ) + project.parent = parent.id + project.path = [*(parent.path or []), parent.id] + project.save() + + +def _get_writable_project_from_name( + company, + name, + _only: Optional[Sequence[str]] = ("id", "name", "path", "company", "parent"), +) -> Optional[Project]: + """ + Return a project from name. If the project not found then return None + """ + qs = Project.objects(company=company, name=name) + if _only: + qs = qs.only(*_only) + return qs.first() + + +def _get_sub_projects( + project_ids: Sequence[str], _only: Sequence[str] = ("id", "path") +) -> Mapping[str, Sequence[Project]]: + """ + Return the list of child projects of all the levels for the parent project ids + """ + qs = Project.objects(path__in=project_ids) + if _only: + _only = set(_only) | {"path"} + qs = qs.only(*_only) + subprojects = list(qs) + + return { + pid: [s for s in subprojects if pid in (s.path or [])] for pid in project_ids + } + + +def _ids_with_parents(project_ids: Sequence[str]) -> Sequence[str]: + """ + Return project ids with all the parent projects + """ + projects = Project.objects(id__in=project_ids).only("id", "path") + parent_ids = set(itertools.chain.from_iterable(p.path for p in projects if p.path)) + return list({*(p.id for p in projects), *parent_ids}) + + +def _ids_with_children(project_ids: Sequence[str]) -> Sequence[str]: + """ + Return project ids with the ids of all the subprojects + """ + subprojects = Project.objects(path__in=project_ids).only("id") + return list({*project_ids, *(child.id for child in subprojects)}) + + +def _update_subproject_names(project: Project, update_path: bool = False) -> int: + """ + Update sub project names when the base project name changes + Optionally update the paths + """ + child_projects = _get_sub_projects(project_ids=[project.id], _only=("id", "name")) + updated = 0 + for child in child_projects[project.id]: + child_suffix = name_separator.join( + child.name.split(name_separator)[len(project.path) + 1 :] + ) + updates = {"name": name_separator.join((project.name, child_suffix))} + if update_path: + updates["path"] = project.path + child.path[len(project.path) :] + updated += child.update(upsert=False, **updates) + + return updated + + +def _reposition_project_with_children(project: Project, parent: Project) -> int: + new_location = parent.name if parent else None + project.name = name_separator.join( + filter(None, (new_location, project.name.split(name_separator)[-1])) + ) + _save_under_parent(project, parent=parent) + + moved = 1 + _update_subproject_names(project=project, update_path=True) + return moved diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index d0e9d27..1335786 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -26,7 +26,6 @@ from apiserver.database.model.task.task import ( TaskStatusMessage, TaskSystemTags, ArtifactModes, - external_task_types, ) from apiserver.database.model import EntityVisibility from apiserver.database.utils import get_company_or_none_constraint, id as create_id @@ -51,18 +50,6 @@ class TaskBLL: self.events_es = events_es or es_factory.connect("events") self.redis: StrictRedis = redis or redman.connection("apiserver") - @classmethod - def get_types(cls, company, project_ids: Optional[Sequence]) -> set: - """ - Return the list of unique task types used by company and public tasks - If project ids passed then only tasks from these projects are considered - """ - query = get_company_or_none_constraint(company) - if project_ids: - query &= Q(project__in=project_ids) - res = Task.objects(query).distinct(field="type") - return set(res).intersection(external_task_types) - @staticmethod def get_task_with_access( task_id, company_id, only=None, allow_public=False, requires_write_access=False diff --git a/apiserver/config/default/services/projects.conf b/apiserver/config/default/services/projects.conf index 43b9eaf..e7614f4 100644 --- a/apiserver/config/default/services/projects.conf +++ b/apiserver/config/default/services/projects.conf @@ -10,4 +10,9 @@ featured { # default featured index for public projects not specified in the order public_default: 9999 +} + +sub_projects { + # the max sub project depth + max_depth: 10 } \ No newline at end of file diff --git a/apiserver/database/model/project.py b/apiserver/database/model/project.py index 5f456f7..e4ff6d4 100644 --- a/apiserver/database/model/project.py +++ b/apiserver/database/model/project.py @@ -1,4 +1,4 @@ -from mongoengine import StringField, DateTimeField, IntField +from mongoengine import StringField, DateTimeField, IntField, ListField from apiserver.database import Database, strict from apiserver.database.fields import StrippedStringField, SafeSortedListField @@ -10,13 +10,15 @@ class Project(AttributedDocument): get_all_query_options = GetMixin.QueryParameterOptions( pattern_fields=("name", "description"), - list_fields=("tags", "system_tags", "id"), + list_fields=("tags", "system_tags", "id", "parent", "path"), ) meta = { "db_alias": Database.backend, "strict": strict, "indexes": [ + "parent", + "path", ("company", "name"), { "name": "%s.project.main_text_index" % Database.backend, @@ -44,3 +46,5 @@ class Project(AttributedDocument): logo_url = StringField() logo_blob = StringField(exclude_by_default=True) company_origin = StringField(exclude_by_default=True) + parent = StringField(reference_field="Project") + path = ListField(StringField(required=True), exclude_by_default=True) diff --git a/apiserver/mongo/initialize/pre_populate.py b/apiserver/mongo/initialize/pre_populate.py index a6af6e8..0f71976 100644 --- a/apiserver/mongo/initialize/pre_populate.py +++ b/apiserver/mongo/initialize/pre_populate.py @@ -33,6 +33,7 @@ from mongoengine import Q from apiserver.bll.event import EventBLL from apiserver.bll.event.event_common import EventType +from apiserver.bll.project import project_ids_with_children from apiserver.bll.task.artifacts import get_artifact_id from apiserver.bll.task.param_utils import ( split_param_name, @@ -423,6 +424,22 @@ class PrePopulate: items.append(results[0]) return items + @classmethod + def _check_projects_hierarchy(cls, projects: Set[Project]): + """ + For any exported project all its parents up to the root should be present + """ + if not projects: + return + + project_ids = {p.id for p in projects} + orphans = [p.id for p in projects if p.parent and p.parent not in project_ids] + if not orphans: + return + + print(f"ERROR: the following projects are exported without their parents: {orphans}") + exit(1) + @classmethod def _resolve_entities( cls, @@ -434,6 +451,7 @@ class PrePopulate: if projects: print("Reading projects...") + projects = project_ids_with_children(projects) entities[cls.project_cls].update( cls._resolve_type(cls.project_cls, projects) ) @@ -463,6 +481,8 @@ class PrePopulate: project_ids = {p.id for p in entities[cls.project_cls]} entities[cls.project_cls].update(o for o in objs if o.id not in project_ids) + cls._check_projects_hierarchy(entities[cls.project_cls]) + model_ids = { model_id for task in entities[cls.task_cls] @@ -634,11 +654,12 @@ class PrePopulate: """ Export the requested experiments, projects and models and return the list of artifact files Always do the export on sorted items since the order of items influence hash + The projects should be sorted by name so that on import the hierarchy is correctly restored from top to bottom """ artifacts = [] now = datetime.utcnow() for cls_ in sorted(entities, key=attrgetter("__name__")): - items = sorted(entities[cls_], key=attrgetter("id")) + items = sorted(entities[cls_], key=attrgetter("name", "id")) if not items: continue base_filename = cls._get_base_filename(cls_) diff --git a/apiserver/schema/services/projects.conf b/apiserver/schema/services/projects.conf index 28c1650..6a76c98 100644 --- a/apiserver/schema/services/projects.conf +++ b/apiserver/schema/services/projects.conf @@ -167,10 +167,27 @@ _definitions { type: string } // extra properties - stats: { + stats { description: "Additional project stats" "$ref": "#/definitions/stats" } + sub_projects { + description: "The list of sub projects" + type: array + items { + type: object + properties { + id { + description: "Subproject ID" + type: string + } + name { + description: "Subproject name" + type: string + } + } + } + } } } metric_variant_result { @@ -405,6 +422,17 @@ get_all { } } } + "2.13": ${get_all."2.1"} { + request { + properties { + shallow_search { + description: "If set to 'true' then the search with the specified criteria is performed among top level projects only (or if parents specified, among the direct children of the these parents). Otherwise the search is performed among all the company projects (or among all of the descendants of the specified parents)." + type: boolean + default: false + } + } + } + } } get_all_ex { internal: true @@ -438,6 +466,11 @@ get_all_ex { type: array items: {type: string} } + shallow_search { + description: "If set to 'true' then the search with the specified criteria is performed among top level projects only (or if parents specified, among the direct children of the these parents). Otherwise the search is performed among all the company projects (or among all of the descendants of the specified parents)." + type: boolean + default: false + } } } } @@ -498,6 +531,66 @@ update { } } } +move { + "2.13" { + description: "Moves a project and all of its subprojects under the different location" + request { + type: object + required: [project] + properties { + project { + description: "Project id" + type: string + } + new_location { + description: "The name location for the project" + type: string + } + } + } + response { + type: object + properties { + moved { + description: "The number of projects moved" + type: integer + } + } + } + } +} +merge { + "2.13" { + description: "Moves all the source project's contents to the destination project and remove the source project" + request { + type: object + required: [project] + properties { + project { + description: "Project id" + type: string + } + destination_project { + description: "The ID of the destination project" + type: string + } + } + } + response { + type: object + properties { + moved_entities { + description: "The number of tasks, models and dataviews moved from the merged project into the destination" + type: integer + } + moved_projects { + description: "The number of child projects moved from the merged project into the destination" + type: integer + } + } + } + } +} delete { "2.1" { description: "Deletes a project" diff --git a/apiserver/services/models.py b/apiserver/services/models.py index 0878cce..1ac3138 100644 --- a/apiserver/services/models.py +++ b/apiserver/services/models.py @@ -16,7 +16,6 @@ from apiserver.apimodels.models import ( GetFrameworksRequest, DeleteModelRequest, ) -from apiserver.bll.model import ModelBLL from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL from apiserver.bll.task import TaskBLL @@ -38,7 +37,6 @@ from apiserver.timing_context import TimingContext log = config.logger(__file__) org_bll = OrgBLL() -model_bll = ModelBLL() project_bll = ProjectBLL() @@ -131,7 +129,7 @@ def get_all(call: APICall, company_id, _): def get_frameworks(call: APICall, company_id, request: GetFrameworksRequest): call.result.data = { "frameworks": sorted( - model_bll.get_frameworks(company_id, project_ids=request.projects) + project_bll.get_model_frameworks(company_id, project_ids=request.projects) ) } diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py index 4cb0414..658f97a 100644 --- a/apiserver/services/projects.py +++ b/apiserver/services/projects.py @@ -1,4 +1,3 @@ -from datetime import datetime from typing import Sequence import attr @@ -8,13 +7,13 @@ from apiserver.apierrors import errors from apiserver.apierrors.errors.bad_request import InvalidProjectId from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse from apiserver.apimodels.projects import ( - GetHyperParamReq, - ProjectReq, + GetHyperParamRequest, + ProjectRequest, ProjectTagsRequest, ProjectTaskParentsRequest, ProjectHyperparamValuesRequest, ProjectsGetRequest, - DeleteRequest, + DeleteRequest, MoveRequest, MergeRequest, ) from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL @@ -47,10 +46,6 @@ create_fields = { "default_output_destination": None, } -get_all_query_options = Project.QueryParameterOptions( - pattern_fields=("name", "description"), list_fields=("tags", "system_tags", "id"), -) - @endpoint("projects.get_by_id", required_fields=["project"]) def get_by_id(call): @@ -72,27 +67,46 @@ def get_by_id(call): call.result.data = {"project": project_dict} +def _adjust_search_parameters(data: dict, shallow_search: bool): + """ + 1. Make sure that there is no external query on path + 2. If not shallow_search and parent is provided then parent can be at any place in path + 3. If shallow_search and no parent provided then use a top level parent + """ + data.pop("path", None) + if not shallow_search: + if "parent" in data: + data["path"] = data.pop("parent") + return + + if "parent" not in data: + data["parent"] = [None, ""] + + @endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest) def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): conform_tag_fields(call, call.data) allow_public = not request.non_public + shallow_search = request.shallow_search or request.include_stats with TimingContext("mongo", "projects_get_all"): + data = call.data if request.active_users: ids = project_bll.get_projects_with_active_user( company=company_id, users=request.active_users, - project_ids=call.data.get("id"), + project_ids=data.get("id"), allow_public=allow_public, ) if not ids: call.result.data = {"projects": []} return - call.data["id"] = ids + data["id"] = ids + + _adjust_search_parameters(data, shallow_search=shallow_search) projects = Project.get_many_with_join( company=company_id, - query_dict=call.data, - query_options=get_all_query_options, + query_dict=data, allow_public=allow_public, ) @@ -102,7 +116,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): return project_ids = {project["id"] for project in projects} - stats = project_bll.get_project_stats( + stats, children = project_bll.get_project_stats( company=company_id, project_ids=list(project_ids), specific_state=request.stats_for_state, @@ -110,6 +124,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): for project in projects: project["stats"] = stats[project["id"]] + project["sub_projects"] = children[project["id"]] call.result.data = {"projects": projects} @@ -117,12 +132,13 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): @endpoint("projects.get_all") def get_all(call: APICall): conform_tag_fields(call, call.data) + data = call.data + _adjust_search_parameters(data, shallow_search=data.get("shallow_search", False)) with translate_errors_context(), TimingContext("mongo", "projects_get_all"): projects = Project.get_many( company=call.identity.company, - query_dict=call.data, - query_options=get_all_query_options, - parameters=call.data, + query_dict=data, + parameters=data, allow_public=True, ) conform_output_tags(call, projects) @@ -161,22 +177,15 @@ def update(call: APICall): :return: updated - `int` - number of projects updated fields - `[string]` - updated fields """ - project_id = call.data["project"] - - with translate_errors_context(): - project = Project.get_for_writing(company=call.identity.company, id=project_id) - if not project: - raise errors.bad_request.InvalidProjectId(id=project_id) - - fields = parse_from_call( - call.data, create_fields, Project.get_fields(), discard_none_values=False - ) - conform_tag_fields(call, fields, validate=True) - fields["last_update"] = datetime.utcnow() - with TimingContext("mongo", "projects_update"): - updated = project.update(upsert=False, **fields) - conform_output_tags(call, fields) - call.result.data_model = UpdateResponse(updated=updated, fields=fields) + fields = parse_from_call( + call.data, create_fields, Project.get_fields(), discard_none_values=False + ) + conform_tag_fields(call, fields, validate=True) + updated = ProjectBLL.update( + company=call.identity.company, project_id=call.data["project"], **fields + ) + conform_output_tags(call, fields) + call.result.data_model = UpdateResponse(updated=updated, fields=fields) def _reset_cached_tags(company: str, projects: Sequence[str]): @@ -184,20 +193,47 @@ def _reset_cached_tags(company: str, projects: Sequence[str]): org_bll.reset_tags(company, Tags.Model, projects=projects) +@endpoint("projects.move", request_data_model=MoveRequest) +def move(call: APICall, company: str, request: MoveRequest): + moved, affected_projects = ProjectBLL.move_project( + company=company, + user=call.identity.user, + project_id=request.project, + new_location=request.new_location, + ) + _reset_cached_tags(company, projects=list(affected_projects)) + + call.result.data = {"moved": moved} + + +@endpoint("projects.merge", request_data_model=MergeRequest) +def merge(call: APICall, company: str, request: MergeRequest): + moved_entitites, moved_projects, affected_projects = ProjectBLL.merge_project( + company, source_id=request.project, destination_id=request.destination_project + ) + + _reset_cached_tags(company, projects=list(affected_projects)) + + call.result.data = { + "moved_entities": moved_entitites, + "moved_projects": moved_projects, + } + + @endpoint("projects.delete", request_data_model=DeleteRequest) def delete(call: APICall, company_id: str, request: DeleteRequest): - res = delete_project( + res, affected_projects = delete_project( company=company_id, project_id=request.project, force=request.force, delete_contents=request.delete_contents, ) - _reset_cached_tags(company_id, projects=[request.project]) + _reset_cached_tags(company_id, projects=list(affected_projects)) call.result.data = {**attr.asdict(res)} -@endpoint("projects.get_unique_metric_variants", request_data_model=ProjectReq) -def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectReq): +@endpoint("projects.get_unique_metric_variants", request_data_model=ProjectRequest) +def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectRequest): metrics = task_bll.get_unique_metric_variants( company_id, [request.project] if request.project else None @@ -209,9 +245,9 @@ def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectR @endpoint( "projects.get_hyper_parameters", min_version="2.9", - request_data_model=GetHyperParamReq, + request_data_model=GetHyperParamRequest, ) -def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamReq): +def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamRequest): total, remaining, parameters = TaskBLL.get_aggregated_project_parameters( company_id, @@ -303,7 +339,7 @@ def get_task_parents( call: APICall, company_id: str, request: ProjectTaskParentsRequest ): call.result.data = { - "parents": org_bll.get_parent_tasks( + "parents": project_bll.get_task_parents( company_id, projects=request.projects, state=request.tasks_state ) } diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 5d49ebb..93bc6d0 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -216,7 +216,7 @@ def get_all(call: APICall, company_id, _): @endpoint("tasks.get_types", request_data_model=GetTypesRequest) def get_types(call: APICall, company_id, request: GetTypesRequest): call.result.data = { - "types": list(task_bll.get_types(company_id, project_ids=request.projects)) + "types": list(project_bll.get_task_types(company_id, project_ids=request.projects)) } diff --git a/apiserver/tests/automated/test_subprojects.py b/apiserver/tests/automated/test_subprojects.py new file mode 100644 index 0000000..4a9446a --- /dev/null +++ b/apiserver/tests/automated/test_subprojects.py @@ -0,0 +1,240 @@ +from time import sleep +from typing import Sequence, Optional, Tuple + +from boltons.iterutils import first + +from apiserver.apierrors import errors +from apiserver.database.model import EntityVisibility +from apiserver.database.utils import id as db_id +from apiserver.tests.automated import TestService + + +class TestSubProjects(TestService): + def setUp(self, **kwargs): + super().setUp(version="2.13") + + def test_project_aggregations(self): + child = self._temp_project(name="Aggregation/Pr1") + project = self.api.projects.get_all_ex(name="^Aggregation$").projects[0].id + child_project = self.api.projects.get_all_ex(id=[child]).projects[0] + self.assertEqual(child_project.parent.id, project) + + user = self.api.users.get_current_user().user.id + + # test aggregations on project with empty subprojects + res = self.api.users.get_all_ex(active_in_projects=[project]) + self.assertEqual(res.users, []) + res = self.api.projects.get_all_ex(id=[project], active_users=[user]) + self.assertEqual(res.projects, []) + res = self.api.models.get_frameworks(projects=[project]) + self.assertEqual(res.frameworks, []) + res = self.api.tasks.get_types(projects=[project]) + self.assertEqual(res.types, []) + res = self.api.projects.get_task_parents(projects=[project]) + self.assertEqual(res.parents, []) + + # test aggregations with non-empty subprojects + task1 = self._temp_task(project=child) + self._temp_task(project=child, parent=task1) + framework = "Test framework" + self._temp_model(project=child, framework=framework) + res = self.api.users.get_all_ex(active_in_projects=[project]) + self._assert_ids(res.users, [user]) + res = self.api.projects.get_all_ex(id=[project], active_users=[user]) + self._assert_ids(res.projects, [project]) + res = self.api.projects.get_task_parents(projects=[project]) + self._assert_ids(res.parents, [task1]) + res = self.api.models.get_frameworks(projects=[project]) + self.assertEqual(res.frameworks, [framework]) + res = self.api.tasks.get_types(projects=[project]) + self.assertEqual(res.types, ["testing"]) + + def _assert_ids(self, actual: Sequence[dict], expected: Sequence[str]): + self.assertEqual([a["id"] for a in actual], expected) + + def test_project_operations(self): + # create + with self.api.raises(errors.bad_request.InvalidProjectName): + self._temp_project(name="/") + project1 = self._temp_project(name="Root1/Pr1") + project1_child = self._temp_project(name="Root1/Pr1/Pr2") + with self.api.raises(errors.bad_request.ExpectedUniqueData): + self._temp_project(name="Root1/Pr1/Pr2") + + # update + with self.api.raises(errors.bad_request.CannotUpdateProjectLocation): + self.api.projects.update(project=project1, name="Root2/Pr2") + res = self.api.projects.update(project=project1, name="Root1/Pr2") + self.assertEqual(res.updated, 1) + res = self.api.projects.get_by_id(project=project1_child) + self.assertEqual(res.project.name, "Root1/Pr2/Pr2") + + # move + res = self.api.projects.move(project=project1, new_location="Root2") + self.assertEqual(res.moved, 2) + res = self.api.projects.get_by_id(project=project1_child) + self.assertEqual(res.project.name, "Root2/Pr2/Pr2") + + # merge + project_with_task, (active, archived) = self._temp_project_with_tasks( + "Root1/Pr3/Pr4" + ) + project1_parent = self._getProjectParent(project1) + self._assertTags(project1_parent, tags=[], system_tags=[]) + self._assertTags(project1_parent, tags=[], system_tags=[]) + project_with_task_parent = self._getProjectParent(project_with_task) + self._assertTags(project_with_task_parent) + # self._assertTags(project_id=None) + + merge_source = self.api.projects.get_by_id( + project=project_with_task + ).project.parent + res = self.api.projects.merge( + project=merge_source, destination_project=project1 + ) + self.assertEqual(res.moved_entities, 0) + self.assertEqual(res.moved_projects, 1) + res = self.api.projects.get_by_id(project=project_with_task) + self.assertEqual(res.project.name, "Root2/Pr2/Pr4") + with self.api.raises(errors.bad_request.InvalidProjectId): + self.api.projects.get_by_id(project=merge_source) + + self._assertTags(project1_parent) + self._assertTags(project1) + self._assertTags(project_with_task_parent, tags=[], system_tags=[]) + # self._assertTags(project_id=None) + + # delete + with self.api.raises(errors.bad_request.ProjectHasTasks): + self.api.projects.delete(project=project1) + res = self.api.projects.delete(project=project1, force=True) + self.assertEqual(res.deleted, 3) + self.assertEqual(res.disassociated_tasks, 2) + res = self.api.tasks.get_by_id(task=active).task + self.assertIsNone(res.get("project")) + for p_id in (project1, project1_child, project_with_task): + with self.api.raises(errors.bad_request.InvalidProjectId): + self.api.projects.get_by_id(project=p_id) + + self._assertTags(project1_parent, tags=[], system_tags=[]) + # self._assertTags(project_id=None, tags=[], system_tags=[]) + + def _getProjectParent(self, project_id: str): + return self.api.projects.get_all_ex(id=[project_id]).projects[0].parent.id + + def _assertTags( + self, + project_id: Optional[str], + tags: Sequence[str] = ("test",), + system_tags: Sequence[str] = (EntityVisibility.archived.value,), + ): + if project_id: + res = self.api.projects.get_task_tags( + projects=[project_id], include_system=True + ) + else: + res = self.api.organization.get_tags(include_system=True) + + self.assertEqual(set(res.tags), set(tags)) + self.assertEqual(set(res.system_tags), set(system_tags)) + + def test_get_all_search_options(self): + project1 = self._temp_project(name="project1") + project2 = self._temp_project(name="project1/project2") + self._temp_project(name="project3") + + # local search finds only at the specified level + res = self.api.projects.get_all_ex( + name="project1", shallow_search=True + ).projects + self.assertEqual([p.id for p in res], [project1]) + res = self.api.projects.get_all_ex(name="project1", parent=[project1]).projects + self.assertEqual([p.id for p in res], [project2]) + + # global search finds all or below the specified level + res = self.api.projects.get_all_ex(name="project1").projects + self.assertEqual(set(p.id for p in res), {project1, project2}) + project4 = self._temp_project(name="project1/project2/project1") + res = self.api.projects.get_all_ex(name="project1", parent=[project2]).projects + self.assertEqual([p.id for p in res], [project4]) + + self.api.projects.delete(project=project1, force=True) + + def test_get_all_with_stats(self): + project4, _ = self._temp_project_with_tasks(name="project1/project3/project4") + project5, _ = self._temp_project_with_tasks(name="project1/project3/project5") + project2 = self._temp_project(name="project2") + res = self.api.projects.get_all(shallow_search=True).projects + self.assertTrue(any(p for p in res if p.id == project2)) + self.assertFalse(any(p for p in res if p.id in [project4, project5])) + + project1 = first(p.id for p in res if p.name == "project1") + res = self.api.projects.get_all_ex( + id=[project1, project2], include_stats=True + ).projects + self.assertEqual(set(p.id for p in res), {project1, project2}) + res1 = next(p for p in res if p.id == project1) + self.assertEqual(res1.stats["active"]["status_count"]["created"], 0) + self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2) + self.assertEqual(res1.stats["active"]["total_runtime"], 2) + self.assertEqual( + {sp.name for sp in res1.sub_projects}, + { + "project1/project3", + "project1/project3/project4", + "project1/project3/project5", + }, + ) + res2 = next(p for p in res if p.id == project2) + self.assertEqual(res2.stats["active"]["status_count"]["created"], 0) + self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0) + self.assertEqual(res2.stats["active"]["total_runtime"], 0) + self.assertEqual(res2.sub_projects, []) + + def _run_tasks(self, *tasks): + """Imitate 1 second of running""" + for task_id in tasks: + self.api.tasks.started(task=task_id) + sleep(1) + for task_id in tasks: + self.api.tasks.stopped(task=task_id) + + def _temp_project_with_tasks(self, name) -> Tuple[str, Tuple[str, str]]: + pr_id = self._temp_project(name=name) + task_active = self._temp_task(project=pr_id) + task_archived = self._temp_task( + project=pr_id, system_tags=[EntityVisibility.archived.value], tags=["test"] + ) + self._run_tasks(task_active, task_archived) + return pr_id, (task_active, task_archived) + + delete_params = dict(can_fail=True, force=True) + + def _temp_project(self, name, **kwargs): + return self.create_temp( + "projects", + delete_params=self.delete_params, + name=name, + description="", + **kwargs, + ) + + def _temp_task(self, **kwargs): + return self.create_temp( + "tasks", + delete_params=self.delete_params, + type="testing", + name=db_id(), + input=dict(view=dict()), + **kwargs, + ) + + def _temp_model(self, **kwargs): + return self.create_temp( + service="models", + delete_params=self.delete_params, + name="test", + uri="file:///a", + labels={}, + **kwargs, + )