mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 10:56:48 +00:00
481 lines
15 KiB
Python
481 lines
15 KiB
Python
from collections import defaultdict
|
|
from datetime import datetime
|
|
from itertools import groupby
|
|
from operator import itemgetter
|
|
|
|
import dpath
|
|
from mongoengine import Q
|
|
|
|
from apiserver.apierrors import errors
|
|
from apiserver.apierrors.errors.bad_request import InvalidProjectId
|
|
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
|
|
from apiserver.apimodels.projects import (
|
|
GetHyperParamReq,
|
|
ProjectReq,
|
|
ProjectTagsRequest,
|
|
ProjectTaskParentsRequest,
|
|
ProjectHyperparamValuesRequest,
|
|
)
|
|
from apiserver.bll.organization import OrgBLL, Tags
|
|
from apiserver.bll.project import ProjectBLL
|
|
from apiserver.bll.task import TaskBLL
|
|
from apiserver.database.errors import translate_errors_context
|
|
from apiserver.database.model import EntityVisibility
|
|
from apiserver.database.model.model import Model
|
|
from apiserver.database.model.project import Project
|
|
from apiserver.database.model.task.task import Task, TaskStatus
|
|
from apiserver.database.utils import (
|
|
parse_from_call,
|
|
get_options,
|
|
get_company_or_none_constraint,
|
|
)
|
|
from apiserver.service_repo import APICall, endpoint
|
|
from apiserver.services.utils import (
|
|
conform_tag_fields,
|
|
conform_output_tags,
|
|
get_tags_filter_dictionary,
|
|
get_tags_response,
|
|
)
|
|
from apiserver.timing_context import TimingContext
|
|
|
|
org_bll = OrgBLL()
|
|
task_bll = TaskBLL()
|
|
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
|
|
|
|
create_fields = {
|
|
"name": None,
|
|
"description": None,
|
|
"tags": list,
|
|
"system_tags": list,
|
|
"default_output_destination": None,
|
|
}
|
|
|
|
get_all_query_options = Project.QueryParameterOptions(
|
|
pattern_fields=("name", "description"), list_fields=("tags", "system_tags", "id"),
|
|
)
|
|
|
|
|
|
@endpoint("projects.get_by_id", required_fields=["project"])
|
|
def get_by_id(call):
|
|
assert isinstance(call, APICall)
|
|
project_id = call.data["project"]
|
|
|
|
with translate_errors_context():
|
|
with TimingContext("mongo", "projects_by_id"):
|
|
query = Q(id=project_id) & get_company_or_none_constraint(
|
|
call.identity.company
|
|
)
|
|
project = Project.objects(query).first()
|
|
if not project:
|
|
raise errors.bad_request.InvalidProjectId(id=project_id)
|
|
|
|
project_dict = project.to_proper_dict()
|
|
conform_output_tags(call, project_dict)
|
|
|
|
call.result.data = {"project": project_dict}
|
|
|
|
|
|
def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None):
|
|
archived = EntityVisibility.archived.value
|
|
|
|
def ensure_valid_fields():
|
|
"""
|
|
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
|
|
"""
|
|
return {
|
|
"$addFields": {
|
|
"system_tags": {
|
|
"$cond": {
|
|
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
|
|
"then": [],
|
|
"else": "$system_tags",
|
|
}
|
|
},
|
|
"status": {"$ifNull": ["$status", "unknown"]},
|
|
}
|
|
}
|
|
|
|
status_count_pipeline = [
|
|
# count tasks per project per status
|
|
{
|
|
"$match": {
|
|
"company": {"$in": [None, "", company_id]},
|
|
"project": {"$in": project_ids},
|
|
}
|
|
},
|
|
ensure_valid_fields(),
|
|
{
|
|
"$group": {
|
|
"_id": {
|
|
"project": "$project",
|
|
"status": "$status",
|
|
archived: archived_tasks_cond,
|
|
},
|
|
"count": {"$sum": 1},
|
|
}
|
|
},
|
|
# for each project, create a list of (status, count, archived)
|
|
{
|
|
"$group": {
|
|
"_id": "$_id.project",
|
|
"counts": {
|
|
"$push": {
|
|
"status": "$_id.status",
|
|
"count": "$count",
|
|
archived: "$_id.%s" % archived,
|
|
}
|
|
},
|
|
}
|
|
},
|
|
]
|
|
|
|
def runtime_subquery(additional_cond):
|
|
return {
|
|
# the sum of
|
|
"$sum": {
|
|
# for each task
|
|
"$cond": {
|
|
# if completed and started and completed > started
|
|
"if": {
|
|
"$and": [
|
|
"$started",
|
|
"$completed",
|
|
{"$gt": ["$completed", "$started"]},
|
|
additional_cond,
|
|
]
|
|
},
|
|
# then: floor((completed - started) / 1000)
|
|
"then": {
|
|
"$floor": {
|
|
"$divide": [
|
|
{"$subtract": ["$completed", "$started"]},
|
|
1000.0,
|
|
]
|
|
}
|
|
},
|
|
"else": 0,
|
|
}
|
|
}
|
|
}
|
|
|
|
group_step = {"_id": "$project"}
|
|
|
|
for state in EntityVisibility:
|
|
if specific_state and state != specific_state:
|
|
continue
|
|
if state == EntityVisibility.active:
|
|
group_step[state.value] = runtime_subquery({"$not": archived_tasks_cond})
|
|
elif state == EntityVisibility.archived:
|
|
group_step[state.value] = runtime_subquery(archived_tasks_cond)
|
|
|
|
runtime_pipeline = [
|
|
# only count run time for these types of tasks
|
|
{
|
|
"$match": {
|
|
"type": {"$in": ["training", "testing"]},
|
|
"company": {"$in": [None, "", company_id]},
|
|
"project": {"$in": project_ids},
|
|
}
|
|
},
|
|
ensure_valid_fields(),
|
|
{
|
|
# for each project
|
|
"$group": group_step
|
|
},
|
|
]
|
|
|
|
return status_count_pipeline, runtime_pipeline
|
|
|
|
|
|
@endpoint("projects.get_all_ex")
|
|
def get_all_ex(call: APICall):
|
|
include_stats = call.data.get("include_stats")
|
|
stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value)
|
|
allow_public = not call.data.get("non_public", False)
|
|
|
|
if stats_for_state:
|
|
try:
|
|
specific_state = EntityVisibility(stats_for_state)
|
|
except ValueError:
|
|
raise errors.bad_request.FieldsValueError(stats_for_state=stats_for_state)
|
|
else:
|
|
specific_state = None
|
|
|
|
conform_tag_fields(call, call.data)
|
|
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
|
|
projects = Project.get_many_with_join(
|
|
company=call.identity.company,
|
|
query_dict=call.data,
|
|
query_options=get_all_query_options,
|
|
allow_public=allow_public,
|
|
)
|
|
conform_output_tags(call, projects)
|
|
|
|
if not include_stats:
|
|
call.result.data = {"projects": projects}
|
|
return
|
|
|
|
ids = [project["id"] for project in projects]
|
|
status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines(
|
|
call.identity.company, ids, specific_state=specific_state
|
|
)
|
|
|
|
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
|
|
|
|
def set_default_count(entry):
|
|
return dict(default_counts, **entry)
|
|
|
|
status_count = defaultdict(lambda: {})
|
|
key = itemgetter(EntityVisibility.archived.value)
|
|
for result in Task.aggregate(status_count_pipeline):
|
|
for k, group in groupby(sorted(result["counts"], key=key), key):
|
|
section = (
|
|
EntityVisibility.archived if k else EntityVisibility.active
|
|
).value
|
|
status_count[result["_id"]][section] = set_default_count(
|
|
{
|
|
count_entry["status"]: count_entry["count"]
|
|
for count_entry in group
|
|
}
|
|
)
|
|
|
|
runtime = {
|
|
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
|
|
for result in Task.aggregate(runtime_pipeline)
|
|
}
|
|
|
|
def safe_get(obj, path, default=None):
|
|
try:
|
|
return dpath.get(obj, path)
|
|
except KeyError:
|
|
return default
|
|
|
|
def get_status_counts(project_id, section):
|
|
path = "/".join((project_id, section))
|
|
return {
|
|
"total_runtime": safe_get(runtime, path, 0),
|
|
"status_count": safe_get(status_count, path, default_counts),
|
|
}
|
|
|
|
report_for_states = [
|
|
s for s in EntityVisibility if not specific_state or specific_state == s
|
|
]
|
|
|
|
for project in projects:
|
|
project["stats"] = {
|
|
task_state.value: get_status_counts(project["id"], task_state.value)
|
|
for task_state in report_for_states
|
|
}
|
|
|
|
call.result.data = {"projects": projects}
|
|
|
|
|
|
@endpoint("projects.get_all")
|
|
def get_all(call: APICall):
|
|
conform_tag_fields(call, call.data)
|
|
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
|
|
projects = Project.get_many(
|
|
company=call.identity.company,
|
|
query_dict=call.data,
|
|
query_options=get_all_query_options,
|
|
parameters=call.data,
|
|
allow_public=True,
|
|
)
|
|
conform_output_tags(call, projects)
|
|
|
|
call.result.data = {"projects": projects}
|
|
|
|
|
|
@endpoint(
|
|
"projects.create",
|
|
required_fields=["name", "description"],
|
|
response_data_model=IdResponse,
|
|
)
|
|
def create(call: APICall):
|
|
identity = call.identity
|
|
|
|
with translate_errors_context():
|
|
fields = parse_from_call(call.data, create_fields, Project.get_fields())
|
|
conform_tag_fields(call, fields, validate=True)
|
|
|
|
return IdResponse(
|
|
id=ProjectBLL.create(
|
|
user=identity.user, company=identity.company, **fields,
|
|
)
|
|
)
|
|
|
|
|
|
@endpoint(
|
|
"projects.update", required_fields=["project"], response_data_model=UpdateResponse
|
|
)
|
|
def update(call: APICall):
|
|
"""
|
|
update
|
|
|
|
:summary: Update project information.
|
|
See `project.create` for parameters.
|
|
:return: updated - `int` - number of projects updated
|
|
fields - `[string]` - updated fields
|
|
"""
|
|
project_id = call.data["project"]
|
|
|
|
with translate_errors_context():
|
|
project = Project.get_for_writing(company=call.identity.company, id=project_id)
|
|
if not project:
|
|
raise errors.bad_request.InvalidProjectId(id=project_id)
|
|
|
|
fields = parse_from_call(
|
|
call.data, create_fields, Project.get_fields(), discard_none_values=False
|
|
)
|
|
conform_tag_fields(call, fields, validate=True)
|
|
fields["last_update"] = datetime.utcnow()
|
|
with TimingContext("mongo", "projects_update"):
|
|
updated = project.update(upsert=False, **fields)
|
|
conform_output_tags(call, fields)
|
|
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
|
|
|
|
|
@endpoint("projects.delete", required_fields=["project"])
|
|
def delete(call):
|
|
assert isinstance(call, APICall)
|
|
project_id = call.data["project"]
|
|
force = call.data.get("force", False)
|
|
|
|
with translate_errors_context():
|
|
project = Project.get_for_writing(company=call.identity.company, id=project_id)
|
|
if not project:
|
|
raise errors.bad_request.InvalidProjectId(id=project_id)
|
|
|
|
# NOTE: from this point on we'll use the project ID and won't check for company, since we assume we already
|
|
# have the correct project ID.
|
|
|
|
# Find the tasks which belong to the project
|
|
for cls, error in (
|
|
(Task, errors.bad_request.ProjectHasTasks),
|
|
(Model, errors.bad_request.ProjectHasModels),
|
|
):
|
|
res = cls.objects(
|
|
project=project_id, system_tags__nin=[EntityVisibility.archived.value]
|
|
).only("id")
|
|
if res and not force:
|
|
raise error("use force=true to delete", id=project_id)
|
|
|
|
updated_count = res.update(project=None)
|
|
|
|
project.delete()
|
|
|
|
call.result.data = {"deleted": 1, "disassociated_tasks": updated_count}
|
|
|
|
|
|
@endpoint("projects.get_unique_metric_variants", request_data_model=ProjectReq)
|
|
def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectReq):
|
|
|
|
metrics = task_bll.get_unique_metric_variants(
|
|
company_id, [request.project] if request.project else None
|
|
)
|
|
|
|
call.result.data = {"metrics": metrics}
|
|
|
|
|
|
@endpoint(
|
|
"projects.get_hyper_parameters",
|
|
min_version="2.9",
|
|
request_data_model=GetHyperParamReq,
|
|
)
|
|
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamReq):
|
|
|
|
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
|
|
company_id,
|
|
project_ids=[request.project] if request.project else None,
|
|
page=request.page,
|
|
page_size=request.page_size,
|
|
)
|
|
|
|
call.result.data = {
|
|
"total": total,
|
|
"remaining": remaining,
|
|
"parameters": parameters,
|
|
}
|
|
|
|
|
|
@endpoint(
|
|
"projects.get_hyperparam_values",
|
|
min_version="2.13",
|
|
request_data_model=ProjectHyperparamValuesRequest,
|
|
)
|
|
def get_hyperparam_values(
|
|
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
|
|
):
|
|
total, values = task_bll.get_hyperparam_distinct_values(
|
|
company_id,
|
|
project_ids=request.projects,
|
|
section=request.section,
|
|
name=request.name,
|
|
allow_public=request.allow_public,
|
|
)
|
|
call.result.data = {
|
|
"total": total,
|
|
"values": values,
|
|
}
|
|
|
|
|
|
@endpoint(
|
|
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
|
|
)
|
|
def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
|
ret = org_bll.get_tags(
|
|
company,
|
|
Tags.Task,
|
|
include_system=request.include_system,
|
|
filter_=get_tags_filter_dictionary(request.filter),
|
|
projects=request.projects,
|
|
)
|
|
call.result.data = get_tags_response(ret)
|
|
|
|
|
|
@endpoint(
|
|
"projects.get_model_tags", min_version="2.8", request_data_model=ProjectTagsRequest
|
|
)
|
|
def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
|
ret = org_bll.get_tags(
|
|
company,
|
|
Tags.Model,
|
|
include_system=request.include_system,
|
|
filter_=get_tags_filter_dictionary(request.filter),
|
|
projects=request.projects,
|
|
)
|
|
call.result.data = get_tags_response(ret)
|
|
|
|
|
|
@endpoint(
|
|
"projects.make_public", min_version="2.9", request_data_model=MakePublicRequest
|
|
)
|
|
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
|
call.result.data = Project.set_public(
|
|
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=True
|
|
)
|
|
|
|
|
|
@endpoint(
|
|
"projects.make_private", min_version="2.9", request_data_model=MakePublicRequest
|
|
)
|
|
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
|
call.result.data = Project.set_public(
|
|
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=False
|
|
)
|
|
|
|
|
|
@endpoint(
|
|
"projects.get_task_parents",
|
|
min_version="2.12",
|
|
request_data_model=ProjectTaskParentsRequest,
|
|
)
|
|
def get_task_parents(
|
|
call: APICall, company_id: str, request: ProjectTaskParentsRequest
|
|
):
|
|
call.result.data = {
|
|
"parents": org_bll.get_parent_tasks(
|
|
company_id, projects=request.projects, state=request.tasks_state
|
|
)
|
|
}
|