import textwrap from datetime import datetime from typing import Sequence from apiserver.apimodels.reports import ( CreateReportRequest, UpdateReportRequest, PublishReportRequest, ArchiveReportRequest, DeleteReportRequest, MoveReportRequest, GetTasksDataRequest, EventsRequest, GetAllRequest, ) from apiserver.apierrors import errors from apiserver.apimodels.base import UpdateResponse from apiserver.services.utils import process_include_subprojects, sort_tags_response from apiserver.bll.organization import OrgBLL from apiserver.bll.project import ProjectBLL from apiserver.bll.task import TaskBLL, ChangeStatusRequest from apiserver.database.model import EntityVisibility from apiserver.database.model.project import Project from apiserver.database.model.task.task import Task, TaskType, TaskStatus from apiserver.service_repo import APICall, endpoint from apiserver.services.events import ( _get_task_or_model_index_company, event_bll, _get_metrics_response, _get_metric_variants_from_request, _get_multitask_plots, ) from apiserver.services.tasks import ( escape_execution_parameters, _hidden_query, unprepare_from_saved, ) org_bll = OrgBLL() project_bll = ProjectBLL() task_bll = TaskBLL() reports_project_name = ".reports" reports_tag = "reports" update_fields = { "name", "tags", "comment", "report", } def _assert_report(company_id, task_id, only_fields=None, requires_write_access=True): if only_fields and "type" not in only_fields: only_fields += ("type",) task = TaskBLL.get_task_with_access( task_id=task_id, company_id=company_id, only=only_fields, requires_write_access=requires_write_access, ) if task.type != TaskType.report: raise errors.bad_request.OperationSupportedOnReportsOnly(id=task_id) return task @endpoint("reports.update", response_data_model=UpdateResponse) def update_report(call: APICall, company_id: str, request: UpdateReportRequest): task = _assert_report( task_id=request.task, company_id=company_id, only_fields=("status",), ) if task.status != TaskStatus.created: raise errors.bad_request.InvalidTaskStatus( expected=TaskStatus.created, status=task.status ) partial_update_dict = { field: value for field, value in call.data.items() if field in update_fields } if not partial_update_dict: return UpdateResponse(updated=0) now = datetime.utcnow() updated = task.update( upsert=False, **partial_update_dict, last_change=now, last_update=now, last_changed_by=call.identity.user, ) if not updated: return UpdateResponse(updated=0) updated_tags = partial_update_dict.get("tags") if updated_tags: partial_update_dict["tags"] = sorted(updated_tags) updated_report = partial_update_dict.get("report") if updated_report: partial_update_dict["report"] = textwrap.shorten(updated_report, width=100) return UpdateResponse(updated=updated, fields=partial_update_dict) def _ensure_reports_project(company: str, user: str, name: str): name = name.strip("/") _, _, basename = name.rpartition("/") if basename != reports_project_name: name = f"{name}/{reports_project_name}" return project_bll.find_or_create( user=user, company=company, project_name=name, description="Reports project", system_tags=[reports_tag, EntityVisibility.hidden.value], ) @endpoint("reports.create") def create_report(call: APICall, company_id: str, request: CreateReportRequest): user_id = call.identity.user project_id = request.project if request.project: project = Project.get_for_writing( company=company_id, id=project_id, _only=("name",) ) project_name = project.name else: project_name = "" project_id = _ensure_reports_project( company=company_id, user=user_id, name=project_name ) task = task_bll.create( company=company_id, user=user_id, fields=dict( project=project_id, name=request.name, tags=request.tags, comment=request.comment, type=TaskType.report, system_tags=[reports_tag, EntityVisibility.hidden.value], ), ) task.save() call.result.data = {"id": task.id, "project_id": project_id} def _delete_reports_project_if_empty(project_id): project = Project.objects(id=project_id).only("basename").first() if ( project and project.basename == reports_project_name and Task.objects(project=project_id).count() == 0 ): project.delete() @endpoint("reports.get_all_ex") def get_all_ex(call: APICall, company_id, request: GetAllRequest): call_data = call.data call_data["type"] = TaskType.report call_data["include_subprojects"] = True process_include_subprojects(call_data) ret_params = {} tasks = Task.get_many_with_join( company=company_id, query_dict=call_data, allow_public=request.allow_public, ret_params=ret_params, ) unprepare_from_saved(call, tasks) call.result.data = {"tasks": tasks, **ret_params} def _get_task_metrics_from_request( task_ids: Sequence[str], request: EventsRequest ) -> dict: task_metrics = {} for task in task_ids: task_dict = {} for mv in request.metrics: task_dict[mv.metric] = mv.variants task_metrics[task] = task_dict return task_metrics @endpoint("reports.get_task_data") def get_task_data(call: APICall, company_id, request: GetTasksDataRequest): call_data = escape_execution_parameters(call) process_include_subprojects(call_data) ret_params = {} tasks = Task.get_many_with_join( company=company_id, query_dict=call_data, query=_hidden_query(call_data), allow_public=request.allow_public, ret_params=ret_params, ) unprepare_from_saved(call, tasks) res = {"tasks": tasks, **ret_params} if not ( request.debug_images or request.plots or request.scalar_metrics_iter_histogram ): return res task_ids = [task["id"] for task in tasks] company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids) if request.debug_images: result = event_bll.debug_images_iterator.get_task_events( company_id=company, task_metrics=_get_task_metrics_from_request(task_ids, request.debug_images), iter_count=request.debug_images.iters, ) res["debug_images"] = [ r.to_struct() for r in _get_metrics_response(result.metric_events) ] if request.plots: res["plots"] = _get_multitask_plots( company=company, tasks_or_models=tasks_or_models, last_iters=request.plots.iters, metrics=_get_metric_variants_from_request(request.plots.metrics), )[0] if request.scalar_metrics_iter_histogram: res[ "scalar_metrics_iter_histogram" ] = event_bll.metrics.compare_scalar_metrics_average_per_iter( company_id=company_id, tasks=tasks_or_models, samples=request.scalar_metrics_iter_histogram.samples, key=request.scalar_metrics_iter_histogram.key, metric_variants=_get_metric_variants_from_request( request.scalar_metrics_iter_histogram.metrics ), ) call.result.data = res @endpoint("reports.move") def move(call: APICall, company_id: str, request: MoveReportRequest): if not (request.project or request.project_name): raise errors.bad_request.MissingRequiredFields( "project or project_name is required" ) task = _assert_report( company_id=company_id, task_id=request.task, only_fields=("project",), ) user_id = call.identity.user project_name = request.project_name if not project_name: project = Project.get_for_writing( company=company_id, id=request.project, _only=("name",) ) project_name = project.name project_id = _ensure_reports_project( company=company_id, user=user_id, name=project_name ) project_bll.move_under_project( entity_cls=Task, user=call.identity.user, company=company_id, ids=[request.task], project=project_id, ) _delete_reports_project_if_empty(task.project) return {"project_id": project_id} @endpoint( "reports.publish", response_data_model=UpdateResponse, ) def publish(call: APICall, company_id, request: PublishReportRequest): task = _assert_report(company_id=company_id, task_id=request.task) updates = ChangeStatusRequest( task=task, new_status=TaskStatus.published, force=True, status_reason="", status_message=request.message, user_id=call.identity.user, ).execute(published=datetime.utcnow()) call.result.data_model = UpdateResponse(**updates) @endpoint("reports.archive") def archive(call: APICall, company_id, request: ArchiveReportRequest): task = _assert_report(company_id=company_id, task_id=request.task) archived = task.update( status_message=request.message, status_reason="", add_to_set__system_tags=EntityVisibility.archived.value, last_change=datetime.utcnow(), last_changed_by=call.identity.user, ) return {"archived": archived} @endpoint("reports.unarchive") def unarchive(call: APICall, company_id, request: ArchiveReportRequest): task = _assert_report(company_id=company_id, task_id=request.task) unarchived = task.update( status_message=request.message, status_reason="", pull__system_tags=EntityVisibility.archived.value, last_change=datetime.utcnow(), last_changed_by=call.identity.user, ) return {"unarchived": unarchived} # @endpoint("reports.share") # def share(call: APICall, company_id, request: ShareReportRequest): # _assert_report( # company_id=company_id, user_id=call.identity.user, task_id=request.task # ) # call.result.data = { # "changed": task_bll.share_task( # company_id=company_id, task_ids=[request.task], share=request.share # ) # } @endpoint("reports.delete") def delete(call: APICall, company_id, request: DeleteReportRequest): task = _assert_report( company_id=company_id, task_id=request.task, only_fields=("project",), ) if ( task.status != TaskStatus.created and EntityVisibility.archived.value not in task.system_tags and not request.force ): raise errors.bad_request.TaskCannotBeDeleted( "due to status, use force=True", task=task.id, expected=TaskStatus.created, current=task.status, ) task.delete() _delete_reports_project_if_empty(task.project) call.result.data = {"deleted": 1} @endpoint("reports.get_tags") def get_tags(call: APICall, company_id: str, _): tags = Task.objects(company=company_id, type=TaskType.report).distinct(field="tags") call.result.data = sort_tags_response({"tags": tags})