clearml-server/apiserver/services/reports.py

370 lines
11 KiB
Python

import textwrap
from datetime import datetime
from typing import Sequence
from apiserver.apimodels.reports import (
CreateReportRequest,
UpdateReportRequest,
PublishReportRequest,
ArchiveReportRequest,
DeleteReportRequest,
MoveReportRequest,
GetTasksDataRequest,
EventsRequest,
GetAllRequest,
)
from apiserver.apierrors import errors
from apiserver.apimodels.base import UpdateResponse
from apiserver.services.utils import process_include_subprojects, sort_tags_response
from apiserver.bll.organization import OrgBLL
from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL, ChangeStatusRequest
from apiserver.database.model import EntityVisibility
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskType, TaskStatus
from apiserver.service_repo import APICall, endpoint
from apiserver.services.events import (
_get_task_or_model_index_company,
event_bll,
_get_metrics_response,
_get_metric_variants_from_request,
_get_multitask_plots,
)
from apiserver.services.tasks import (
escape_execution_parameters,
_hidden_query,
unprepare_from_saved,
)
org_bll = OrgBLL()
project_bll = ProjectBLL()
task_bll = TaskBLL()
reports_project_name = ".reports"
reports_tag = "reports"
update_fields = {
"name",
"tags",
"comment",
"report",
}
def _assert_report(company_id, task_id, only_fields=None, requires_write_access=True):
if only_fields and "type" not in only_fields:
only_fields += ("type",)
task = TaskBLL.get_task_with_access(
task_id=task_id,
company_id=company_id,
only=only_fields,
requires_write_access=requires_write_access,
)
if task.type != TaskType.report:
raise errors.bad_request.OperationSupportedOnReportsOnly(id=task_id)
return task
@endpoint("reports.update", response_data_model=UpdateResponse)
def update_report(call: APICall, company_id: str, request: UpdateReportRequest):
task = _assert_report(
task_id=request.task, company_id=company_id, only_fields=("status",),
)
if task.status != TaskStatus.created:
raise errors.bad_request.InvalidTaskStatus(
expected=TaskStatus.created, status=task.status
)
partial_update_dict = {
field: value for field, value in call.data.items() if field in update_fields
}
if not partial_update_dict:
return UpdateResponse(updated=0)
now = datetime.utcnow()
updated = task.update(
upsert=False,
**partial_update_dict,
last_change=now,
last_update=now,
last_changed_by=call.identity.user,
)
if not updated:
return UpdateResponse(updated=0)
updated_tags = partial_update_dict.get("tags")
if updated_tags:
partial_update_dict["tags"] = sorted(updated_tags)
updated_report = partial_update_dict.get("report")
if updated_report:
partial_update_dict["report"] = textwrap.shorten(updated_report, width=100)
return UpdateResponse(updated=updated, fields=partial_update_dict)
def _ensure_reports_project(company: str, user: str, name: str):
name = name.strip("/")
_, _, basename = name.rpartition("/")
if basename != reports_project_name:
name = f"{name}/{reports_project_name}"
return project_bll.find_or_create(
user=user,
company=company,
project_name=name,
description="Reports project",
system_tags=[reports_tag, EntityVisibility.hidden.value],
)
@endpoint("reports.create")
def create_report(call: APICall, company_id: str, request: CreateReportRequest):
user_id = call.identity.user
project_id = request.project
if request.project:
project = Project.get_for_writing(
company=company_id, id=project_id, _only=("name",)
)
project_name = project.name
else:
project_name = ""
project_id = _ensure_reports_project(
company=company_id, user=user_id, name=project_name
)
task = task_bll.create(
company=company_id,
user=user_id,
fields=dict(
project=project_id,
name=request.name,
tags=request.tags,
comment=request.comment,
type=TaskType.report,
system_tags=[reports_tag, EntityVisibility.hidden.value],
),
)
task.save()
call.result.data = {"id": task.id, "project_id": project_id}
def _delete_reports_project_if_empty(project_id):
project = Project.objects(id=project_id).only("basename").first()
if (
project
and project.basename == reports_project_name
and Task.objects(project=project_id).count() == 0
):
project.delete()
@endpoint("reports.get_all_ex")
def get_all_ex(call: APICall, company_id, request: GetAllRequest):
call_data = call.data
call_data["type"] = TaskType.report
call_data["include_subprojects"] = True
process_include_subprojects(call_data)
ret_params = {}
tasks = Task.get_many_with_join(
company=company_id,
query_dict=call_data,
allow_public=request.allow_public,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks, **ret_params}
def _get_task_metrics_from_request(
task_ids: Sequence[str], request: EventsRequest
) -> dict:
task_metrics = {}
for task in task_ids:
task_dict = {}
for mv in request.metrics:
task_dict[mv.metric] = mv.variants
task_metrics[task] = task_dict
return task_metrics
@endpoint("reports.get_task_data")
def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
call_data = escape_execution_parameters(call)
process_include_subprojects(call_data)
ret_params = {}
tasks = Task.get_many_with_join(
company=company_id,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=request.allow_public,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
res = {"tasks": tasks, **ret_params}
if not (
request.debug_images or request.plots or request.scalar_metrics_iter_histogram
):
return res
task_ids = [task["id"] for task in tasks]
company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids)
if request.debug_images:
result = event_bll.debug_images_iterator.get_task_events(
company_id=company,
task_metrics=_get_task_metrics_from_request(task_ids, request.debug_images),
iter_count=request.debug_images.iters,
)
res["debug_images"] = [
r.to_struct() for r in _get_metrics_response(result.metric_events)
]
if request.plots:
res["plots"] = _get_multitask_plots(
company=company,
tasks_or_models=tasks_or_models,
last_iters=request.plots.iters,
metrics=_get_metric_variants_from_request(request.plots.metrics),
)[0]
if request.scalar_metrics_iter_histogram:
res[
"scalar_metrics_iter_histogram"
] = event_bll.metrics.compare_scalar_metrics_average_per_iter(
company_id=company_id,
tasks=tasks_or_models,
samples=request.scalar_metrics_iter_histogram.samples,
key=request.scalar_metrics_iter_histogram.key,
metric_variants=_get_metric_variants_from_request(
request.scalar_metrics_iter_histogram.metrics
),
)
call.result.data = res
@endpoint("reports.move")
def move(call: APICall, company_id: str, request: MoveReportRequest):
if not (request.project or request.project_name):
raise errors.bad_request.MissingRequiredFields(
"project or project_name is required"
)
task = _assert_report(
company_id=company_id, task_id=request.task, only_fields=("project",),
)
user_id = call.identity.user
project_name = request.project_name
if not project_name:
project = Project.get_for_writing(
company=company_id, id=request.project, _only=("name",)
)
project_name = project.name
project_id = _ensure_reports_project(
company=company_id, user=user_id, name=project_name
)
project_bll.move_under_project(
entity_cls=Task,
user=call.identity.user,
company=company_id,
ids=[request.task],
project=project_id,
)
_delete_reports_project_if_empty(task.project)
return {"project_id": project_id}
@endpoint(
"reports.publish", response_data_model=UpdateResponse,
)
def publish(call: APICall, company_id, request: PublishReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
updates = ChangeStatusRequest(
task=task,
new_status=TaskStatus.published,
force=True,
status_reason="",
status_message=request.message,
user_id=call.identity.user,
).execute(published=datetime.utcnow())
call.result.data_model = UpdateResponse(**updates)
@endpoint("reports.archive")
def archive(call: APICall, company_id, request: ArchiveReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
archived = task.update(
status_message=request.message,
status_reason="",
add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=call.identity.user,
)
return {"archived": archived}
@endpoint("reports.unarchive")
def unarchive(call: APICall, company_id, request: ArchiveReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
unarchived = task.update(
status_message=request.message,
status_reason="",
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=call.identity.user,
)
return {"unarchived": unarchived}
# @endpoint("reports.share")
# def share(call: APICall, company_id, request: ShareReportRequest):
# _assert_report(
# company_id=company_id, user_id=call.identity.user, task_id=request.task
# )
# call.result.data = {
# "changed": task_bll.share_task(
# company_id=company_id, task_ids=[request.task], share=request.share
# )
# }
@endpoint("reports.delete")
def delete(call: APICall, company_id, request: DeleteReportRequest):
task = _assert_report(
company_id=company_id, task_id=request.task, only_fields=("project",),
)
if (
task.status != TaskStatus.created
and EntityVisibility.archived.value not in task.system_tags
and not request.force
):
raise errors.bad_request.TaskCannotBeDeleted(
"due to status, use force=True",
task=task.id,
expected=TaskStatus.created,
current=task.status,
)
task.delete()
_delete_reports_project_if_empty(task.project)
call.result.data = {"deleted": 1}
@endpoint("reports.get_tags")
def get_tags(call: APICall, company_id: str, _):
tags = Task.objects(company=company_id, type=TaskType.report).distinct(field="tags")
call.result.data = sort_tags_response({"tags": tags})