clearml-server/apiserver/services/projects.py
allegroai bca3a6e556 Set default task active duration to None
Move endpoints to new API version
Support tasks.move and models.move for moving tasks and models into projects
Support new project name in tasks.clone
Improve task active duration migration
2021-01-05 18:05:44 +02:00

441 lines
14 KiB
Python

from collections import defaultdict
from datetime import datetime
from itertools import groupby
from operator import itemgetter
import dpath
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidProjectId
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
from apiserver.apimodels.projects import (
GetHyperParamReq,
ProjectReq,
ProjectTagsRequest,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.database.utils import parse_from_call, get_options, get_company_or_none_constraint
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
get_tags_filter_dictionary,
get_tags_response,
)
from apiserver.timing_context import TimingContext
org_bll = OrgBLL()
task_bll = TaskBLL()
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
create_fields = {
"name": None,
"description": None,
"tags": list,
"system_tags": list,
"default_output_destination": None,
}
get_all_query_options = Project.QueryParameterOptions(
pattern_fields=("name", "description"), list_fields=("tags", "system_tags", "id"),
)
@endpoint("projects.get_by_id", required_fields=["project"])
def get_by_id(call):
assert isinstance(call, APICall)
project_id = call.data["project"]
with translate_errors_context():
with TimingContext("mongo", "projects_by_id"):
query = Q(id=project_id) & get_company_or_none_constraint(
call.identity.company
)
project = Project.objects(query).first()
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
project_dict = project.to_proper_dict()
conform_output_tags(call, project_dict)
call.result.data = {"project": project_dict}
def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None):
archived = EntityVisibility.archived.value
def ensure_valid_fields():
"""
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
"""
return {
"$addFields": {
"system_tags": {
"$cond": {
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
"then": [],
"else": "$system_tags",
}
},
"status": {"$ifNull": ["$status", "unknown"]},
}
}
status_count_pipeline = [
# count tasks per project per status
{
"$match": {
"company": {"$in": [None, "", company_id]},
"project": {"$in": project_ids},
}
},
ensure_valid_fields(),
{
"$group": {
"_id": {
"project": "$project",
"status": "$status",
archived: archived_tasks_cond,
},
"count": {"$sum": 1},
}
},
# for each project, create a list of (status, count, archived)
{
"$group": {
"_id": "$_id.project",
"counts": {
"$push": {
"status": "$_id.status",
"count": "$count",
archived: "$_id.%s" % archived,
}
},
}
},
]
def runtime_subquery(additional_cond):
return {
# the sum of
"$sum": {
# for each task
"$cond": {
# if completed and started and completed > started
"if": {
"$and": [
"$started",
"$completed",
{"$gt": ["$completed", "$started"]},
additional_cond,
]
},
# then: floor((completed - started) / 1000)
"then": {
"$floor": {
"$divide": [
{"$subtract": ["$completed", "$started"]},
1000.0,
]
}
},
"else": 0,
}
}
}
group_step = {"_id": "$project"}
for state in EntityVisibility:
if specific_state and state != specific_state:
continue
if state == EntityVisibility.active:
group_step[state.value] = runtime_subquery({"$not": archived_tasks_cond})
elif state == EntityVisibility.archived:
group_step[state.value] = runtime_subquery(archived_tasks_cond)
runtime_pipeline = [
# only count run time for these types of tasks
{
"$match": {
"type": {"$in": ["training", "testing"]},
"company": {"$in": [None, "", company_id]},
"project": {"$in": project_ids},
}
},
ensure_valid_fields(),
{
# for each project
"$group": group_step
},
]
return status_count_pipeline, runtime_pipeline
@endpoint("projects.get_all_ex")
def get_all_ex(call: APICall):
include_stats = call.data.get("include_stats")
stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value)
allow_public = not call.data.get("non_public", False)
if stats_for_state:
try:
specific_state = EntityVisibility(stats_for_state)
except ValueError:
raise errors.bad_request.FieldsValueError(stats_for_state=stats_for_state)
else:
specific_state = None
conform_tag_fields(call, call.data)
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
projects = Project.get_many_with_join(
company=call.identity.company,
query_dict=call.data,
query_options=get_all_query_options,
allow_public=allow_public,
)
conform_output_tags(call, projects)
if not include_stats:
call.result.data = {"projects": projects}
return
ids = [project["id"] for project in projects]
status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines(
call.identity.company, ids, specific_state=specific_state
)
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
def set_default_count(entry):
return dict(default_counts, **entry)
status_count = defaultdict(lambda: {})
key = itemgetter(EntityVisibility.archived.value)
for result in Task.aggregate(status_count_pipeline):
for k, group in groupby(sorted(result["counts"], key=key), key):
section = (
EntityVisibility.archived if k else EntityVisibility.active
).value
status_count[result["_id"]][section] = set_default_count(
{
count_entry["status"]: count_entry["count"]
for count_entry in group
}
)
runtime = {
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
for result in Task.aggregate(runtime_pipeline)
}
def safe_get(obj, path, default=None):
try:
return dpath.get(obj, path)
except KeyError:
return default
def get_status_counts(project_id, section):
path = "/".join((project_id, section))
return {
"total_runtime": safe_get(runtime, path, 0),
"status_count": safe_get(status_count, path, default_counts),
}
report_for_states = [
s for s in EntityVisibility if not specific_state or specific_state == s
]
for project in projects:
project["stats"] = {
task_state.value: get_status_counts(project["id"], task_state.value)
for task_state in report_for_states
}
call.result.data = {"projects": projects}
@endpoint("projects.get_all")
def get_all(call: APICall):
conform_tag_fields(call, call.data)
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
projects = Project.get_many(
company=call.identity.company,
query_dict=call.data,
query_options=get_all_query_options,
parameters=call.data,
allow_public=True,
)
conform_output_tags(call, projects)
call.result.data = {"projects": projects}
@endpoint(
"projects.create",
required_fields=["name", "description"],
response_data_model=IdResponse,
)
def create(call: APICall):
identity = call.identity
with translate_errors_context():
fields = parse_from_call(call.data, create_fields, Project.get_fields())
conform_tag_fields(call, fields, validate=True)
return IdResponse(
id=ProjectBLL.create(
user=identity.user, company=identity.company, **fields,
)
)
@endpoint(
"projects.update", required_fields=["project"], response_data_model=UpdateResponse
)
def update(call: APICall):
"""
update
:summary: Update project information.
See `project.create` for parameters.
:return: updated - `int` - number of projects updated
fields - `[string]` - updated fields
"""
project_id = call.data["project"]
with translate_errors_context():
project = Project.get_for_writing(company=call.identity.company, id=project_id)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
fields = parse_from_call(
call.data, create_fields, Project.get_fields(), discard_none_values=False
)
conform_tag_fields(call, fields, validate=True)
fields["last_update"] = datetime.utcnow()
with TimingContext("mongo", "projects_update"):
updated = project.update(upsert=False, **fields)
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@endpoint("projects.delete", required_fields=["project"])
def delete(call):
assert isinstance(call, APICall)
project_id = call.data["project"]
force = call.data.get("force", False)
with translate_errors_context():
project = Project.get_for_writing(company=call.identity.company, id=project_id)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
# NOTE: from this point on we'll use the project ID and won't check for company, since we assume we already
# have the correct project ID.
# Find the tasks which belong to the project
for cls, error in (
(Task, errors.bad_request.ProjectHasTasks),
(Model, errors.bad_request.ProjectHasModels),
):
res = cls.objects(
project=project_id, system_tags__nin=[EntityVisibility.archived.value]
).only("id")
if res and not force:
raise error("use force=true to delete", id=project_id)
updated_count = res.update(project=None)
project.delete()
call.result.data = {"deleted": 1, "disassociated_tasks": updated_count}
@endpoint("projects.get_unique_metric_variants", request_data_model=ProjectReq)
def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectReq):
metrics = task_bll.get_unique_metric_variants(
company_id, [request.project] if request.project else None
)
call.result.data = {"metrics": metrics}
@endpoint(
"projects.get_hyper_parameters",
min_version="2.9",
request_data_model=GetHyperParamReq,
)
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamReq):
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
company_id,
project_ids=[request.project] if request.project else None,
page=request.page,
page_size=request.page_size,
)
call.result.data = {
"total": total,
"remaining": remaining,
"parameters": parameters,
}
@endpoint(
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
)
def get_tags(call: APICall, company, request: ProjectTagsRequest):
ret = org_bll.get_tags(
company,
Tags.Task,
include_system=request.include_system,
filter_=get_tags_filter_dictionary(request.filter),
projects=request.projects,
)
call.result.data = get_tags_response(ret)
@endpoint(
"projects.get_model_tags", min_version="2.8", request_data_model=ProjectTagsRequest
)
def get_tags(call: APICall, company, request: ProjectTagsRequest):
ret = org_bll.get_tags(
company,
Tags.Model,
include_system=request.include_system,
filter_=get_tags_filter_dictionary(request.filter),
projects=request.projects,
)
call.result.data = get_tags_response(ret)
@endpoint(
"projects.make_public", min_version="2.9", request_data_model=MakePublicRequest
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Project.set_public(
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=True
)
@endpoint(
"projects.make_private", min_version="2.9", request_data_model=MakePublicRequest
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Project.set_public(
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=False
)