clearml-server/server/services/projects.py

378 lines
12 KiB
Python

from collections import defaultdict
from datetime import datetime
from itertools import groupby
from operator import itemgetter
import dpath
from mongoengine import Q
import database
from apierrors import errors
from apimodels.base import UpdateResponse
from apimodels.projects import GetHyperParamReq, GetHyperParamResp, ProjectReq
from bll.task import TaskBLL
from database.errors import translate_errors_context
from database.model import EntityVisibility
from database.model.model import Model
from database.model.project import Project
from database.model.task.task import Task, TaskStatus
from database.utils import parse_from_call, get_options, get_company_or_none_constraint
from service_repo import APICall, endpoint
from services.utils import conform_tag_fields, conform_output_tags
from timing_context import TimingContext
task_bll = TaskBLL()
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
create_fields = {
"name": None,
"description": None,
"tags": list,
"system_tags": list,
"default_output_destination": None,
}
get_all_query_options = Project.QueryParameterOptions(
pattern_fields=("name", "description"),
list_fields=("tags", "system_tags", "id"),
)
@endpoint("projects.get_by_id", required_fields=["project"])
def get_by_id(call):
assert isinstance(call, APICall)
project_id = call.data["project"]
with translate_errors_context():
with TimingContext("mongo", "projects_by_id"):
query = Q(id=project_id) & get_company_or_none_constraint(
call.identity.company
)
project = Project.objects(query).first()
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
project_dict = project.to_proper_dict()
conform_output_tags(call, project_dict)
call.result.data = {"project": project_dict}
def make_projects_get_all_pipelines(project_ids, specific_state=None):
archived = EntityVisibility.archived.value
def ensure_system_tags():
"""
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
"""
return {
"$addFields": {
"system_tags": {
"$cond": {
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
"then": [],
"else": "$system_tags",
}
}
}
}
status_count_pipeline = [
# count tasks per project per status
{"$match": {"project": {"$in": project_ids}}},
ensure_system_tags(),
{
"$group": {
"_id": {
"project": "$project",
"status": "$status",
archived: archived_tasks_cond,
},
"count": {"$sum": 1},
}
},
# for each project, create a list of (status, count, archived)
{
"$group": {
"_id": "$_id.project",
"counts": {
"$push": {
"status": "$_id.status",
"count": "$count",
archived: "$_id.%s" % archived,
}
},
}
},
]
def runtime_subquery(additional_cond):
return {
# the sum of
"$sum": {
# for each task
"$cond": {
# if completed and started and completed > started
"if": {
"$and": [
"$started",
"$completed",
{"$gt": ["$completed", "$started"]},
additional_cond,
]
},
# then: floor((completed - started) / 1000)
"then": {
"$floor": {
"$divide": [
{"$subtract": ["$completed", "$started"]},
1000.0,
]
}
},
"else": 0,
}
}
}
group_step = {"_id": "$project"}
for state in EntityVisibility:
if specific_state and state != specific_state:
continue
if state == EntityVisibility.active:
group_step[state.value] = runtime_subquery({"$not": archived_tasks_cond})
elif state == EntityVisibility.archived:
group_step[state.value] = runtime_subquery(archived_tasks_cond)
runtime_pipeline = [
# only count run time for these types of tasks
{
"$match": {
"type": {"$in": ["training", "testing", "annotation"]},
"project": {"$in": project_ids},
}
},
ensure_system_tags(),
{
# for each project
"$group": group_step
},
]
return status_count_pipeline, runtime_pipeline
@endpoint("projects.get_all_ex")
def get_all_ex(call: APICall):
include_stats = call.data.get("include_stats")
stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value)
if stats_for_state:
try:
specific_state = EntityVisibility(stats_for_state)
except ValueError:
raise errors.bad_request.FieldsValueError(stats_for_state=stats_for_state)
else:
specific_state = None
conform_tag_fields(call, call.data)
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
projects = Project.get_many_with_join(
company=call.identity.company,
query_dict=call.data,
query_options=get_all_query_options,
allow_public=True,
)
conform_output_tags(call, projects)
if not include_stats:
call.result.data = {"projects": projects}
return
ids = [project["id"] for project in projects]
status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines(
ids, specific_state=specific_state
)
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
def set_default_count(entry):
return dict(default_counts, **entry)
status_count = defaultdict(lambda: {})
key = itemgetter(EntityVisibility.archived.value)
for result in Task.aggregate(*status_count_pipeline):
for k, group in groupby(sorted(result["counts"], key=key), key):
section = (
EntityVisibility.archived if k else EntityVisibility.active
).value
status_count[result["_id"]][section] = set_default_count(
{
count_entry["status"]: count_entry["count"]
for count_entry in group
}
)
runtime = {
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
for result in Task.aggregate(*runtime_pipeline)
}
def safe_get(obj, path, default=None):
try:
return dpath.get(obj, path)
except KeyError:
return default
def get_status_counts(project_id, section):
path = "/".join((project_id, section))
return {
"total_runtime": safe_get(runtime, path, 0),
"status_count": safe_get(status_count, path, default_counts),
}
report_for_states = [
s for s in EntityVisibility if not specific_state or specific_state == s
]
for project in projects:
project["stats"] = {
task_state.value: get_status_counts(project["id"], task_state.value)
for task_state in report_for_states
}
call.result.data = {"projects": projects}
@endpoint("projects.get_all")
def get_all(call: APICall):
conform_tag_fields(call, call.data)
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
projects = Project.get_many(
company=call.identity.company,
query_dict=call.data,
query_options=get_all_query_options,
parameters=call.data,
allow_public=True,
)
conform_output_tags(call, projects)
call.result.data = {"projects": projects}
@endpoint("projects.create", required_fields=["name", "description"])
def create(call):
assert isinstance(call, APICall)
identity = call.identity
with translate_errors_context():
fields = parse_from_call(call.data, create_fields, Project.get_fields())
conform_tag_fields(call, fields)
now = datetime.utcnow()
project = Project(
id=database.utils.id(),
user=identity.user,
company=identity.company,
created=now,
last_update=now,
**fields
)
with TimingContext("mongo", "projects_save"):
project.save()
call.result.data = {"id": project.id}
@endpoint(
"projects.update", required_fields=["project"], response_data_model=UpdateResponse
)
def update(call: APICall):
"""
update
:summary: Update project information.
See `project.create` for parameters.
:return: updated - `int` - number of projects updated
fields - `[string]` - updated fields
"""
project_id = call.data["project"]
with translate_errors_context():
project = Project.get_for_writing(company=call.identity.company, id=project_id)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
fields = parse_from_call(
call.data, create_fields, Project.get_fields(), discard_none_values=False
)
conform_tag_fields(call, fields)
fields["last_update"] = datetime.utcnow()
with TimingContext("mongo", "projects_update"):
updated = project.update(upsert=False, **fields)
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@endpoint("projects.delete", required_fields=["project"])
def delete(call):
assert isinstance(call, APICall)
project_id = call.data["project"]
force = call.data.get("force", False)
with translate_errors_context():
project = Project.get_for_writing(company=call.identity.company, id=project_id)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
# NOTE: from this point on we'll use the project ID and won't check for company, since we assume we already
# have the correct project ID.
# Find the tasks which belong to the project
for cls, error in (
(Task, errors.bad_request.ProjectHasTasks),
(Model, errors.bad_request.ProjectHasModels),
):
res = cls.objects(
project=project_id, system_tags__nin=[EntityVisibility.archived.value]
).only("id")
if res and not force:
raise error("use force=true to delete", id=project_id)
updated_count = res.update(project=None)
project.delete()
call.result.data = {"deleted": 1, "disassociated_tasks": updated_count}
@endpoint("projects.get_unique_metric_variants", request_data_model=ProjectReq)
def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectReq):
metrics = task_bll.get_unique_metric_variants(
company_id, [request.project] if request.project else None
)
call.result.data = {"metrics": metrics}
@endpoint(
"projects.get_hyper_parameters",
min_version="2.2",
request_data_model=GetHyperParamReq,
response_data_model=GetHyperParamResp,
)
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamReq):
total, remaining, parameters = TaskBLL.get_aggregated_project_execution_parameters(
company_id,
project_ids=[request.project] if request.project else None,
page=request.page,
page_size=request.page_size,
)
call.result.data = {
"total": total,
"remaining": remaining,
"parameters": parameters,
}