mirror of
https://github.com/clearml/clearml-server
synced 2025-04-05 13:35:02 +00:00
Add sub-projects support
This commit is contained in:
parent
1b49da8748
commit
c034c1a986
@ -63,7 +63,9 @@
|
||||
403: ["project_not_found", "project not found"]
|
||||
405: ["project_has_models", "project has associated models"]
|
||||
407: ["invalid_project_name", "invalid project name"]
|
||||
408: ["cannot_update_project_location", "cannot update project location"]
|
||||
408: ["cannot_update_project_location", "Cannot update project location. Use projects.move instead"]
|
||||
409: ["project_path_exceeds_max", "Project path exceed the maximum allowed depth"]
|
||||
410: ["project_source_and_destination_are_the_same", "Project has the same source and destination paths"]
|
||||
|
||||
# Queues
|
||||
701: ["invalid_queue_id", "invalid queue id"]
|
||||
|
@ -5,16 +5,24 @@ from apiserver.apimodels.organization import TagsRequest
|
||||
from apiserver.database.model import EntityVisibility
|
||||
|
||||
|
||||
class ProjectReq(models.Base):
|
||||
class ProjectRequest(models.Base):
|
||||
project = fields.StringField(required=True)
|
||||
|
||||
|
||||
class DeleteRequest(ProjectReq):
|
||||
class MergeRequest(ProjectRequest):
|
||||
destination_project = fields.StringField()
|
||||
|
||||
|
||||
class MoveRequest(ProjectRequest):
|
||||
new_location = fields.StringField()
|
||||
|
||||
|
||||
class DeleteRequest(ProjectRequest):
|
||||
force = fields.BoolField(default=False)
|
||||
delete_contents = fields.BoolField(default=False)
|
||||
|
||||
|
||||
class GetHyperParamReq(ProjectReq):
|
||||
class GetHyperParamRequest(ProjectRequest):
|
||||
page = fields.IntField(default=0)
|
||||
page_size = fields.IntField(default=500)
|
||||
|
||||
@ -23,15 +31,15 @@ class ProjectTagsRequest(TagsRequest):
|
||||
projects = ListField(str)
|
||||
|
||||
|
||||
class MultiProjectReq(models.Base):
|
||||
class MultiProjectRequest(models.Base):
|
||||
projects = fields.ListField(str)
|
||||
|
||||
|
||||
class ProjectTaskParentsRequest(MultiProjectReq):
|
||||
class ProjectTaskParentsRequest(MultiProjectRequest):
|
||||
tasks_state = ActualEnumField(EntityVisibility)
|
||||
|
||||
|
||||
class ProjectHyperparamValuesRequest(MultiProjectReq):
|
||||
class ProjectHyperparamValuesRequest(MultiProjectRequest):
|
||||
section = fields.StringField(required=True)
|
||||
name = fields.StringField(required=True)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
@ -42,3 +50,4 @@ class ProjectsGetRequest(models.Base):
|
||||
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
|
||||
non_public = fields.BoolField(default=False)
|
||||
active_users = fields.ListField(str)
|
||||
shallow_search = fields.BoolField(default=False)
|
||||
|
@ -1,18 +0,0 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.utils import get_company_or_none_constraint
|
||||
|
||||
|
||||
class ModelBLL:
|
||||
def get_frameworks(self, company, project_ids: Optional[Sequence]) -> Sequence:
|
||||
"""
|
||||
Return the list of unique frameworks used by company and public models
|
||||
If project ids passed then only models from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
return Model.objects(query).distinct(field="framework")
|
@ -1,12 +1,8 @@
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Dict, Optional
|
||||
|
||||
from mongoengine import Q
|
||||
from typing import Sequence, Dict
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
@ -65,34 +61,3 @@ class OrgBLL:
|
||||
|
||||
def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache:
|
||||
return self._task_tags if entity == Tags.Task else self._model_tags
|
||||
|
||||
@classmethod
|
||||
def get_parent_tasks(
|
||||
cls,
|
||||
company_id: str,
|
||||
projects: Sequence[str],
|
||||
state: Optional[EntityVisibility] = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Get list of unique parent tasks sorted by task name for the passed company projects
|
||||
If projects is None or empty then get parents for all the company tasks
|
||||
"""
|
||||
query = Q(company=company_id)
|
||||
if projects:
|
||||
query &= Q(project__in=projects)
|
||||
if state == EntityVisibility.archived:
|
||||
query &= Q(system_tags__in=[EntityVisibility.archived.value])
|
||||
elif state == EntityVisibility.active:
|
||||
query &= Q(system_tags__nin=[EntityVisibility.archived.value])
|
||||
|
||||
parent_ids = set(Task.objects(query).distinct("parent"))
|
||||
if not parent_ids:
|
||||
return []
|
||||
|
||||
parents = Task.get_many_with_join(
|
||||
company_id,
|
||||
query=Q(id__in=parent_ids),
|
||||
allow_public=True,
|
||||
override_projection=("id", "name", "project.name"),
|
||||
)
|
||||
return sorted(parents, key=itemgetter("name"))
|
||||
|
@ -5,6 +5,7 @@ from mongoengine import Q
|
||||
from redis import Redis
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.bll.project import project_ids_with_children
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
@ -40,7 +41,7 @@ class _TagsCache:
|
||||
if vals:
|
||||
query &= GetMixin.get_list_field_query(name, vals)
|
||||
if project:
|
||||
query &= Q(project=project)
|
||||
query &= Q(project__in=project_ids_with_children([project]))
|
||||
|
||||
return self.db_cls.objects(query).distinct(field)
|
||||
|
||||
|
@ -1 +1,2 @@
|
||||
from .project_bll import ProjectBLL
|
||||
from .sub_projects import _ids_with_children as project_ids_with_children
|
||||
|
@ -1,8 +1,20 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from functools import reduce
|
||||
from itertools import groupby
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Optional, Type, Tuple, Dict
|
||||
from typing import (
|
||||
Sequence,
|
||||
Optional,
|
||||
Type,
|
||||
Tuple,
|
||||
Dict,
|
||||
Set,
|
||||
TypeVar,
|
||||
Callable,
|
||||
Mapping,
|
||||
)
|
||||
|
||||
from mongoengine import Q, Document
|
||||
|
||||
@ -12,35 +24,122 @@ from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, external_task_types
|
||||
from apiserver.database.utils import get_options, get_company_or_none_constraint
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from .sub_projects import (
|
||||
_reposition_project_with_children,
|
||||
_ensure_project,
|
||||
_validate_project_name,
|
||||
_update_subproject_names,
|
||||
_save_under_parent,
|
||||
_get_sub_projects,
|
||||
_ids_with_children,
|
||||
_ids_with_parents,
|
||||
)
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class ProjectBLL:
|
||||
@classmethod
|
||||
def get_active_users(
|
||||
cls, company, project_ids: Sequence, user_ids: Optional[Sequence] = None
|
||||
) -> set:
|
||||
def merge_project(
|
||||
cls, company, source_id: str, destination_id: str
|
||||
) -> Tuple[int, int, Set[str]]:
|
||||
"""
|
||||
Get the set of user ids that created tasks/models in the given projects
|
||||
If project_ids is empty then all projects are examined
|
||||
If user_ids are passed then only subset of these users is returned
|
||||
Move all the tasks and sub projects from the source project to the destination
|
||||
Remove the source project
|
||||
Return the amounts of moved entities and subprojects + set of all the affected project ids
|
||||
"""
|
||||
with TimingContext("mongo", "active_users_in_projects"):
|
||||
res = set()
|
||||
query = Q(company=company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
if user_ids:
|
||||
query &= Q(user__in=user_ids)
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="user"))
|
||||
with TimingContext("mongo", "move_project"):
|
||||
if source_id == destination_id:
|
||||
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
|
||||
parent=source_id
|
||||
)
|
||||
source = Project.get(company, source_id)
|
||||
destination = Project.get(company, destination_id)
|
||||
|
||||
return res
|
||||
moved_entities = 0
|
||||
for entity_type in (Task, Model):
|
||||
moved_entities += entity_type.objects(
|
||||
company=company,
|
||||
project=source_id,
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
).update(upsert=False, project=destination_id)
|
||||
|
||||
moved_sub_projects = 0
|
||||
for child in Project.objects(company=company, parent=source_id):
|
||||
_reposition_project_with_children(project=child, parent=destination)
|
||||
moved_sub_projects += 1
|
||||
|
||||
affected = {source.id, *(source.path or [])}
|
||||
source.delete()
|
||||
|
||||
if destination:
|
||||
destination.update(last_update=datetime.utcnow())
|
||||
affected.update({destination.id, *(destination.path or [])})
|
||||
|
||||
return moved_entities, moved_sub_projects, affected
|
||||
|
||||
@classmethod
|
||||
def move_project(
|
||||
cls, company: str, user: str, project_id: str, new_location: str
|
||||
) -> Tuple[int, Set[str]]:
|
||||
"""
|
||||
Move project with its sub projects from its current location to the target one.
|
||||
If the target location does not exist then it will be created. If it exists then
|
||||
it should be writable. The source location should be writable too.
|
||||
Return the number of moved projects + set of all the affected project ids
|
||||
"""
|
||||
with TimingContext("mongo", "move_project"):
|
||||
project = Project.get(company, project_id)
|
||||
old_parent_id = project.parent
|
||||
old_parent = (
|
||||
Project.get_for_writing(company=project.company, id=old_parent_id)
|
||||
if old_parent_id
|
||||
else None
|
||||
)
|
||||
new_parent = _ensure_project(company=company, user=user, name=new_location)
|
||||
new_parent_id = new_parent.id if new_parent else None
|
||||
if old_parent_id == new_parent_id:
|
||||
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
|
||||
location=new_parent.name if new_parent else ""
|
||||
)
|
||||
|
||||
moved = _reposition_project_with_children(project, parent=new_parent)
|
||||
|
||||
now = datetime.utcnow()
|
||||
affected = set()
|
||||
for p in filter(None, (old_parent, new_parent)):
|
||||
p.update(last_update=now)
|
||||
affected.update({p.id, *(p.path or [])})
|
||||
|
||||
return moved, affected
|
||||
|
||||
@classmethod
|
||||
def update(cls, company: str, project_id: str, **fields):
|
||||
with TimingContext("mongo", "projects_update"):
|
||||
project = Project.get_for_writing(company=company, id=project_id)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
new_name = fields.pop("name", None)
|
||||
if new_name:
|
||||
new_name, new_location = _validate_project_name(new_name)
|
||||
old_name, old_location = _validate_project_name(project.name)
|
||||
if new_location != old_location:
|
||||
raise errors.bad_request.CannotUpdateProjectLocation(name=new_name)
|
||||
fields["name"] = new_name
|
||||
|
||||
fields["last_update"] = datetime.utcnow()
|
||||
updated = project.update(upsert=False, **fields)
|
||||
|
||||
if new_name:
|
||||
project.name = new_name
|
||||
_update_subproject_names(project=project)
|
||||
|
||||
return updated
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@ -57,6 +156,7 @@ class ProjectBLL:
|
||||
Create a new project.
|
||||
Returns project ID
|
||||
"""
|
||||
name, location = _validate_project_name(name)
|
||||
now = datetime.utcnow()
|
||||
project = Project(
|
||||
id=database.utils.id(),
|
||||
@ -70,7 +170,11 @@ class ProjectBLL:
|
||||
created=now,
|
||||
last_update=now,
|
||||
)
|
||||
project.save()
|
||||
parent = _ensure_project(company=company, user=user, name=location)
|
||||
_save_under_parent(project=project, parent=parent)
|
||||
if parent:
|
||||
parent.update(last_update=now)
|
||||
|
||||
return project.id
|
||||
|
||||
@classmethod
|
||||
@ -98,6 +202,7 @@ class ProjectBLL:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
return project_id
|
||||
|
||||
project_name, _ = _validate_project_name(project_name)
|
||||
project = Project.objects(company=company, name=project_name).only("id").first()
|
||||
if project:
|
||||
return project.id
|
||||
@ -265,18 +370,48 @@ class ProjectBLL:
|
||||
|
||||
return status_count_pipeline, runtime_pipeline
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@staticmethod
|
||||
def aggregate_project_data(
|
||||
func: Callable[[T, T], T],
|
||||
project_ids: Sequence[str],
|
||||
child_projects: Mapping[str, Sequence[Project]],
|
||||
data: Mapping[str, T],
|
||||
) -> Dict[str, T]:
|
||||
"""
|
||||
Given a list of project ids and data collected over these projects and their subprojects
|
||||
For each project aggregates the data from all of its subprojects
|
||||
"""
|
||||
aggregated = {}
|
||||
if not data:
|
||||
return aggregated
|
||||
for pid in project_ids:
|
||||
relevant_projects = {p.id for p in child_projects.get(pid, [])} | {pid}
|
||||
relevant_data = [data for p, data in data.items() if p in relevant_projects]
|
||||
if not relevant_data:
|
||||
continue
|
||||
aggregated[pid] = reduce(func, relevant_data)
|
||||
return aggregated
|
||||
|
||||
@classmethod
|
||||
def get_project_stats(
|
||||
cls,
|
||||
company: str,
|
||||
project_ids: Sequence[str],
|
||||
specific_state: Optional[EntityVisibility] = None,
|
||||
) -> Dict[str, dict]:
|
||||
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
|
||||
if not project_ids:
|
||||
return {}
|
||||
return {}, {}
|
||||
|
||||
child_projects = _get_sub_projects(project_ids, _only=("id", "name"))
|
||||
project_ids_with_children = set(project_ids) | {
|
||||
c.id for c in itertools.chain.from_iterable(child_projects.values())
|
||||
}
|
||||
status_count_pipeline, runtime_pipeline = cls.make_projects_get_all_pipelines(
|
||||
company, project_ids=project_ids, specific_state=specific_state
|
||||
company,
|
||||
project_ids=list(project_ids_with_children),
|
||||
specific_state=specific_state,
|
||||
)
|
||||
|
||||
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
|
||||
@ -298,23 +433,58 @@ class ProjectBLL:
|
||||
}
|
||||
)
|
||||
|
||||
def sum_status_count(
|
||||
a: Mapping[str, Mapping], b: Mapping[str, Mapping]
|
||||
) -> Dict[str, dict]:
|
||||
return {
|
||||
section: {
|
||||
status: nested_get(a, (section, status), 0)
|
||||
+ nested_get(b, (section, status), 0)
|
||||
for status in set(a.get(section, {})) | set(b.get(section, {}))
|
||||
}
|
||||
for section in set(a) | set(b)
|
||||
}
|
||||
|
||||
status_count = cls.aggregate_project_data(
|
||||
func=sum_status_count,
|
||||
project_ids=project_ids,
|
||||
child_projects=child_projects,
|
||||
data=status_count,
|
||||
)
|
||||
|
||||
runtime = {
|
||||
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
|
||||
for result in Task.aggregate(runtime_pipeline)
|
||||
}
|
||||
|
||||
def get_status_counts(project_id, section):
|
||||
path = "/".join((project_id, section))
|
||||
def sum_runtime(
|
||||
a: Mapping[str, Mapping], b: Mapping[str, Mapping]
|
||||
) -> Dict[str, dict]:
|
||||
return {
|
||||
"total_runtime": safe_get(runtime, path, 0),
|
||||
"status_count": safe_get(status_count, path, default_counts),
|
||||
section: a.get(section, 0) + b.get(section, 0)
|
||||
for section in set(a) | set(b)
|
||||
}
|
||||
|
||||
runtime = cls.aggregate_project_data(
|
||||
func=sum_runtime,
|
||||
project_ids=project_ids,
|
||||
child_projects=child_projects,
|
||||
data=runtime,
|
||||
)
|
||||
|
||||
def get_status_counts(project_id, section):
|
||||
return {
|
||||
"total_runtime": nested_get(runtime, (project_id, section), 0),
|
||||
"status_count": nested_get(
|
||||
status_count, (project_id, section), default_counts
|
||||
),
|
||||
}
|
||||
|
||||
report_for_states = [
|
||||
s for s in EntityVisibility if not specific_state or specific_state == s
|
||||
]
|
||||
|
||||
return {
|
||||
stats = {
|
||||
project: {
|
||||
task_state.value: get_status_counts(project, task_state.value)
|
||||
for task_state in report_for_states
|
||||
@ -322,6 +492,40 @@ class ProjectBLL:
|
||||
for project in project_ids
|
||||
}
|
||||
|
||||
children = {
|
||||
project: sorted(
|
||||
[{"id": c.id, "name": c.name} for c in child_projects.get(project, [])],
|
||||
key=itemgetter("name"),
|
||||
)
|
||||
for project in project_ids
|
||||
}
|
||||
return stats, children
|
||||
|
||||
@classmethod
|
||||
def get_active_users(
|
||||
cls,
|
||||
company,
|
||||
project_ids: Sequence[str],
|
||||
user_ids: Optional[Sequence[str]] = None,
|
||||
) -> set:
|
||||
"""
|
||||
Get the set of user ids that created tasks/models/dataviews in the given projects
|
||||
If project_ids is empty then all projects are examined
|
||||
If user_ids are passed then only subset of these users is returned
|
||||
"""
|
||||
with TimingContext("mongo", "active_users_in_projects"):
|
||||
res = set()
|
||||
query = Q(company=company)
|
||||
if project_ids:
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=project_ids)
|
||||
if user_ids:
|
||||
query &= Q(user__in=user_ids)
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="user"))
|
||||
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_projects_with_active_user(
|
||||
cls,
|
||||
@ -330,13 +534,83 @@ class ProjectBLL:
|
||||
project_ids: Optional[Sequence[str]] = None,
|
||||
allow_public: bool = True,
|
||||
) -> Sequence[str]:
|
||||
"""Get the projects ids where user created any tasks"""
|
||||
company = (
|
||||
{"company__in": [None, "", company]}
|
||||
if allow_public
|
||||
else {"company": company}
|
||||
)
|
||||
projects = {"project__in": project_ids} if project_ids else {}
|
||||
return Task.objects(**company, user__in=users, **projects).distinct(
|
||||
field="project"
|
||||
"""
|
||||
Get the projects ids where user created any tasks including all the parents of these projects
|
||||
If project ids are specified then filter the results by these project ids
|
||||
"""
|
||||
query = Q(user__in=users)
|
||||
|
||||
if allow_public:
|
||||
query &= get_company_or_none_constraint(company)
|
||||
else:
|
||||
query &= Q(company=company)
|
||||
|
||||
if project_ids:
|
||||
query &= Q(project__in=_ids_with_children(project_ids))
|
||||
|
||||
res = Task.objects(query).distinct(field="project")
|
||||
if not res:
|
||||
return res
|
||||
|
||||
ids_with_parents = _ids_with_parents(res)
|
||||
if project_ids:
|
||||
return [pid for pid in ids_with_parents if pid in project_ids]
|
||||
|
||||
return ids_with_parents
|
||||
|
||||
@classmethod
|
||||
def get_task_parents(
|
||||
cls,
|
||||
company_id: str,
|
||||
projects: Sequence[str],
|
||||
state: Optional[EntityVisibility] = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Get list of unique parent tasks sorted by task name for the passed company projects
|
||||
If projects is None or empty then get parents for all the company tasks
|
||||
"""
|
||||
query = Q(company=company_id)
|
||||
if projects:
|
||||
projects = _ids_with_children(projects)
|
||||
query &= Q(project__in=projects)
|
||||
if state == EntityVisibility.archived:
|
||||
query &= Q(system_tags__in=[EntityVisibility.archived.value])
|
||||
elif state == EntityVisibility.active:
|
||||
query &= Q(system_tags__nin=[EntityVisibility.archived.value])
|
||||
|
||||
parent_ids = set(Task.objects(query).distinct("parent"))
|
||||
if not parent_ids:
|
||||
return []
|
||||
|
||||
parents = Task.get_many_with_join(
|
||||
company_id,
|
||||
query=Q(id__in=parent_ids),
|
||||
allow_public=True,
|
||||
override_projection=("id", "name", "project.name"),
|
||||
)
|
||||
return sorted(parents, key=itemgetter("name"))
|
||||
|
||||
@classmethod
|
||||
def get_task_types(cls, company, project_ids: Optional[Sequence]) -> set:
|
||||
"""
|
||||
Return the list of unique task types used by company and public tasks
|
||||
If project ids passed then only tasks from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=project_ids)
|
||||
res = Task.objects(query).distinct(field="type")
|
||||
return set(res).intersection(external_task_types)
|
||||
|
||||
@classmethod
|
||||
def get_model_frameworks(cls, company, project_ids: Optional[Sequence]) -> Sequence:
|
||||
"""
|
||||
Return the list of unique frameworks used by company and public models
|
||||
If project ids passed then only models from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=project_ids)
|
||||
return Model.objects(query).distinct(field="framework")
|
||||
|
@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Tuple, Set
|
||||
from typing import Tuple, Set, Sequence
|
||||
|
||||
import attr
|
||||
|
||||
@ -16,6 +16,7 @@ from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, ArtifactModes
|
||||
from apiserver.timing_context import TimingContext
|
||||
from .sub_projects import _ids_with_children
|
||||
|
||||
log = config.logger(__file__)
|
||||
event_bll = EventBLL()
|
||||
@ -32,18 +33,22 @@ class DeleteProjectResult:
|
||||
|
||||
def delete_project(
|
||||
company: str, project_id: str, force: bool, delete_contents: bool
|
||||
) -> DeleteProjectResult:
|
||||
project = Project.get_for_writing(company=company, id=project_id)
|
||||
) -> Tuple[DeleteProjectResult, Set[str]]:
|
||||
project = Project.get_for_writing(
|
||||
company=company, id=project_id, _only=("id", "path")
|
||||
)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
project_ids = _ids_with_children([project_id])
|
||||
if not force:
|
||||
for cls, error in (
|
||||
(Task, errors.bad_request.ProjectHasTasks),
|
||||
(Model, errors.bad_request.ProjectHasModels),
|
||||
):
|
||||
non_archived = cls.objects(
|
||||
project=project_id, system_tags__nin=[EntityVisibility.archived.value],
|
||||
project__in=project_ids,
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
).only("id")
|
||||
if non_archived:
|
||||
raise error("use force=true to delete", id=project_id)
|
||||
@ -51,12 +56,14 @@ def delete_project(
|
||||
if not delete_contents:
|
||||
with TimingContext("mongo", "update_children"):
|
||||
for cls in (Model, Task):
|
||||
updated_count = cls.objects(project=project_id).update(project=None)
|
||||
updated_count = cls.objects(project__in=project_ids).update(
|
||||
project=None
|
||||
)
|
||||
res = DeleteProjectResult(disassociated_tasks=updated_count)
|
||||
else:
|
||||
deleted_models, model_urls = _delete_models(project=project_id)
|
||||
deleted_models, model_urls = _delete_models(projects=project_ids)
|
||||
deleted_tasks, event_urls, artifact_urls = _delete_tasks(
|
||||
company=company, project=project_id
|
||||
company=company, projects=project_ids
|
||||
)
|
||||
res = DeleteProjectResult(
|
||||
deleted_tasks=deleted_tasks,
|
||||
@ -68,25 +75,27 @@ def delete_project(
|
||||
),
|
||||
)
|
||||
|
||||
res.deleted = Project.objects(id=project_id).delete()
|
||||
return res
|
||||
affected = {*project_ids, *(project.path or [])}
|
||||
res.deleted = Project.objects(id__in=project_ids).delete()
|
||||
|
||||
return res, affected
|
||||
|
||||
|
||||
def _delete_tasks(company: str, project: str) -> Tuple[int, Set, Set]:
|
||||
def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]:
|
||||
"""
|
||||
Delete only the task themselves and their non published version.
|
||||
Child models under the same project are deleted separately.
|
||||
Children tasks should be deleted in the same api call.
|
||||
If any child entities are left in another projects then updated their parent task to None
|
||||
"""
|
||||
tasks = Task.objects(project=project).only("id", "execution__artifacts")
|
||||
tasks = Task.objects(project__in=projects).only("id", "execution__artifacts")
|
||||
if not tasks:
|
||||
return 0, set(), set()
|
||||
|
||||
task_ids = {t.id for t in tasks}
|
||||
with TimingContext("mongo", "delete_tasks_update_children"):
|
||||
Task.objects(parent__in=task_ids, project__ne=project).update(parent=None)
|
||||
Model.objects(task__in=task_ids, project__ne=project).update(task=None)
|
||||
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
|
||||
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
|
||||
|
||||
event_urls, artifact_urls = set(), set()
|
||||
for task in tasks:
|
||||
@ -106,18 +115,18 @@ def _delete_tasks(company: str, project: str) -> Tuple[int, Set, Set]:
|
||||
return deleted, event_urls, artifact_urls
|
||||
|
||||
|
||||
def _delete_models(project: str) -> Tuple[int, Set[str]]:
|
||||
def _delete_models(projects: Sequence[str]) -> Tuple[int, Set[str]]:
|
||||
"""
|
||||
Delete project models and update the tasks from other projects
|
||||
that reference them to reference None.
|
||||
"""
|
||||
with TimingContext("mongo", "delete_models"):
|
||||
models = Model.objects(project=project).only("task", "id", "uri")
|
||||
models = Model.objects(project__in=projects).only("task", "id", "uri")
|
||||
if not models:
|
||||
return 0, set()
|
||||
|
||||
model_ids = {m.id for m in models}
|
||||
Task.objects(execution__model__in=model_ids, project__ne=project).update(
|
||||
Task.objects(execution__model__in=model_ids, project__nin=projects).update(
|
||||
execution__model=None
|
||||
)
|
||||
|
||||
@ -125,7 +134,7 @@ def _delete_models(project: str) -> Tuple[int, Set[str]]:
|
||||
if model_tasks:
|
||||
now = datetime.utcnow()
|
||||
Task.objects(
|
||||
id__in=model_tasks, project__ne=project, output__model__in=model_ids
|
||||
id__in=model_tasks, project__nin=projects, output__model__in=model_ids
|
||||
).update(
|
||||
output__model=None,
|
||||
output__error=f"model deleted on {now.isoformat()}",
|
||||
|
163
apiserver/bll/project/sub_projects.py
Normal file
163
apiserver/bll/project/sub_projects.py
Normal file
@ -0,0 +1,163 @@
|
||||
import itertools
|
||||
from datetime import datetime
|
||||
from typing import Tuple, Optional, Sequence, Mapping
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.project import Project
|
||||
|
||||
name_separator = "/"
|
||||
max_depth = config.get("services.projects.sub_projects.max_depth", 10)
|
||||
|
||||
|
||||
def _validate_project_name(project_name: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Remove redundant '/' characters. Ensure that the project name is not empty
|
||||
and path to it is not larger then max_depth parameter.
|
||||
Return the cleaned up project name and location
|
||||
"""
|
||||
name_parts = list(filter(None, project_name.split(name_separator)))
|
||||
if not name_parts:
|
||||
raise errors.bad_request.InvalidProjectName(name=project_name)
|
||||
|
||||
if len(name_parts) > max_depth:
|
||||
raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth)
|
||||
|
||||
return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
|
||||
|
||||
|
||||
def _ensure_project(company: str, user: str, name: str) -> Optional[Project]:
|
||||
"""
|
||||
Makes sure that the project with the given name exists
|
||||
If needed auto-create the project and all the missing projects in the path to it
|
||||
Return the project
|
||||
"""
|
||||
name = name.strip(name_separator)
|
||||
if not name:
|
||||
return None
|
||||
|
||||
project = _get_writable_project_from_name(company, name)
|
||||
if project:
|
||||
return project
|
||||
|
||||
now = datetime.utcnow()
|
||||
name, location = _validate_project_name(name)
|
||||
project = Project(
|
||||
id=database.utils.id(),
|
||||
user=user,
|
||||
company=company,
|
||||
created=now,
|
||||
last_update=now,
|
||||
name=name,
|
||||
description="",
|
||||
)
|
||||
parent = _ensure_project(company, user, location)
|
||||
_save_under_parent(project=project, parent=parent)
|
||||
if parent:
|
||||
parent.update(last_update=now)
|
||||
|
||||
return project
|
||||
|
||||
|
||||
def _save_under_parent(project: Project, parent: Optional[Project]):
|
||||
"""
|
||||
Save the project under the given parent project or top level (parent=None)
|
||||
Check that the project location matches the parent name
|
||||
"""
|
||||
location, _, _ = project.name.rpartition(name_separator)
|
||||
if not parent:
|
||||
if location:
|
||||
raise ValueError(
|
||||
f"Project location {location} does not match empty parent name"
|
||||
)
|
||||
project.parent = None
|
||||
project.path = []
|
||||
project.save()
|
||||
return
|
||||
|
||||
if location != parent.name:
|
||||
raise ValueError(
|
||||
f"Project location {location} does not match parent name {parent.name}"
|
||||
)
|
||||
project.parent = parent.id
|
||||
project.path = [*(parent.path or []), parent.id]
|
||||
project.save()
|
||||
|
||||
|
||||
def _get_writable_project_from_name(
|
||||
company,
|
||||
name,
|
||||
_only: Optional[Sequence[str]] = ("id", "name", "path", "company", "parent"),
|
||||
) -> Optional[Project]:
|
||||
"""
|
||||
Return a project from name. If the project not found then return None
|
||||
"""
|
||||
qs = Project.objects(company=company, name=name)
|
||||
if _only:
|
||||
qs = qs.only(*_only)
|
||||
return qs.first()
|
||||
|
||||
|
||||
def _get_sub_projects(
|
||||
project_ids: Sequence[str], _only: Sequence[str] = ("id", "path")
|
||||
) -> Mapping[str, Sequence[Project]]:
|
||||
"""
|
||||
Return the list of child projects of all the levels for the parent project ids
|
||||
"""
|
||||
qs = Project.objects(path__in=project_ids)
|
||||
if _only:
|
||||
_only = set(_only) | {"path"}
|
||||
qs = qs.only(*_only)
|
||||
subprojects = list(qs)
|
||||
|
||||
return {
|
||||
pid: [s for s in subprojects if pid in (s.path or [])] for pid in project_ids
|
||||
}
|
||||
|
||||
|
||||
def _ids_with_parents(project_ids: Sequence[str]) -> Sequence[str]:
|
||||
"""
|
||||
Return project ids with all the parent projects
|
||||
"""
|
||||
projects = Project.objects(id__in=project_ids).only("id", "path")
|
||||
parent_ids = set(itertools.chain.from_iterable(p.path for p in projects if p.path))
|
||||
return list({*(p.id for p in projects), *parent_ids})
|
||||
|
||||
|
||||
def _ids_with_children(project_ids: Sequence[str]) -> Sequence[str]:
|
||||
"""
|
||||
Return project ids with the ids of all the subprojects
|
||||
"""
|
||||
subprojects = Project.objects(path__in=project_ids).only("id")
|
||||
return list({*project_ids, *(child.id for child in subprojects)})
|
||||
|
||||
|
||||
def _update_subproject_names(project: Project, update_path: bool = False) -> int:
|
||||
"""
|
||||
Update sub project names when the base project name changes
|
||||
Optionally update the paths
|
||||
"""
|
||||
child_projects = _get_sub_projects(project_ids=[project.id], _only=("id", "name"))
|
||||
updated = 0
|
||||
for child in child_projects[project.id]:
|
||||
child_suffix = name_separator.join(
|
||||
child.name.split(name_separator)[len(project.path) + 1 :]
|
||||
)
|
||||
updates = {"name": name_separator.join((project.name, child_suffix))}
|
||||
if update_path:
|
||||
updates["path"] = project.path + child.path[len(project.path) :]
|
||||
updated += child.update(upsert=False, **updates)
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
def _reposition_project_with_children(project: Project, parent: Project) -> int:
|
||||
new_location = parent.name if parent else None
|
||||
project.name = name_separator.join(
|
||||
filter(None, (new_location, project.name.split(name_separator)[-1]))
|
||||
)
|
||||
_save_under_parent(project, parent=parent)
|
||||
|
||||
moved = 1 + _update_subproject_names(project=project, update_path=True)
|
||||
return moved
|
@ -26,7 +26,6 @@ from apiserver.database.model.task.task import (
|
||||
TaskStatusMessage,
|
||||
TaskSystemTags,
|
||||
ArtifactModes,
|
||||
external_task_types,
|
||||
)
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||
@ -51,18 +50,6 @@ class TaskBLL:
|
||||
self.events_es = events_es or es_factory.connect("events")
|
||||
self.redis: StrictRedis = redis or redman.connection("apiserver")
|
||||
|
||||
@classmethod
|
||||
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
|
||||
"""
|
||||
Return the list of unique task types used by company and public tasks
|
||||
If project ids passed then only tasks from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
res = Task.objects(query).distinct(field="type")
|
||||
return set(res).intersection(external_task_types)
|
||||
|
||||
@staticmethod
|
||||
def get_task_with_access(
|
||||
task_id, company_id, only=None, allow_public=False, requires_write_access=False
|
||||
|
@ -10,4 +10,9 @@ featured {
|
||||
|
||||
# default featured index for public projects not specified in the order
|
||||
public_default: 9999
|
||||
}
|
||||
|
||||
sub_projects {
|
||||
# the max sub project depth
|
||||
max_depth: 10
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
from mongoengine import StringField, DateTimeField, IntField
|
||||
from mongoengine import StringField, DateTimeField, IntField, ListField
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.fields import StrippedStringField, SafeSortedListField
|
||||
@ -10,13 +10,15 @@ class Project(AttributedDocument):
|
||||
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name", "description"),
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
list_fields=("tags", "system_tags", "id", "parent", "path"),
|
||||
)
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
"parent",
|
||||
"path",
|
||||
("company", "name"),
|
||||
{
|
||||
"name": "%s.project.main_text_index" % Database.backend,
|
||||
@ -44,3 +46,5 @@ class Project(AttributedDocument):
|
||||
logo_url = StringField()
|
||||
logo_blob = StringField(exclude_by_default=True)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
parent = StringField(reference_field="Project")
|
||||
path = ListField(StringField(required=True), exclude_by_default=True)
|
||||
|
@ -33,6 +33,7 @@ from mongoengine import Q
|
||||
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import EventType
|
||||
from apiserver.bll.project import project_ids_with_children
|
||||
from apiserver.bll.task.artifacts import get_artifact_id
|
||||
from apiserver.bll.task.param_utils import (
|
||||
split_param_name,
|
||||
@ -423,6 +424,22 @@ class PrePopulate:
|
||||
items.append(results[0])
|
||||
return items
|
||||
|
||||
@classmethod
|
||||
def _check_projects_hierarchy(cls, projects: Set[Project]):
|
||||
"""
|
||||
For any exported project all its parents up to the root should be present
|
||||
"""
|
||||
if not projects:
|
||||
return
|
||||
|
||||
project_ids = {p.id for p in projects}
|
||||
orphans = [p.id for p in projects if p.parent and p.parent not in project_ids]
|
||||
if not orphans:
|
||||
return
|
||||
|
||||
print(f"ERROR: the following projects are exported without their parents: {orphans}")
|
||||
exit(1)
|
||||
|
||||
@classmethod
|
||||
def _resolve_entities(
|
||||
cls,
|
||||
@ -434,6 +451,7 @@ class PrePopulate:
|
||||
|
||||
if projects:
|
||||
print("Reading projects...")
|
||||
projects = project_ids_with_children(projects)
|
||||
entities[cls.project_cls].update(
|
||||
cls._resolve_type(cls.project_cls, projects)
|
||||
)
|
||||
@ -463,6 +481,8 @@ class PrePopulate:
|
||||
project_ids = {p.id for p in entities[cls.project_cls]}
|
||||
entities[cls.project_cls].update(o for o in objs if o.id not in project_ids)
|
||||
|
||||
cls._check_projects_hierarchy(entities[cls.project_cls])
|
||||
|
||||
model_ids = {
|
||||
model_id
|
||||
for task in entities[cls.task_cls]
|
||||
@ -634,11 +654,12 @@ class PrePopulate:
|
||||
"""
|
||||
Export the requested experiments, projects and models and return the list of artifact files
|
||||
Always do the export on sorted items since the order of items influence hash
|
||||
The projects should be sorted by name so that on import the hierarchy is correctly restored from top to bottom
|
||||
"""
|
||||
artifacts = []
|
||||
now = datetime.utcnow()
|
||||
for cls_ in sorted(entities, key=attrgetter("__name__")):
|
||||
items = sorted(entities[cls_], key=attrgetter("id"))
|
||||
items = sorted(entities[cls_], key=attrgetter("name", "id"))
|
||||
if not items:
|
||||
continue
|
||||
base_filename = cls._get_base_filename(cls_)
|
||||
|
@ -167,10 +167,27 @@ _definitions {
|
||||
type: string
|
||||
}
|
||||
// extra properties
|
||||
stats: {
|
||||
stats {
|
||||
description: "Additional project stats"
|
||||
"$ref": "#/definitions/stats"
|
||||
}
|
||||
sub_projects {
|
||||
description: "The list of sub projects"
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Subproject ID"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Subproject name"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
metric_variant_result {
|
||||
@ -405,6 +422,17 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${get_all."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
shallow_search {
|
||||
description: "If set to 'true' then the search with the specified criteria is performed among top level projects only (or if parents specified, among the direct children of the these parents). Otherwise the search is performed among all the company projects (or among all of the descendants of the specified parents)."
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all_ex {
|
||||
internal: true
|
||||
@ -438,6 +466,11 @@ get_all_ex {
|
||||
type: array
|
||||
items: {type: string}
|
||||
}
|
||||
shallow_search {
|
||||
description: "If set to 'true' then the search with the specified criteria is performed among top level projects only (or if parents specified, among the direct children of the these parents). Otherwise the search is performed among all the company projects (or among all of the descendants of the specified parents)."
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -498,6 +531,66 @@ update {
|
||||
}
|
||||
}
|
||||
}
|
||||
move {
|
||||
"2.13" {
|
||||
description: "Moves a project and all of its subprojects under the different location"
|
||||
request {
|
||||
type: object
|
||||
required: [project]
|
||||
properties {
|
||||
project {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
new_location {
|
||||
description: "The name location for the project"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
moved {
|
||||
description: "The number of projects moved"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
merge {
|
||||
"2.13" {
|
||||
description: "Moves all the source project's contents to the destination project and remove the source project"
|
||||
request {
|
||||
type: object
|
||||
required: [project]
|
||||
properties {
|
||||
project {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
destination_project {
|
||||
description: "The ID of the destination project"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
moved_entities {
|
||||
description: "The number of tasks, models and dataviews moved from the merged project into the destination"
|
||||
type: integer
|
||||
}
|
||||
moved_projects {
|
||||
description: "The number of child projects moved from the merged project into the destination"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete {
|
||||
"2.1" {
|
||||
description: "Deletes a project"
|
||||
|
@ -16,7 +16,6 @@ from apiserver.apimodels.models import (
|
||||
GetFrameworksRequest,
|
||||
DeleteModelRequest,
|
||||
)
|
||||
from apiserver.bll.model import ModelBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.task import TaskBLL
|
||||
@ -38,7 +37,6 @@ from apiserver.timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
model_bll = ModelBLL()
|
||||
project_bll = ProjectBLL()
|
||||
|
||||
|
||||
@ -131,7 +129,7 @@ def get_all(call: APICall, company_id, _):
|
||||
def get_frameworks(call: APICall, company_id, request: GetFrameworksRequest):
|
||||
call.result.data = {
|
||||
"frameworks": sorted(
|
||||
model_bll.get_frameworks(company_id, project_ids=request.projects)
|
||||
project_bll.get_model_frameworks(company_id, project_ids=request.projects)
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
from datetime import datetime
|
||||
from typing import Sequence
|
||||
|
||||
import attr
|
||||
@ -8,13 +7,13 @@ from apiserver.apierrors import errors
|
||||
from apiserver.apierrors.errors.bad_request import InvalidProjectId
|
||||
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
|
||||
from apiserver.apimodels.projects import (
|
||||
GetHyperParamReq,
|
||||
ProjectReq,
|
||||
GetHyperParamRequest,
|
||||
ProjectRequest,
|
||||
ProjectTagsRequest,
|
||||
ProjectTaskParentsRequest,
|
||||
ProjectHyperparamValuesRequest,
|
||||
ProjectsGetRequest,
|
||||
DeleteRequest,
|
||||
DeleteRequest, MoveRequest, MergeRequest,
|
||||
)
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
@ -47,10 +46,6 @@ create_fields = {
|
||||
"default_output_destination": None,
|
||||
}
|
||||
|
||||
get_all_query_options = Project.QueryParameterOptions(
|
||||
pattern_fields=("name", "description"), list_fields=("tags", "system_tags", "id"),
|
||||
)
|
||||
|
||||
|
||||
@endpoint("projects.get_by_id", required_fields=["project"])
|
||||
def get_by_id(call):
|
||||
@ -72,27 +67,46 @@ def get_by_id(call):
|
||||
call.result.data = {"project": project_dict}
|
||||
|
||||
|
||||
def _adjust_search_parameters(data: dict, shallow_search: bool):
|
||||
"""
|
||||
1. Make sure that there is no external query on path
|
||||
2. If not shallow_search and parent is provided then parent can be at any place in path
|
||||
3. If shallow_search and no parent provided then use a top level parent
|
||||
"""
|
||||
data.pop("path", None)
|
||||
if not shallow_search:
|
||||
if "parent" in data:
|
||||
data["path"] = data.pop("parent")
|
||||
return
|
||||
|
||||
if "parent" not in data:
|
||||
data["parent"] = [None, ""]
|
||||
|
||||
|
||||
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
|
||||
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
conform_tag_fields(call, call.data)
|
||||
allow_public = not request.non_public
|
||||
shallow_search = request.shallow_search or request.include_stats
|
||||
with TimingContext("mongo", "projects_get_all"):
|
||||
data = call.data
|
||||
if request.active_users:
|
||||
ids = project_bll.get_projects_with_active_user(
|
||||
company=company_id,
|
||||
users=request.active_users,
|
||||
project_ids=call.data.get("id"),
|
||||
project_ids=data.get("id"),
|
||||
allow_public=allow_public,
|
||||
)
|
||||
if not ids:
|
||||
call.result.data = {"projects": []}
|
||||
return
|
||||
call.data["id"] = ids
|
||||
data["id"] = ids
|
||||
|
||||
_adjust_search_parameters(data, shallow_search=shallow_search)
|
||||
|
||||
projects = Project.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
query_options=get_all_query_options,
|
||||
query_dict=data,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
|
||||
@ -102,7 +116,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
return
|
||||
|
||||
project_ids = {project["id"] for project in projects}
|
||||
stats = project_bll.get_project_stats(
|
||||
stats, children = project_bll.get_project_stats(
|
||||
company=company_id,
|
||||
project_ids=list(project_ids),
|
||||
specific_state=request.stats_for_state,
|
||||
@ -110,6 +124,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
|
||||
for project in projects:
|
||||
project["stats"] = stats[project["id"]]
|
||||
project["sub_projects"] = children[project["id"]]
|
||||
|
||||
call.result.data = {"projects": projects}
|
||||
|
||||
@ -117,12 +132,13 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
@endpoint("projects.get_all")
|
||||
def get_all(call: APICall):
|
||||
conform_tag_fields(call, call.data)
|
||||
data = call.data
|
||||
_adjust_search_parameters(data, shallow_search=data.get("shallow_search", False))
|
||||
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
|
||||
projects = Project.get_many(
|
||||
company=call.identity.company,
|
||||
query_dict=call.data,
|
||||
query_options=get_all_query_options,
|
||||
parameters=call.data,
|
||||
query_dict=data,
|
||||
parameters=data,
|
||||
allow_public=True,
|
||||
)
|
||||
conform_output_tags(call, projects)
|
||||
@ -161,22 +177,15 @@ def update(call: APICall):
|
||||
:return: updated - `int` - number of projects updated
|
||||
fields - `[string]` - updated fields
|
||||
"""
|
||||
project_id = call.data["project"]
|
||||
|
||||
with translate_errors_context():
|
||||
project = Project.get_for_writing(company=call.identity.company, id=project_id)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
fields = parse_from_call(
|
||||
call.data, create_fields, Project.get_fields(), discard_none_values=False
|
||||
)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
fields["last_update"] = datetime.utcnow()
|
||||
with TimingContext("mongo", "projects_update"):
|
||||
updated = project.update(upsert=False, **fields)
|
||||
conform_output_tags(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
fields = parse_from_call(
|
||||
call.data, create_fields, Project.get_fields(), discard_none_values=False
|
||||
)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
updated = ProjectBLL.update(
|
||||
company=call.identity.company, project_id=call.data["project"], **fields
|
||||
)
|
||||
conform_output_tags(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
|
||||
|
||||
def _reset_cached_tags(company: str, projects: Sequence[str]):
|
||||
@ -184,20 +193,47 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
|
||||
org_bll.reset_tags(company, Tags.Model, projects=projects)
|
||||
|
||||
|
||||
@endpoint("projects.move", request_data_model=MoveRequest)
|
||||
def move(call: APICall, company: str, request: MoveRequest):
|
||||
moved, affected_projects = ProjectBLL.move_project(
|
||||
company=company,
|
||||
user=call.identity.user,
|
||||
project_id=request.project,
|
||||
new_location=request.new_location,
|
||||
)
|
||||
_reset_cached_tags(company, projects=list(affected_projects))
|
||||
|
||||
call.result.data = {"moved": moved}
|
||||
|
||||
|
||||
@endpoint("projects.merge", request_data_model=MergeRequest)
|
||||
def merge(call: APICall, company: str, request: MergeRequest):
|
||||
moved_entitites, moved_projects, affected_projects = ProjectBLL.merge_project(
|
||||
company, source_id=request.project, destination_id=request.destination_project
|
||||
)
|
||||
|
||||
_reset_cached_tags(company, projects=list(affected_projects))
|
||||
|
||||
call.result.data = {
|
||||
"moved_entities": moved_entitites,
|
||||
"moved_projects": moved_projects,
|
||||
}
|
||||
|
||||
|
||||
@endpoint("projects.delete", request_data_model=DeleteRequest)
|
||||
def delete(call: APICall, company_id: str, request: DeleteRequest):
|
||||
res = delete_project(
|
||||
res, affected_projects = delete_project(
|
||||
company=company_id,
|
||||
project_id=request.project,
|
||||
force=request.force,
|
||||
delete_contents=request.delete_contents,
|
||||
)
|
||||
_reset_cached_tags(company_id, projects=[request.project])
|
||||
_reset_cached_tags(company_id, projects=list(affected_projects))
|
||||
call.result.data = {**attr.asdict(res)}
|
||||
|
||||
|
||||
@endpoint("projects.get_unique_metric_variants", request_data_model=ProjectReq)
|
||||
def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectReq):
|
||||
@endpoint("projects.get_unique_metric_variants", request_data_model=ProjectRequest)
|
||||
def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectRequest):
|
||||
|
||||
metrics = task_bll.get_unique_metric_variants(
|
||||
company_id, [request.project] if request.project else None
|
||||
@ -209,9 +245,9 @@ def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectR
|
||||
@endpoint(
|
||||
"projects.get_hyper_parameters",
|
||||
min_version="2.9",
|
||||
request_data_model=GetHyperParamReq,
|
||||
request_data_model=GetHyperParamRequest,
|
||||
)
|
||||
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamReq):
|
||||
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamRequest):
|
||||
|
||||
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
|
||||
company_id,
|
||||
@ -303,7 +339,7 @@ def get_task_parents(
|
||||
call: APICall, company_id: str, request: ProjectTaskParentsRequest
|
||||
):
|
||||
call.result.data = {
|
||||
"parents": org_bll.get_parent_tasks(
|
||||
"parents": project_bll.get_task_parents(
|
||||
company_id, projects=request.projects, state=request.tasks_state
|
||||
)
|
||||
}
|
||||
|
@ -216,7 +216,7 @@ def get_all(call: APICall, company_id, _):
|
||||
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)
|
||||
def get_types(call: APICall, company_id, request: GetTypesRequest):
|
||||
call.result.data = {
|
||||
"types": list(task_bll.get_types(company_id, project_ids=request.projects))
|
||||
"types": list(project_bll.get_task_types(company_id, project_ids=request.projects))
|
||||
}
|
||||
|
||||
|
||||
|
240
apiserver/tests/automated/test_subprojects.py
Normal file
240
apiserver/tests/automated/test_subprojects.py
Normal file
@ -0,0 +1,240 @@
|
||||
from time import sleep
|
||||
from typing import Sequence, Optional, Tuple
|
||||
|
||||
from boltons.iterutils import first
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.utils import id as db_id
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestSubProjects(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version="2.13")
|
||||
|
||||
def test_project_aggregations(self):
|
||||
child = self._temp_project(name="Aggregation/Pr1")
|
||||
project = self.api.projects.get_all_ex(name="^Aggregation$").projects[0].id
|
||||
child_project = self.api.projects.get_all_ex(id=[child]).projects[0]
|
||||
self.assertEqual(child_project.parent.id, project)
|
||||
|
||||
user = self.api.users.get_current_user().user.id
|
||||
|
||||
# test aggregations on project with empty subprojects
|
||||
res = self.api.users.get_all_ex(active_in_projects=[project])
|
||||
self.assertEqual(res.users, [])
|
||||
res = self.api.projects.get_all_ex(id=[project], active_users=[user])
|
||||
self.assertEqual(res.projects, [])
|
||||
res = self.api.models.get_frameworks(projects=[project])
|
||||
self.assertEqual(res.frameworks, [])
|
||||
res = self.api.tasks.get_types(projects=[project])
|
||||
self.assertEqual(res.types, [])
|
||||
res = self.api.projects.get_task_parents(projects=[project])
|
||||
self.assertEqual(res.parents, [])
|
||||
|
||||
# test aggregations with non-empty subprojects
|
||||
task1 = self._temp_task(project=child)
|
||||
self._temp_task(project=child, parent=task1)
|
||||
framework = "Test framework"
|
||||
self._temp_model(project=child, framework=framework)
|
||||
res = self.api.users.get_all_ex(active_in_projects=[project])
|
||||
self._assert_ids(res.users, [user])
|
||||
res = self.api.projects.get_all_ex(id=[project], active_users=[user])
|
||||
self._assert_ids(res.projects, [project])
|
||||
res = self.api.projects.get_task_parents(projects=[project])
|
||||
self._assert_ids(res.parents, [task1])
|
||||
res = self.api.models.get_frameworks(projects=[project])
|
||||
self.assertEqual(res.frameworks, [framework])
|
||||
res = self.api.tasks.get_types(projects=[project])
|
||||
self.assertEqual(res.types, ["testing"])
|
||||
|
||||
def _assert_ids(self, actual: Sequence[dict], expected: Sequence[str]):
|
||||
self.assertEqual([a["id"] for a in actual], expected)
|
||||
|
||||
def test_project_operations(self):
|
||||
# create
|
||||
with self.api.raises(errors.bad_request.InvalidProjectName):
|
||||
self._temp_project(name="/")
|
||||
project1 = self._temp_project(name="Root1/Pr1")
|
||||
project1_child = self._temp_project(name="Root1/Pr1/Pr2")
|
||||
with self.api.raises(errors.bad_request.ExpectedUniqueData):
|
||||
self._temp_project(name="Root1/Pr1/Pr2")
|
||||
|
||||
# update
|
||||
with self.api.raises(errors.bad_request.CannotUpdateProjectLocation):
|
||||
self.api.projects.update(project=project1, name="Root2/Pr2")
|
||||
res = self.api.projects.update(project=project1, name="Root1/Pr2")
|
||||
self.assertEqual(res.updated, 1)
|
||||
res = self.api.projects.get_by_id(project=project1_child)
|
||||
self.assertEqual(res.project.name, "Root1/Pr2/Pr2")
|
||||
|
||||
# move
|
||||
res = self.api.projects.move(project=project1, new_location="Root2")
|
||||
self.assertEqual(res.moved, 2)
|
||||
res = self.api.projects.get_by_id(project=project1_child)
|
||||
self.assertEqual(res.project.name, "Root2/Pr2/Pr2")
|
||||
|
||||
# merge
|
||||
project_with_task, (active, archived) = self._temp_project_with_tasks(
|
||||
"Root1/Pr3/Pr4"
|
||||
)
|
||||
project1_parent = self._getProjectParent(project1)
|
||||
self._assertTags(project1_parent, tags=[], system_tags=[])
|
||||
self._assertTags(project1_parent, tags=[], system_tags=[])
|
||||
project_with_task_parent = self._getProjectParent(project_with_task)
|
||||
self._assertTags(project_with_task_parent)
|
||||
# self._assertTags(project_id=None)
|
||||
|
||||
merge_source = self.api.projects.get_by_id(
|
||||
project=project_with_task
|
||||
).project.parent
|
||||
res = self.api.projects.merge(
|
||||
project=merge_source, destination_project=project1
|
||||
)
|
||||
self.assertEqual(res.moved_entities, 0)
|
||||
self.assertEqual(res.moved_projects, 1)
|
||||
res = self.api.projects.get_by_id(project=project_with_task)
|
||||
self.assertEqual(res.project.name, "Root2/Pr2/Pr4")
|
||||
with self.api.raises(errors.bad_request.InvalidProjectId):
|
||||
self.api.projects.get_by_id(project=merge_source)
|
||||
|
||||
self._assertTags(project1_parent)
|
||||
self._assertTags(project1)
|
||||
self._assertTags(project_with_task_parent, tags=[], system_tags=[])
|
||||
# self._assertTags(project_id=None)
|
||||
|
||||
# delete
|
||||
with self.api.raises(errors.bad_request.ProjectHasTasks):
|
||||
self.api.projects.delete(project=project1)
|
||||
res = self.api.projects.delete(project=project1, force=True)
|
||||
self.assertEqual(res.deleted, 3)
|
||||
self.assertEqual(res.disassociated_tasks, 2)
|
||||
res = self.api.tasks.get_by_id(task=active).task
|
||||
self.assertIsNone(res.get("project"))
|
||||
for p_id in (project1, project1_child, project_with_task):
|
||||
with self.api.raises(errors.bad_request.InvalidProjectId):
|
||||
self.api.projects.get_by_id(project=p_id)
|
||||
|
||||
self._assertTags(project1_parent, tags=[], system_tags=[])
|
||||
# self._assertTags(project_id=None, tags=[], system_tags=[])
|
||||
|
||||
def _getProjectParent(self, project_id: str):
|
||||
return self.api.projects.get_all_ex(id=[project_id]).projects[0].parent.id
|
||||
|
||||
def _assertTags(
|
||||
self,
|
||||
project_id: Optional[str],
|
||||
tags: Sequence[str] = ("test",),
|
||||
system_tags: Sequence[str] = (EntityVisibility.archived.value,),
|
||||
):
|
||||
if project_id:
|
||||
res = self.api.projects.get_task_tags(
|
||||
projects=[project_id], include_system=True
|
||||
)
|
||||
else:
|
||||
res = self.api.organization.get_tags(include_system=True)
|
||||
|
||||
self.assertEqual(set(res.tags), set(tags))
|
||||
self.assertEqual(set(res.system_tags), set(system_tags))
|
||||
|
||||
def test_get_all_search_options(self):
|
||||
project1 = self._temp_project(name="project1")
|
||||
project2 = self._temp_project(name="project1/project2")
|
||||
self._temp_project(name="project3")
|
||||
|
||||
# local search finds only at the specified level
|
||||
res = self.api.projects.get_all_ex(
|
||||
name="project1", shallow_search=True
|
||||
).projects
|
||||
self.assertEqual([p.id for p in res], [project1])
|
||||
res = self.api.projects.get_all_ex(name="project1", parent=[project1]).projects
|
||||
self.assertEqual([p.id for p in res], [project2])
|
||||
|
||||
# global search finds all or below the specified level
|
||||
res = self.api.projects.get_all_ex(name="project1").projects
|
||||
self.assertEqual(set(p.id for p in res), {project1, project2})
|
||||
project4 = self._temp_project(name="project1/project2/project1")
|
||||
res = self.api.projects.get_all_ex(name="project1", parent=[project2]).projects
|
||||
self.assertEqual([p.id for p in res], [project4])
|
||||
|
||||
self.api.projects.delete(project=project1, force=True)
|
||||
|
||||
def test_get_all_with_stats(self):
|
||||
project4, _ = self._temp_project_with_tasks(name="project1/project3/project4")
|
||||
project5, _ = self._temp_project_with_tasks(name="project1/project3/project5")
|
||||
project2 = self._temp_project(name="project2")
|
||||
res = self.api.projects.get_all(shallow_search=True).projects
|
||||
self.assertTrue(any(p for p in res if p.id == project2))
|
||||
self.assertFalse(any(p for p in res if p.id in [project4, project5]))
|
||||
|
||||
project1 = first(p.id for p in res if p.name == "project1")
|
||||
res = self.api.projects.get_all_ex(
|
||||
id=[project1, project2], include_stats=True
|
||||
).projects
|
||||
self.assertEqual(set(p.id for p in res), {project1, project2})
|
||||
res1 = next(p for p in res if p.id == project1)
|
||||
self.assertEqual(res1.stats["active"]["status_count"]["created"], 0)
|
||||
self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2)
|
||||
self.assertEqual(res1.stats["active"]["total_runtime"], 2)
|
||||
self.assertEqual(
|
||||
{sp.name for sp in res1.sub_projects},
|
||||
{
|
||||
"project1/project3",
|
||||
"project1/project3/project4",
|
||||
"project1/project3/project5",
|
||||
},
|
||||
)
|
||||
res2 = next(p for p in res if p.id == project2)
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["created"], 0)
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0)
|
||||
self.assertEqual(res2.stats["active"]["total_runtime"], 0)
|
||||
self.assertEqual(res2.sub_projects, [])
|
||||
|
||||
def _run_tasks(self, *tasks):
|
||||
"""Imitate 1 second of running"""
|
||||
for task_id in tasks:
|
||||
self.api.tasks.started(task=task_id)
|
||||
sleep(1)
|
||||
for task_id in tasks:
|
||||
self.api.tasks.stopped(task=task_id)
|
||||
|
||||
def _temp_project_with_tasks(self, name) -> Tuple[str, Tuple[str, str]]:
|
||||
pr_id = self._temp_project(name=name)
|
||||
task_active = self._temp_task(project=pr_id)
|
||||
task_archived = self._temp_task(
|
||||
project=pr_id, system_tags=[EntityVisibility.archived.value], tags=["test"]
|
||||
)
|
||||
self._run_tasks(task_active, task_archived)
|
||||
return pr_id, (task_active, task_archived)
|
||||
|
||||
delete_params = dict(can_fail=True, force=True)
|
||||
|
||||
def _temp_project(self, name, **kwargs):
|
||||
return self.create_temp(
|
||||
"projects",
|
||||
delete_params=self.delete_params,
|
||||
name=name,
|
||||
description="",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _temp_task(self, **kwargs):
|
||||
return self.create_temp(
|
||||
"tasks",
|
||||
delete_params=self.delete_params,
|
||||
type="testing",
|
||||
name=db_id(),
|
||||
input=dict(view=dict()),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _temp_model(self, **kwargs):
|
||||
return self.create_temp(
|
||||
service="models",
|
||||
delete_params=self.delete_params,
|
||||
name="test",
|
||||
uri="file:///a",
|
||||
labels={},
|
||||
**kwargs,
|
||||
)
|
Loading…
Reference in New Issue
Block a user