mirror of
https://github.com/clearml/clearml-server
synced 2025-02-01 03:16:44 +00:00
baba8b5b73
Add initial support for project ordering Add support for sortable task duration (used by the UI in the experiment's table) Add support for project name in worker's current task info Add support for results and artifacts in pre-populates examples Add demo server features
446 lines
14 KiB
Python
446 lines
14 KiB
Python
from collections import defaultdict
|
|
from datetime import datetime
|
|
from itertools import groupby
|
|
from operator import itemgetter
|
|
|
|
import dpath
|
|
from mongoengine import Q
|
|
|
|
import database
|
|
from apierrors import errors
|
|
from apierrors.errors.bad_request import InvalidProjectId
|
|
from apimodels.base import UpdateResponse, MakePublicRequest
|
|
from apimodels.projects import (
|
|
GetHyperParamReq,
|
|
GetHyperParamResp,
|
|
ProjectReq,
|
|
ProjectTagsRequest,
|
|
)
|
|
from bll.organization import OrgBLL, Tags
|
|
from bll.task import TaskBLL
|
|
from database.errors import translate_errors_context
|
|
from database.model import EntityVisibility
|
|
from database.model.model import Model
|
|
from database.model.project import Project
|
|
from database.model.task.task import Task, TaskStatus
|
|
from database.utils import parse_from_call, get_options, get_company_or_none_constraint
|
|
from service_repo import APICall, endpoint
|
|
from services.utils import (
|
|
conform_tag_fields,
|
|
conform_output_tags,
|
|
get_tags_filter_dictionary,
|
|
get_tags_response,
|
|
)
|
|
from timing_context import TimingContext
|
|
|
|
org_bll = OrgBLL()
|
|
task_bll = TaskBLL()
|
|
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
|
|
|
|
create_fields = {
|
|
"name": None,
|
|
"description": None,
|
|
"tags": list,
|
|
"system_tags": list,
|
|
"default_output_destination": None,
|
|
}
|
|
|
|
get_all_query_options = Project.QueryParameterOptions(
|
|
pattern_fields=("name", "description"), list_fields=("tags", "system_tags", "id"),
|
|
)
|
|
|
|
|
|
@endpoint("projects.get_by_id", required_fields=["project"])
|
|
def get_by_id(call):
|
|
assert isinstance(call, APICall)
|
|
project_id = call.data["project"]
|
|
|
|
with translate_errors_context():
|
|
with TimingContext("mongo", "projects_by_id"):
|
|
query = Q(id=project_id) & get_company_or_none_constraint(
|
|
call.identity.company
|
|
)
|
|
project = Project.objects(query).first()
|
|
if not project:
|
|
raise errors.bad_request.InvalidProjectId(id=project_id)
|
|
|
|
project_dict = project.to_proper_dict()
|
|
conform_output_tags(call, project_dict)
|
|
|
|
call.result.data = {"project": project_dict}
|
|
|
|
|
|
def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None):
|
|
archived = EntityVisibility.archived.value
|
|
|
|
def ensure_valid_fields():
|
|
"""
|
|
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
|
|
"""
|
|
return {
|
|
"$addFields": {
|
|
"system_tags": {
|
|
"$cond": {
|
|
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
|
|
"then": [],
|
|
"else": "$system_tags",
|
|
}
|
|
},
|
|
"status": {"$ifNull": ["$status", "unknown"]},
|
|
}
|
|
}
|
|
|
|
status_count_pipeline = [
|
|
# count tasks per project per status
|
|
{
|
|
"$match": {
|
|
"company": {"$in": [None, "", company_id]},
|
|
"project": {"$in": project_ids},
|
|
}
|
|
},
|
|
ensure_valid_fields(),
|
|
{
|
|
"$group": {
|
|
"_id": {
|
|
"project": "$project",
|
|
"status": "$status",
|
|
archived: archived_tasks_cond,
|
|
},
|
|
"count": {"$sum": 1},
|
|
}
|
|
},
|
|
# for each project, create a list of (status, count, archived)
|
|
{
|
|
"$group": {
|
|
"_id": "$_id.project",
|
|
"counts": {
|
|
"$push": {
|
|
"status": "$_id.status",
|
|
"count": "$count",
|
|
archived: "$_id.%s" % archived,
|
|
}
|
|
},
|
|
}
|
|
},
|
|
]
|
|
|
|
def runtime_subquery(additional_cond):
|
|
return {
|
|
# the sum of
|
|
"$sum": {
|
|
# for each task
|
|
"$cond": {
|
|
# if completed and started and completed > started
|
|
"if": {
|
|
"$and": [
|
|
"$started",
|
|
"$completed",
|
|
{"$gt": ["$completed", "$started"]},
|
|
additional_cond,
|
|
]
|
|
},
|
|
# then: floor((completed - started) / 1000)
|
|
"then": {
|
|
"$floor": {
|
|
"$divide": [
|
|
{"$subtract": ["$completed", "$started"]},
|
|
1000.0,
|
|
]
|
|
}
|
|
},
|
|
"else": 0,
|
|
}
|
|
}
|
|
}
|
|
|
|
group_step = {"_id": "$project"}
|
|
|
|
for state in EntityVisibility:
|
|
if specific_state and state != specific_state:
|
|
continue
|
|
if state == EntityVisibility.active:
|
|
group_step[state.value] = runtime_subquery({"$not": archived_tasks_cond})
|
|
elif state == EntityVisibility.archived:
|
|
group_step[state.value] = runtime_subquery(archived_tasks_cond)
|
|
|
|
runtime_pipeline = [
|
|
# only count run time for these types of tasks
|
|
{
|
|
"$match": {
|
|
"type": {"$in": ["training", "testing"]},
|
|
"company": {"$in": [None, "", company_id]},
|
|
"project": {"$in": project_ids},
|
|
}
|
|
},
|
|
ensure_valid_fields(),
|
|
{
|
|
# for each project
|
|
"$group": group_step
|
|
},
|
|
]
|
|
|
|
return status_count_pipeline, runtime_pipeline
|
|
|
|
|
|
@endpoint("projects.get_all_ex")
|
|
def get_all_ex(call: APICall):
|
|
include_stats = call.data.get("include_stats")
|
|
stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value)
|
|
allow_public = not call.data.get("non_public", False)
|
|
|
|
if stats_for_state:
|
|
try:
|
|
specific_state = EntityVisibility(stats_for_state)
|
|
except ValueError:
|
|
raise errors.bad_request.FieldsValueError(stats_for_state=stats_for_state)
|
|
else:
|
|
specific_state = None
|
|
|
|
conform_tag_fields(call, call.data)
|
|
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
|
|
projects = Project.get_many_with_join(
|
|
company=call.identity.company,
|
|
query_dict=call.data,
|
|
query_options=get_all_query_options,
|
|
allow_public=allow_public,
|
|
)
|
|
conform_output_tags(call, projects)
|
|
|
|
if not include_stats:
|
|
call.result.data = {"projects": projects}
|
|
return
|
|
|
|
ids = [project["id"] for project in projects]
|
|
status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines(
|
|
call.identity.company, ids, specific_state=specific_state
|
|
)
|
|
|
|
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
|
|
|
|
def set_default_count(entry):
|
|
return dict(default_counts, **entry)
|
|
|
|
status_count = defaultdict(lambda: {})
|
|
key = itemgetter(EntityVisibility.archived.value)
|
|
for result in Task.aggregate(status_count_pipeline):
|
|
for k, group in groupby(sorted(result["counts"], key=key), key):
|
|
section = (
|
|
EntityVisibility.archived if k else EntityVisibility.active
|
|
).value
|
|
status_count[result["_id"]][section] = set_default_count(
|
|
{
|
|
count_entry["status"]: count_entry["count"]
|
|
for count_entry in group
|
|
}
|
|
)
|
|
|
|
runtime = {
|
|
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
|
|
for result in Task.aggregate(runtime_pipeline)
|
|
}
|
|
|
|
def safe_get(obj, path, default=None):
|
|
try:
|
|
return dpath.get(obj, path)
|
|
except KeyError:
|
|
return default
|
|
|
|
def get_status_counts(project_id, section):
|
|
path = "/".join((project_id, section))
|
|
return {
|
|
"total_runtime": safe_get(runtime, path, 0),
|
|
"status_count": safe_get(status_count, path, default_counts),
|
|
}
|
|
|
|
report_for_states = [
|
|
s for s in EntityVisibility if not specific_state or specific_state == s
|
|
]
|
|
|
|
for project in projects:
|
|
project["stats"] = {
|
|
task_state.value: get_status_counts(project["id"], task_state.value)
|
|
for task_state in report_for_states
|
|
}
|
|
|
|
call.result.data = {"projects": projects}
|
|
|
|
|
|
@endpoint("projects.get_all")
|
|
def get_all(call: APICall):
|
|
conform_tag_fields(call, call.data)
|
|
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
|
|
projects = Project.get_many(
|
|
company=call.identity.company,
|
|
query_dict=call.data,
|
|
query_options=get_all_query_options,
|
|
parameters=call.data,
|
|
allow_public=True,
|
|
)
|
|
conform_output_tags(call, projects)
|
|
|
|
call.result.data = {"projects": projects}
|
|
|
|
|
|
@endpoint("projects.create", required_fields=["name", "description"])
|
|
def create(call):
|
|
assert isinstance(call, APICall)
|
|
identity = call.identity
|
|
|
|
with translate_errors_context():
|
|
fields = parse_from_call(call.data, create_fields, Project.get_fields())
|
|
conform_tag_fields(call, fields, validate=True)
|
|
now = datetime.utcnow()
|
|
project = Project(
|
|
id=database.utils.id(),
|
|
user=identity.user,
|
|
company=identity.company,
|
|
created=now,
|
|
last_update=now,
|
|
**fields
|
|
)
|
|
with TimingContext("mongo", "projects_save"):
|
|
project.save()
|
|
call.result.data = {"id": project.id}
|
|
|
|
|
|
@endpoint(
|
|
"projects.update", required_fields=["project"], response_data_model=UpdateResponse
|
|
)
|
|
def update(call: APICall):
|
|
"""
|
|
update
|
|
|
|
:summary: Update project information.
|
|
See `project.create` for parameters.
|
|
:return: updated - `int` - number of projects updated
|
|
fields - `[string]` - updated fields
|
|
"""
|
|
project_id = call.data["project"]
|
|
|
|
with translate_errors_context():
|
|
project = Project.get_for_writing(company=call.identity.company, id=project_id)
|
|
if not project:
|
|
raise errors.bad_request.InvalidProjectId(id=project_id)
|
|
|
|
fields = parse_from_call(
|
|
call.data, create_fields, Project.get_fields(), discard_none_values=False
|
|
)
|
|
conform_tag_fields(call, fields, validate=True)
|
|
fields["last_update"] = datetime.utcnow()
|
|
with TimingContext("mongo", "projects_update"):
|
|
updated = project.update(upsert=False, **fields)
|
|
conform_output_tags(call, fields)
|
|
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
|
|
|
|
|
@endpoint("projects.delete", required_fields=["project"])
|
|
def delete(call):
|
|
assert isinstance(call, APICall)
|
|
project_id = call.data["project"]
|
|
force = call.data.get("force", False)
|
|
|
|
with translate_errors_context():
|
|
project = Project.get_for_writing(company=call.identity.company, id=project_id)
|
|
if not project:
|
|
raise errors.bad_request.InvalidProjectId(id=project_id)
|
|
|
|
# NOTE: from this point on we'll use the project ID and won't check for company, since we assume we already
|
|
# have the correct project ID.
|
|
|
|
# Find the tasks which belong to the project
|
|
for cls, error in (
|
|
(Task, errors.bad_request.ProjectHasTasks),
|
|
(Model, errors.bad_request.ProjectHasModels),
|
|
):
|
|
res = cls.objects(
|
|
project=project_id, system_tags__nin=[EntityVisibility.archived.value]
|
|
).only("id")
|
|
if res and not force:
|
|
raise error("use force=true to delete", id=project_id)
|
|
|
|
updated_count = res.update(project=None)
|
|
|
|
project.delete()
|
|
|
|
call.result.data = {"deleted": 1, "disassociated_tasks": updated_count}
|
|
|
|
|
|
@endpoint("projects.get_unique_metric_variants", request_data_model=ProjectReq)
|
|
def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectReq):
|
|
|
|
metrics = task_bll.get_unique_metric_variants(
|
|
company_id, [request.project] if request.project else None
|
|
)
|
|
|
|
call.result.data = {"metrics": metrics}
|
|
|
|
|
|
@endpoint(
|
|
"projects.get_hyper_parameters",
|
|
min_version="2.2",
|
|
request_data_model=GetHyperParamReq,
|
|
response_data_model=GetHyperParamResp,
|
|
)
|
|
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamReq):
|
|
|
|
total, remaining, parameters = TaskBLL.get_aggregated_project_execution_parameters(
|
|
company_id,
|
|
project_ids=[request.project] if request.project else None,
|
|
page=request.page,
|
|
page_size=request.page_size,
|
|
)
|
|
|
|
call.result.data = {
|
|
"total": total,
|
|
"remaining": remaining,
|
|
"parameters": parameters,
|
|
}
|
|
|
|
|
|
@endpoint(
|
|
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
|
|
)
|
|
def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
|
ret = org_bll.get_tags(
|
|
company,
|
|
Tags.Task,
|
|
include_system=request.include_system,
|
|
filter_=get_tags_filter_dictionary(request.filter),
|
|
projects=request.projects,
|
|
)
|
|
call.result.data = get_tags_response(ret)
|
|
|
|
|
|
@endpoint(
|
|
"projects.get_model_tags", min_version="2.8", request_data_model=ProjectTagsRequest
|
|
)
|
|
def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
|
ret = org_bll.get_tags(
|
|
company,
|
|
Tags.Model,
|
|
include_system=request.include_system,
|
|
filter_=get_tags_filter_dictionary(request.filter),
|
|
projects=request.projects,
|
|
)
|
|
call.result.data = get_tags_response(ret)
|
|
|
|
|
|
@endpoint(
|
|
"projects.make_public", min_version="2.9", request_data_model=MakePublicRequest
|
|
)
|
|
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
|
with translate_errors_context():
|
|
call.result.data = Project.set_public(
|
|
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=True
|
|
)
|
|
|
|
|
|
@endpoint(
|
|
"projects.make_private", min_version="2.9", request_data_model=MakePublicRequest
|
|
)
|
|
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
|
with translate_errors_context():
|
|
call.result.data = Project.set_public(
|
|
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=False
|
|
)
|