mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 10:56:48 +00:00
bca3a6e556
Move endpoints to new API version Support tasks.move and models.move for moving tasks and models into projects Support new project name in tasks.clone Improve task active duration migration
441 lines
14 KiB
Python
441 lines
14 KiB
Python
from collections import defaultdict
|
|
from datetime import datetime
|
|
from itertools import groupby
|
|
from operator import itemgetter
|
|
|
|
import dpath
|
|
from mongoengine import Q
|
|
|
|
from apiserver.apierrors import errors
|
|
from apiserver.apierrors.errors.bad_request import InvalidProjectId
|
|
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
|
|
from apiserver.apimodels.projects import (
|
|
GetHyperParamReq,
|
|
ProjectReq,
|
|
ProjectTagsRequest,
|
|
)
|
|
from apiserver.bll.organization import OrgBLL, Tags
|
|
from apiserver.bll.project import ProjectBLL
|
|
from apiserver.bll.task import TaskBLL
|
|
from apiserver.database.errors import translate_errors_context
|
|
from apiserver.database.model import EntityVisibility
|
|
from apiserver.database.model.model import Model
|
|
from apiserver.database.model.project import Project
|
|
from apiserver.database.model.task.task import Task, TaskStatus
|
|
from apiserver.database.utils import parse_from_call, get_options, get_company_or_none_constraint
|
|
from apiserver.service_repo import APICall, endpoint
|
|
from apiserver.services.utils import (
|
|
conform_tag_fields,
|
|
conform_output_tags,
|
|
get_tags_filter_dictionary,
|
|
get_tags_response,
|
|
)
|
|
from apiserver.timing_context import TimingContext
|
|
|
|
org_bll = OrgBLL()
|
|
task_bll = TaskBLL()
|
|
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
|
|
|
|
create_fields = {
|
|
"name": None,
|
|
"description": None,
|
|
"tags": list,
|
|
"system_tags": list,
|
|
"default_output_destination": None,
|
|
}
|
|
|
|
get_all_query_options = Project.QueryParameterOptions(
|
|
pattern_fields=("name", "description"), list_fields=("tags", "system_tags", "id"),
|
|
)
|
|
|
|
|
|
@endpoint("projects.get_by_id", required_fields=["project"])
|
|
def get_by_id(call):
|
|
assert isinstance(call, APICall)
|
|
project_id = call.data["project"]
|
|
|
|
with translate_errors_context():
|
|
with TimingContext("mongo", "projects_by_id"):
|
|
query = Q(id=project_id) & get_company_or_none_constraint(
|
|
call.identity.company
|
|
)
|
|
project = Project.objects(query).first()
|
|
if not project:
|
|
raise errors.bad_request.InvalidProjectId(id=project_id)
|
|
|
|
project_dict = project.to_proper_dict()
|
|
conform_output_tags(call, project_dict)
|
|
|
|
call.result.data = {"project": project_dict}
|
|
|
|
|
|
def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None):
|
|
archived = EntityVisibility.archived.value
|
|
|
|
def ensure_valid_fields():
|
|
"""
|
|
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
|
|
"""
|
|
return {
|
|
"$addFields": {
|
|
"system_tags": {
|
|
"$cond": {
|
|
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
|
|
"then": [],
|
|
"else": "$system_tags",
|
|
}
|
|
},
|
|
"status": {"$ifNull": ["$status", "unknown"]},
|
|
}
|
|
}
|
|
|
|
status_count_pipeline = [
|
|
# count tasks per project per status
|
|
{
|
|
"$match": {
|
|
"company": {"$in": [None, "", company_id]},
|
|
"project": {"$in": project_ids},
|
|
}
|
|
},
|
|
ensure_valid_fields(),
|
|
{
|
|
"$group": {
|
|
"_id": {
|
|
"project": "$project",
|
|
"status": "$status",
|
|
archived: archived_tasks_cond,
|
|
},
|
|
"count": {"$sum": 1},
|
|
}
|
|
},
|
|
# for each project, create a list of (status, count, archived)
|
|
{
|
|
"$group": {
|
|
"_id": "$_id.project",
|
|
"counts": {
|
|
"$push": {
|
|
"status": "$_id.status",
|
|
"count": "$count",
|
|
archived: "$_id.%s" % archived,
|
|
}
|
|
},
|
|
}
|
|
},
|
|
]
|
|
|
|
def runtime_subquery(additional_cond):
|
|
return {
|
|
# the sum of
|
|
"$sum": {
|
|
# for each task
|
|
"$cond": {
|
|
# if completed and started and completed > started
|
|
"if": {
|
|
"$and": [
|
|
"$started",
|
|
"$completed",
|
|
{"$gt": ["$completed", "$started"]},
|
|
additional_cond,
|
|
]
|
|
},
|
|
# then: floor((completed - started) / 1000)
|
|
"then": {
|
|
"$floor": {
|
|
"$divide": [
|
|
{"$subtract": ["$completed", "$started"]},
|
|
1000.0,
|
|
]
|
|
}
|
|
},
|
|
"else": 0,
|
|
}
|
|
}
|
|
}
|
|
|
|
group_step = {"_id": "$project"}
|
|
|
|
for state in EntityVisibility:
|
|
if specific_state and state != specific_state:
|
|
continue
|
|
if state == EntityVisibility.active:
|
|
group_step[state.value] = runtime_subquery({"$not": archived_tasks_cond})
|
|
elif state == EntityVisibility.archived:
|
|
group_step[state.value] = runtime_subquery(archived_tasks_cond)
|
|
|
|
runtime_pipeline = [
|
|
# only count run time for these types of tasks
|
|
{
|
|
"$match": {
|
|
"type": {"$in": ["training", "testing"]},
|
|
"company": {"$in": [None, "", company_id]},
|
|
"project": {"$in": project_ids},
|
|
}
|
|
},
|
|
ensure_valid_fields(),
|
|
{
|
|
# for each project
|
|
"$group": group_step
|
|
},
|
|
]
|
|
|
|
return status_count_pipeline, runtime_pipeline
|
|
|
|
|
|
@endpoint("projects.get_all_ex")
|
|
def get_all_ex(call: APICall):
|
|
include_stats = call.data.get("include_stats")
|
|
stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value)
|
|
allow_public = not call.data.get("non_public", False)
|
|
|
|
if stats_for_state:
|
|
try:
|
|
specific_state = EntityVisibility(stats_for_state)
|
|
except ValueError:
|
|
raise errors.bad_request.FieldsValueError(stats_for_state=stats_for_state)
|
|
else:
|
|
specific_state = None
|
|
|
|
conform_tag_fields(call, call.data)
|
|
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
|
|
projects = Project.get_many_with_join(
|
|
company=call.identity.company,
|
|
query_dict=call.data,
|
|
query_options=get_all_query_options,
|
|
allow_public=allow_public,
|
|
)
|
|
conform_output_tags(call, projects)
|
|
|
|
if not include_stats:
|
|
call.result.data = {"projects": projects}
|
|
return
|
|
|
|
ids = [project["id"] for project in projects]
|
|
status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines(
|
|
call.identity.company, ids, specific_state=specific_state
|
|
)
|
|
|
|
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
|
|
|
|
def set_default_count(entry):
|
|
return dict(default_counts, **entry)
|
|
|
|
status_count = defaultdict(lambda: {})
|
|
key = itemgetter(EntityVisibility.archived.value)
|
|
for result in Task.aggregate(status_count_pipeline):
|
|
for k, group in groupby(sorted(result["counts"], key=key), key):
|
|
section = (
|
|
EntityVisibility.archived if k else EntityVisibility.active
|
|
).value
|
|
status_count[result["_id"]][section] = set_default_count(
|
|
{
|
|
count_entry["status"]: count_entry["count"]
|
|
for count_entry in group
|
|
}
|
|
)
|
|
|
|
runtime = {
|
|
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
|
|
for result in Task.aggregate(runtime_pipeline)
|
|
}
|
|
|
|
def safe_get(obj, path, default=None):
|
|
try:
|
|
return dpath.get(obj, path)
|
|
except KeyError:
|
|
return default
|
|
|
|
def get_status_counts(project_id, section):
|
|
path = "/".join((project_id, section))
|
|
return {
|
|
"total_runtime": safe_get(runtime, path, 0),
|
|
"status_count": safe_get(status_count, path, default_counts),
|
|
}
|
|
|
|
report_for_states = [
|
|
s for s in EntityVisibility if not specific_state or specific_state == s
|
|
]
|
|
|
|
for project in projects:
|
|
project["stats"] = {
|
|
task_state.value: get_status_counts(project["id"], task_state.value)
|
|
for task_state in report_for_states
|
|
}
|
|
|
|
call.result.data = {"projects": projects}
|
|
|
|
|
|
@endpoint("projects.get_all")
|
|
def get_all(call: APICall):
|
|
conform_tag_fields(call, call.data)
|
|
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
|
|
projects = Project.get_many(
|
|
company=call.identity.company,
|
|
query_dict=call.data,
|
|
query_options=get_all_query_options,
|
|
parameters=call.data,
|
|
allow_public=True,
|
|
)
|
|
conform_output_tags(call, projects)
|
|
|
|
call.result.data = {"projects": projects}
|
|
|
|
|
|
@endpoint(
|
|
"projects.create",
|
|
required_fields=["name", "description"],
|
|
response_data_model=IdResponse,
|
|
)
|
|
def create(call: APICall):
|
|
identity = call.identity
|
|
|
|
with translate_errors_context():
|
|
fields = parse_from_call(call.data, create_fields, Project.get_fields())
|
|
conform_tag_fields(call, fields, validate=True)
|
|
|
|
return IdResponse(
|
|
id=ProjectBLL.create(
|
|
user=identity.user, company=identity.company, **fields,
|
|
)
|
|
)
|
|
|
|
|
|
@endpoint(
|
|
"projects.update", required_fields=["project"], response_data_model=UpdateResponse
|
|
)
|
|
def update(call: APICall):
|
|
"""
|
|
update
|
|
|
|
:summary: Update project information.
|
|
See `project.create` for parameters.
|
|
:return: updated - `int` - number of projects updated
|
|
fields - `[string]` - updated fields
|
|
"""
|
|
project_id = call.data["project"]
|
|
|
|
with translate_errors_context():
|
|
project = Project.get_for_writing(company=call.identity.company, id=project_id)
|
|
if not project:
|
|
raise errors.bad_request.InvalidProjectId(id=project_id)
|
|
|
|
fields = parse_from_call(
|
|
call.data, create_fields, Project.get_fields(), discard_none_values=False
|
|
)
|
|
conform_tag_fields(call, fields, validate=True)
|
|
fields["last_update"] = datetime.utcnow()
|
|
with TimingContext("mongo", "projects_update"):
|
|
updated = project.update(upsert=False, **fields)
|
|
conform_output_tags(call, fields)
|
|
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
|
|
|
|
|
@endpoint("projects.delete", required_fields=["project"])
|
|
def delete(call):
|
|
assert isinstance(call, APICall)
|
|
project_id = call.data["project"]
|
|
force = call.data.get("force", False)
|
|
|
|
with translate_errors_context():
|
|
project = Project.get_for_writing(company=call.identity.company, id=project_id)
|
|
if not project:
|
|
raise errors.bad_request.InvalidProjectId(id=project_id)
|
|
|
|
# NOTE: from this point on we'll use the project ID and won't check for company, since we assume we already
|
|
# have the correct project ID.
|
|
|
|
# Find the tasks which belong to the project
|
|
for cls, error in (
|
|
(Task, errors.bad_request.ProjectHasTasks),
|
|
(Model, errors.bad_request.ProjectHasModels),
|
|
):
|
|
res = cls.objects(
|
|
project=project_id, system_tags__nin=[EntityVisibility.archived.value]
|
|
).only("id")
|
|
if res and not force:
|
|
raise error("use force=true to delete", id=project_id)
|
|
|
|
updated_count = res.update(project=None)
|
|
|
|
project.delete()
|
|
|
|
call.result.data = {"deleted": 1, "disassociated_tasks": updated_count}
|
|
|
|
|
|
@endpoint("projects.get_unique_metric_variants", request_data_model=ProjectReq)
|
|
def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectReq):
|
|
|
|
metrics = task_bll.get_unique_metric_variants(
|
|
company_id, [request.project] if request.project else None
|
|
)
|
|
|
|
call.result.data = {"metrics": metrics}
|
|
|
|
|
|
@endpoint(
|
|
"projects.get_hyper_parameters",
|
|
min_version="2.9",
|
|
request_data_model=GetHyperParamReq,
|
|
)
|
|
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamReq):
|
|
|
|
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
|
|
company_id,
|
|
project_ids=[request.project] if request.project else None,
|
|
page=request.page,
|
|
page_size=request.page_size,
|
|
)
|
|
|
|
call.result.data = {
|
|
"total": total,
|
|
"remaining": remaining,
|
|
"parameters": parameters,
|
|
}
|
|
|
|
|
|
@endpoint(
|
|
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
|
|
)
|
|
def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
|
ret = org_bll.get_tags(
|
|
company,
|
|
Tags.Task,
|
|
include_system=request.include_system,
|
|
filter_=get_tags_filter_dictionary(request.filter),
|
|
projects=request.projects,
|
|
)
|
|
call.result.data = get_tags_response(ret)
|
|
|
|
|
|
@endpoint(
|
|
"projects.get_model_tags", min_version="2.8", request_data_model=ProjectTagsRequest
|
|
)
|
|
def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
|
ret = org_bll.get_tags(
|
|
company,
|
|
Tags.Model,
|
|
include_system=request.include_system,
|
|
filter_=get_tags_filter_dictionary(request.filter),
|
|
projects=request.projects,
|
|
)
|
|
call.result.data = get_tags_response(ret)
|
|
|
|
|
|
@endpoint(
|
|
"projects.make_public", min_version="2.9", request_data_model=MakePublicRequest
|
|
)
|
|
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
|
with translate_errors_context():
|
|
call.result.data = Project.set_public(
|
|
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=True
|
|
)
|
|
|
|
|
|
@endpoint(
|
|
"projects.make_private", min_version="2.9", request_data_model=MakePublicRequest
|
|
)
|
|
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
|
with translate_errors_context():
|
|
call.result.data = Project.set_public(
|
|
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=False
|
|
)
|