from copy import deepcopy from datetime import datetime from functools import partial from typing import Sequence, Union, Tuple, Set import attr import dpath from mongoengine import EmbeddedDocument, Q from mongoengine.queryset.transform import COMPARISON_OPERATORS from pymongo import UpdateOne from apiserver.apierrors import errors from apiserver.apierrors.errors.bad_request import InvalidTaskId from apiserver.apimodels.base import ( UpdateResponse, IdResponse, MakePublicRequest, MoveRequest, ) from apiserver.apimodels.tasks import ( StartedResponse, ResetResponse, PublishRequest, PublishResponse, CreateRequest, UpdateRequest, SetRequirementsRequest, TaskRequest, DeleteRequest, PingRequest, EnqueueRequest, EnqueueResponse, DequeueResponse, CloneRequest, AddOrUpdateArtifactsRequest, GetTypesRequest, ResetRequest, GetHyperParamsRequest, EditHyperParamsRequest, DeleteHyperParamsRequest, GetConfigurationsRequest, EditConfigurationRequest, DeleteConfigurationRequest, GetConfigurationNamesRequest, DeleteArtifactsRequest, ArchiveResponse, ArchiveRequest, AddUpdateModelRequest, DeleteModelsRequest, ModelItemType, StopManyResponse, StopManyRequest, EnqueueManyRequest, EnqueueManyResponse, ResetManyRequest, ArchiveManyRequest, ArchiveManyResponse, DeleteManyRequest, PublishManyRequest, ) from apiserver.bll.event import EventBLL from apiserver.bll.model import ModelBLL from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL, project_ids_with_children from apiserver.bll.queue import QueueBLL from apiserver.bll.task import ( TaskBLL, ChangeStatusRequest, update_project_time, ) from apiserver.bll.task.artifacts import ( artifacts_prepare_for_save, artifacts_unprepare_from_saved, Artifacts, ) from apiserver.bll.task.hyperparams import HyperParams from apiserver.bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog from apiserver.bll.task.param_utils import ( params_prepare_for_save, params_unprepare_from_saved, escape_paths, ) from apiserver.bll.task.task_cleanup import CleanupResult from apiserver.bll.task.task_operations import ( stop_task, enqueue_task, reset_task, archive_task, delete_task, publish_task, ) from apiserver.bll.task.utils import update_task, deleted_prefix, get_task_for_update from apiserver.bll.util import SetFieldsResolver, run_batch_operation from apiserver.database.errors import translate_errors_context from apiserver.database.model.task.output import Output from apiserver.database.model.task.task import ( Task, TaskStatus, Script, ModelItem, ) from apiserver.database.utils import get_fields_attr, parse_from_call, get_options from apiserver.service_repo import APICall, endpoint from apiserver.services.utils import ( conform_tag_fields, conform_output_tags, ModelsBackwardsCompatibility, DockerCmdBackwardsCompatibility, ) from apiserver.timing_context import TimingContext from apiserver.utilities.partial_version import PartialVersion task_fields = set(Task.get_fields()) task_script_stripped_fields = set( [f for f, v in get_fields_attr(Script, "strip").items() if v] ) task_bll = TaskBLL() event_bll = EventBLL() queue_bll = QueueBLL() org_bll = OrgBLL() project_bll = ProjectBLL() NonResponsiveTasksWatchdog.start() def set_task_status_from_call( request: UpdateRequest, company_id, new_status=None, **set_fields ) -> dict: fields_resolver = SetFieldsResolver(set_fields) task = TaskBLL.get_task_with_access( request.task, company_id=company_id, only=tuple( {"status", "project", "started", "duration"} | fields_resolver.get_names() ), requires_write_access=True, ) if "duration" not in fields_resolver.get_names(): if new_status == Task.started: fields_resolver.add_fields(min__duration=max(0, task.duration or 0)) elif new_status in ( TaskStatus.completed, TaskStatus.failed, TaskStatus.stopped, ): fields_resolver.add_fields( duration=int((task.started - datetime.utcnow()).total_seconds()) if task.started else 0 ) status_reason = request.status_reason status_message = request.status_message force = request.force return ChangeStatusRequest( task=task, new_status=new_status or task.status, status_reason=status_reason, status_message=status_message, force=force, ).execute(**fields_resolver.get_fields(task)) @endpoint("tasks.get_by_id", request_data_model=TaskRequest) def get_by_id(call: APICall, company_id, req_model: TaskRequest): task = TaskBLL.get_task_with_access( req_model.task, company_id=company_id, allow_public=True ) task_dict = task.to_proper_dict() unprepare_from_saved(call, task_dict) call.result.data = {"task": task_dict} def escape_execution_parameters(call: APICall) -> dict: if not call.data: return call.data keys = list(call.data) call_data = { safe_key: call.data[key] for key, safe_key in zip(keys, escape_paths(keys)) } projection = Task.get_projection(call_data) if projection: Task.set_projection(call_data, escape_paths(projection)) ordering = Task.get_ordering(call_data) if ordering: Task.set_ordering(call_data, escape_paths(ordering)) return call_data def _process_include_subprojects(call_data: dict): include_subprojects = call_data.pop("include_subprojects", False) project_ids = call_data.get("project") if not project_ids or not include_subprojects: return if not isinstance(project_ids, list): project_ids = [project_ids] call_data["project"] = project_ids_with_children(project_ids) @endpoint("tasks.get_all_ex", required_fields=[]) def get_all_ex(call: APICall, company_id, _): conform_tag_fields(call, call.data) call_data = escape_execution_parameters(call) with translate_errors_context(): with TimingContext("mongo", "task_get_all_ex"): _process_include_subprojects(call_data) tasks = Task.get_many_with_join( company=company_id, query_dict=call_data, allow_public=True, ) unprepare_from_saved(call, tasks) call.result.data = {"tasks": tasks} @endpoint("tasks.get_by_id_ex", required_fields=["id"]) def get_by_id_ex(call: APICall, company_id, _): conform_tag_fields(call, call.data) call_data = escape_execution_parameters(call) with translate_errors_context(): with TimingContext("mongo", "task_get_by_id_ex"): tasks = Task.get_many_with_join( company=company_id, query_dict=call_data, allow_public=True, ) unprepare_from_saved(call, tasks) call.result.data = {"tasks": tasks} @endpoint("tasks.get_all", required_fields=[]) def get_all(call: APICall, company_id, _): conform_tag_fields(call, call.data) call_data = escape_execution_parameters(call) with translate_errors_context(): with TimingContext("mongo", "task_get_all"): tasks = Task.get_many( company=company_id, parameters=call_data, query_dict=call_data, allow_public=True, ) unprepare_from_saved(call, tasks) call.result.data = {"tasks": tasks} @endpoint("tasks.get_types", request_data_model=GetTypesRequest) def get_types(call: APICall, company_id, request: GetTypesRequest): call.result.data = { "types": list( project_bll.get_task_types(company_id, project_ids=request.projects) ) } @endpoint( "tasks.stop", request_data_model=UpdateRequest, response_data_model=UpdateResponse ) def stop(call: APICall, company_id, req_model: UpdateRequest): """ stop :summary: Stop a running task. Requires task status 'in_progress' and execution_progress 'running', or force=True. Development task is stopped immediately. For a non-development task only its status message is set to 'stopping' """ call.result.data_model = UpdateResponse( **stop_task( task_id=req_model.task, company_id=company_id, user_name=call.identity.user_name, status_reason=req_model.status_reason, force=req_model.force, ) ) @attr.s(auto_attribs=True) class StopRes: stopped: int = 0 def __add__(self, other: dict): return StopRes(stopped=self.stopped + 1) @endpoint( "tasks.stop_many", request_data_model=StopManyRequest, response_data_model=StopManyResponse, ) def stop_many(call: APICall, company_id, request: StopManyRequest): res, failures = run_batch_operation( func=partial( stop_task, company_id=company_id, user_name=call.identity.user_name, status_reason=request.status_reason, force=request.force, ), ids=request.ids, init_res=StopRes(), ) call.result.data_model = StopManyResponse(stopped=res.stopped, failures=failures) @endpoint( "tasks.stopped", request_data_model=UpdateRequest, response_data_model=UpdateResponse, ) def stopped(call: APICall, company_id, req_model: UpdateRequest): call.result.data_model = UpdateResponse( **set_task_status_from_call( req_model, company_id, new_status=TaskStatus.stopped, completed=datetime.utcnow(), ) ) @endpoint( "tasks.started", request_data_model=UpdateRequest, response_data_model=StartedResponse, ) def started(call: APICall, company_id, req_model: UpdateRequest): res = StartedResponse( **set_task_status_from_call( req_model, company_id, new_status=TaskStatus.in_progress, min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value ) ) res.started = res.updated call.result.data_model = res @endpoint( "tasks.failed", request_data_model=UpdateRequest, response_data_model=UpdateResponse ) def failed(call: APICall, company_id, req_model: UpdateRequest): call.result.data_model = UpdateResponse( **set_task_status_from_call(req_model, company_id, new_status=TaskStatus.failed) ) @endpoint( "tasks.close", request_data_model=UpdateRequest, response_data_model=UpdateResponse ) def close(call: APICall, company_id, req_model: UpdateRequest): call.result.data_model = UpdateResponse( **set_task_status_from_call(req_model, company_id, new_status=TaskStatus.closed) ) create_fields = { "name": None, "tags": list, "system_tags": list, "type": None, "error": None, "comment": None, "parent": Task, "project": None, "input": None, "models": None, "container": None, "output_dest": None, "execution": None, "hyperparams": None, "configuration": None, "script": None, } def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None): conform_tag_fields(call, fields, validate=True) params_prepare_for_save(fields, previous_task=previous_task) artifacts_prepare_for_save(fields) ModelsBackwardsCompatibility.prepare_for_save(call, fields) DockerCmdBackwardsCompatibility.prepare_for_save(call, fields) # Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths for field in task_script_stripped_fields: try: path = f"script/{field}" value = dpath.get(fields, path) if isinstance(value, str): value = value.strip() dpath.set(fields, path, value) except KeyError: pass return fields def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict]): if isinstance(tasks_data, dict): tasks_data = [tasks_data] conform_output_tags(call, tasks_data) ModelsBackwardsCompatibility.unprepare_from_saved(call, tasks_data) DockerCmdBackwardsCompatibility.unprepare_from_saved(call, tasks_data) need_legacy_params = call.requested_endpoint_version < PartialVersion("2.9") for data in tasks_data: params_unprepare_from_saved( fields=data, copy_to_legacy=need_legacy_params, ) artifacts_unprepare_from_saved(fields=data) def prepare_create_fields( call: APICall, valid_fields=None, output=None, previous_task: Task = None, ): valid_fields = valid_fields if valid_fields is not None else create_fields t_fields = task_fields t_fields.add("output_dest") fields = parse_from_call(call.data, valid_fields, t_fields) # Move output_dest to output.destination output_dest = fields.get("output_dest") if output_dest is not None: fields.pop("output_dest") if output: output.destination = output_dest else: output = Output(destination=output_dest) fields["output"] = output # Add models updated time models = fields.get("models") if models: now = datetime.utcnow() for field in ("input", "output"): field_models = models.get(field) if not field_models: continue for model in field_models: model["updated"] = now return prepare_for_save(call, fields, previous_task=previous_task) def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dict]: with translate_errors_context( field_does_not_exist_cls=errors.bad_request.ValidationError ), TimingContext("code", "parse_call"): fields = prepare_create_fields(call, **kwargs) task = task_bll.create(call, fields) with TimingContext("code", "validate"): task_bll.validate(task) return task, fields @endpoint("tasks.validate", request_data_model=CreateRequest) def validate(call: APICall, company_id, req_model: CreateRequest): parent = call.data.get("parent") if parent and parent.startswith(deleted_prefix): call.data.pop("parent") _validate_and_get_task_from_call(call) def _update_cached_tags(company: str, project: str, fields: dict): org_bll.update_tags( company, Tags.Task, project=project, tags=fields.get("tags"), system_tags=fields.get("system_tags"), ) def _reset_cached_tags(company: str, projects: Sequence[str]): org_bll.reset_tags(company, Tags.Task, projects=projects) @endpoint( "tasks.create", request_data_model=CreateRequest, response_data_model=IdResponse ) def create(call: APICall, company_id, req_model: CreateRequest): task, fields = _validate_and_get_task_from_call(call) with translate_errors_context(), TimingContext("mongo", "save_task"): task.save() _update_cached_tags(company_id, project=task.project, fields=fields) update_project_time(task.project) call.result.data_model = IdResponse(id=task.id) @endpoint("tasks.clone", request_data_model=CloneRequest) def clone_task(call: APICall, company_id, request: CloneRequest): task, new_project = task_bll.clone_task( company_id=company_id, user_id=call.identity.user, task_id=request.task, name=request.new_task_name, comment=request.new_task_comment, parent=request.new_task_parent, project=request.new_task_project, tags=request.new_task_tags, system_tags=request.new_task_system_tags, hyperparams=request.new_task_hyperparams, configuration=request.new_task_configuration, container=request.new_task_container, execution_overrides=request.execution_overrides, input_models=request.new_task_input_models, validate_references=request.validate_references, new_project_name=request.new_project_name, ) call.result.data = { "id": task.id, **({"new_project": new_project} if new_project else {}), } def prepare_update_fields(call: APICall, task, call_data): valid_fields = deepcopy(Task.user_set_allowed()) update_fields = {k: v for k, v in create_fields.items() if k in valid_fields} update_fields["output__error"] = None t_fields = task_fields t_fields.add("output__error") fields = parse_from_call(call_data, update_fields, t_fields) return prepare_for_save(call, fields), valid_fields @endpoint( "tasks.update", request_data_model=UpdateRequest, response_data_model=UpdateResponse ) def update(call: APICall, company_id, req_model: UpdateRequest): task_id = req_model.task with translate_errors_context(): task = Task.get_for_writing( id=task_id, company=company_id, _only=["id", "project"] ) if not task: raise errors.bad_request.InvalidTaskId(id=task_id) partial_update_dict, valid_fields = prepare_update_fields(call, task, call.data) if not partial_update_dict: return UpdateResponse(updated=0) updated_count, updated_fields = Task.safe_update( company_id=company_id, id=task_id, partial_update_dict=partial_update_dict, injected_update=dict(last_change=datetime.utcnow()), ) if updated_count: new_project = updated_fields.get("project", task.project) if new_project != task.project: _reset_cached_tags(company_id, projects=[new_project, task.project]) else: _update_cached_tags( company_id, project=task.project, fields=updated_fields ) update_project_time(updated_fields.get("project")) unprepare_from_saved(call, updated_fields) return UpdateResponse(updated=updated_count, fields=updated_fields) @endpoint( "tasks.set_requirements", request_data_model=SetRequirementsRequest, response_data_model=UpdateResponse, ) def set_requirements(call: APICall, company_id, req_model: SetRequirementsRequest): requirements = req_model.requirements with translate_errors_context(): task = TaskBLL.get_task_with_access( req_model.task, company_id=company_id, only=("status", "script"), requires_write_access=True, ) if not task.script: raise errors.bad_request.MissingTaskFields( "Task has no script field", task=task.id ) res = update_task(task, update_cmds=dict(script__requirements=requirements)) call.result.data_model = UpdateResponse(updated=res) if res: call.result.data_model.fields = {"script.requirements": requirements} @endpoint("tasks.update_batch") def update_batch(call: APICall, company_id, _): items = call.batched_data if items is None: raise errors.bad_request.BatchContainsNoItems() with translate_errors_context(): items = {i["task"]: i for i in items} tasks = { t.id: t for t in Task.get_many_for_writing( company=company_id, query=Q(id__in=list(items)) ) } if len(tasks) < len(items): missing = tuple(set(items).difference(tasks)) raise errors.bad_request.InvalidTaskId(ids=missing) now = datetime.utcnow() bulk_ops = [] updated_projects = set() for id, data in items.items(): task = tasks[id] fields, valid_fields = prepare_update_fields(call, task, data) partial_update_dict = Task.get_safe_update_dict(fields) if not partial_update_dict: continue partial_update_dict.update(last_change=now) update_op = UpdateOne( {"_id": id, "company": company_id}, {"$set": partial_update_dict} ) bulk_ops.append(update_op) new_project = partial_update_dict.get("project", task.project) if new_project != task.project: updated_projects.update({new_project, task.project}) elif any(f in partial_update_dict for f in ("tags", "system_tags")): updated_projects.add(task.project) updated = 0 if bulk_ops: res = Task._get_collection().bulk_write(bulk_ops) updated = res.modified_count if updated and updated_projects: projects = list(updated_projects) _reset_cached_tags(company_id, projects=projects) update_project_time(project_ids=projects) call.result.data = {"updated": updated} @endpoint( "tasks.edit", request_data_model=UpdateRequest, response_data_model=UpdateResponse ) def edit(call: APICall, company_id, req_model: UpdateRequest): task_id = req_model.task force = req_model.force with translate_errors_context(): task = Task.get_for_writing(id=task_id, company=company_id) if not task: raise errors.bad_request.InvalidTaskId(id=task_id) if not force and task.status != TaskStatus.created: raise errors.bad_request.InvalidTaskStatus( expected=TaskStatus.created, status=task.status ) edit_fields = create_fields.copy() edit_fields.update(dict(status=None)) with translate_errors_context( field_does_not_exist_cls=errors.bad_request.ValidationError ), TimingContext("code", "parse_and_validate"): fields = prepare_create_fields( call, valid_fields=edit_fields, output=task.output, previous_task=task ) for key in fields: field = getattr(task, key, None) value = fields[key] if ( field and isinstance(value, dict) and isinstance(field, EmbeddedDocument) ): d = field.to_mongo(use_db_field=False).to_dict() d.update(value) fields[key] = d task_bll.validate(task_bll.create(call, fields)) # make sure field names do not end in mongoengine comparison operators fixed_fields = { (k if k not in COMPARISON_OPERATORS else "%s__" % k): v for k, v in fields.items() } if fixed_fields: now = datetime.utcnow() last_change = dict(last_change=now) if not set(fields).issubset(Task.user_set_allowed()): last_change.update(last_update=now) fields.update(**last_change) fixed_fields.update(**last_change) updated = task.update(upsert=False, **fixed_fields) if updated: new_project = fixed_fields.get("project", task.project) if new_project != task.project: _reset_cached_tags(company_id, projects=[new_project, task.project]) else: _update_cached_tags( company_id, project=task.project, fields=fixed_fields ) update_project_time(fields.get("project")) unprepare_from_saved(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) else: call.result.data_model = UpdateResponse(updated=0) @endpoint( "tasks.get_hyper_params", request_data_model=GetHyperParamsRequest, ) def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest): with translate_errors_context(): tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks) call.result.data = { "params": [{"task": task, **data} for task, data in tasks_params.items()] } @endpoint("tasks.edit_hyper_params", request_data_model=EditHyperParamsRequest) def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest): with translate_errors_context(): call.result.data = { "updated": HyperParams.edit_params( company_id, task_id=request.task, hyperparams=request.hyperparams, replace_hyperparams=request.replace_hyperparams, force=request.force, ) } @endpoint("tasks.delete_hyper_params", request_data_model=DeleteHyperParamsRequest) def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsRequest): with translate_errors_context(): call.result.data = { "deleted": HyperParams.delete_params( company_id, task_id=request.task, hyperparams=request.hyperparams, force=request.force, ) } @endpoint( "tasks.get_configurations", request_data_model=GetConfigurationsRequest, ) def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest): with translate_errors_context(): tasks_params = HyperParams.get_configurations( company_id, task_ids=request.tasks, names=request.names ) call.result.data = { "configurations": [ {"task": task, **data} for task, data in tasks_params.items() ] } @endpoint( "tasks.get_configuration_names", request_data_model=GetConfigurationNamesRequest, ) def get_configuration_names( call: APICall, company_id, request: GetConfigurationNamesRequest ): with translate_errors_context(): tasks_params = HyperParams.get_configuration_names( company_id, task_ids=request.tasks ) call.result.data = { "configurations": [ {"task": task, **data} for task, data in tasks_params.items() ] } @endpoint("tasks.edit_configuration", request_data_model=EditConfigurationRequest) def edit_configuration(call: APICall, company_id, request: EditConfigurationRequest): with translate_errors_context(): call.result.data = { "updated": HyperParams.edit_configuration( company_id, task_id=request.task, configuration=request.configuration, replace_configuration=request.replace_configuration, force=request.force, ) } @endpoint("tasks.delete_configuration", request_data_model=DeleteConfigurationRequest) def delete_configuration( call: APICall, company_id, request: DeleteConfigurationRequest ): with translate_errors_context(): call.result.data = { "deleted": HyperParams.delete_configuration( company_id, task_id=request.task, configuration=request.configuration, force=request.force, ) } @endpoint( "tasks.enqueue", request_data_model=EnqueueRequest, response_data_model=EnqueueResponse, ) def enqueue(call: APICall, company_id, request: EnqueueRequest): queued, res = enqueue_task( task_id=request.task, company_id=company_id, queue_id=request.queue, status_message=request.status_message, status_reason=request.status_reason, ) call.result.data_model = EnqueueResponse(queued=queued, **res) @attr.s(auto_attribs=True) class EnqueueRes: queued: int = 0 def __add__(self, other: Tuple[int, dict]): queued, _ = other return EnqueueRes(queued=self.queued + queued) @endpoint( "tasks.enqueue_many", request_data_model=EnqueueManyRequest, response_data_model=EnqueueManyResponse, ) def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest): res, failures = run_batch_operation( func=partial( enqueue_task, company_id=company_id, queue_id=request.queue, status_message=request.status_message, status_reason=request.status_reason, ), ids=request.ids, init_res=EnqueueRes(), ) call.result.data_model = EnqueueManyResponse(queued=res.queued, failures=failures) @endpoint( "tasks.dequeue", request_data_model=UpdateRequest, response_data_model=DequeueResponse, ) def dequeue(call: APICall, company_id, request: UpdateRequest): task = TaskBLL.get_task_with_access( request.task, company_id=company_id, only=("id", "execution", "status", "project"), requires_write_access=True, ) res = DequeueResponse( **TaskBLL.dequeue_and_change_status( task, company_id, status_message=request.status_message, status_reason=request.status_reason, ) ) res.dequeued = 1 call.result.data_model = res @endpoint( "tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse ) def reset(call: APICall, company_id, request: ResetRequest): dequeued, cleanup_res, updates = reset_task( task_id=request.task, company_id=company_id, force=request.force, return_file_urls=request.return_file_urls, delete_output_models=request.delete_output_models, clear_all=request.clear_all, ) res = ResetResponse(**updates, dequeued=dequeued) # do not return artifacts since they are not serializable res.fields.pop("execution.artifacts", None) for key, value in attr.asdict(cleanup_res).items(): setattr(res, key, value) call.result.data_model = res @attr.s(auto_attribs=True) class ResetRes: reset: int = 0 dequeued: int = 0 cleanup_res: CleanupResult = None def __add__(self, other: Tuple[dict, CleanupResult, dict]): dequeued, other_res, _ = other dequeued = dequeued.get("removed", 0) if dequeued else 0 return ResetRes( reset=self.reset + 1, dequeued=self.dequeued + dequeued, cleanup_res=self.cleanup_res + other_res if self.cleanup_res else other_res, ) @endpoint("tasks.reset_many", request_data_model=ResetManyRequest) def reset_many(call: APICall, company_id, request: ResetManyRequest): res, failures = run_batch_operation( func=partial( reset_task, company_id=company_id, force=request.force, return_file_urls=request.return_file_urls, delete_output_models=request.delete_output_models, clear_all=request.clear_all, ), ids=request.ids, init_res=ResetRes(), ) if res.cleanup_res: cleanup_res = dict( deleted_models=res.cleanup_res.deleted_models, urls=attr.asdict(res.cleanup_res.urls), ) else: cleanup_res = {} call.result.data = dict( reset=res.reset, dequeued=res.dequeued, **cleanup_res, failures=failures, ) @endpoint( "tasks.archive", request_data_model=ArchiveRequest, response_data_model=ArchiveResponse, ) def archive(call: APICall, company_id, request: ArchiveRequest): tasks = TaskBLL.assert_exists( company_id, task_ids=request.tasks, only=("id", "execution", "status", "project", "system_tags"), ) archived = 0 for task in tasks: archived += archive_task( company_id=company_id, task=task, status_message=request.status_message, status_reason=request.status_reason, ) call.result.data_model = ArchiveResponse(archived=archived) @endpoint( "tasks.archive_many", request_data_model=ArchiveManyRequest, response_data_model=ArchiveManyResponse, ) def archive_many(call: APICall, company_id, request: ArchiveManyRequest): archived, failures = run_batch_operation( func=partial( archive_task, company_id=company_id, status_message=request.status_message, status_reason=request.status_reason, ), ids=request.ids, init_res=0, ) call.result.data_model = ArchiveManyResponse(archived=archived, failures=failures) @endpoint("tasks.delete", request_data_model=DeleteRequest) def delete(call: APICall, company_id, request: DeleteRequest): deleted, task, cleanup_res = delete_task( task_id=request.task, company_id=company_id, move_to_trash=request.move_to_trash, force=request.force, return_file_urls=request.return_file_urls, delete_output_models=request.delete_output_models, ) if deleted: _reset_cached_tags(company_id, projects=[task.project] if task.project else []) call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res)) @attr.s(auto_attribs=True) class DeleteRes: deleted: int = 0 projects: Set = set() cleanup_res: CleanupResult = None def __add__(self, other: Tuple[int, Task, CleanupResult]): del_count, task, other_res = other return DeleteRes( deleted=self.deleted + del_count, projects=self.projects | {task.project}, cleanup_res=self.cleanup_res + other_res if self.cleanup_res else other_res, ) @endpoint("tasks.delete_many", request_data_model=DeleteManyRequest) def delete_many(call: APICall, company_id, request: DeleteManyRequest): res, failures = run_batch_operation( func=partial( delete_task, company_id=company_id, move_to_trash=request.move_to_trash, force=request.force, return_file_urls=request.return_file_urls, delete_output_models=request.delete_output_models, ), ids=request.ids, init_res=DeleteRes(), ) if res.deleted: _reset_cached_tags(company_id, projects=list(res.projects)) cleanup_res = attr.asdict(res.cleanup_res) if res.cleanup_res else {} call.result.data = dict(deleted=res.deleted, **cleanup_res, failures=failures) @endpoint( "tasks.publish", request_data_model=PublishRequest, response_data_model=PublishResponse, ) def publish(call: APICall, company_id, request: PublishRequest): updates = publish_task( task_id=request.task, company_id=company_id, force=request.force, publish_model_func=ModelBLL.publish_model if request.publish_model else None, status_reason=request.status_reason, status_message=request.status_message, ) call.result.data_model = PublishResponse(**updates) @attr.s(auto_attribs=True) class PublishRes: published: int = 0 def __add__(self, other: dict): return PublishRes(published=self.published + 1) @endpoint("tasks.publish_many", request_data_model=PublishManyRequest) def publish_many(call: APICall, company_id, request: PublishManyRequest): res, failures = run_batch_operation( func=partial( publish_task, company_id=company_id, force=request.force, publish_model_func=ModelBLL.publish_model if request.publish_model else None, status_reason=request.status_reason, status_message=request.status_message, ), ids=request.ids, init_res=PublishRes(), ) call.result.data = dict(published=res.published, failures=failures) @endpoint( "tasks.completed", min_version="2.2", request_data_model=UpdateRequest, response_data_model=UpdateResponse, ) def completed(call: APICall, company_id, request: PublishRequest): call.result.data_model = UpdateResponse( **set_task_status_from_call( request, company_id, new_status=TaskStatus.completed, completed=datetime.utcnow(), ) ) @endpoint("tasks.ping", request_data_model=PingRequest) def ping(_, company_id, request: PingRequest): TaskBLL.set_last_update( task_ids=[request.task], company_id=company_id, last_update=datetime.utcnow() ) @endpoint( "tasks.add_or_update_artifacts", min_version="2.10", request_data_model=AddOrUpdateArtifactsRequest, ) def add_or_update_artifacts( call: APICall, company_id, request: AddOrUpdateArtifactsRequest ): with translate_errors_context(): call.result.data = { "updated": Artifacts.add_or_update_artifacts( company_id=company_id, task_id=request.task, artifacts=request.artifacts, force=request.force, ) } @endpoint( "tasks.delete_artifacts", min_version="2.10", request_data_model=DeleteArtifactsRequest, ) def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest): with translate_errors_context(): call.result.data = { "deleted": Artifacts.delete_artifacts( company_id=company_id, task_id=request.task, artifact_ids=request.artifacts, force=request.force, ) } @endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest) def make_public(call: APICall, company_id, request: MakePublicRequest): with translate_errors_context(): call.result.data = Task.set_public( company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True ) @endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest) def make_public(call: APICall, company_id, request: MakePublicRequest): with translate_errors_context(): call.result.data = Task.set_public( company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False ) @endpoint("tasks.move", request_data_model=MoveRequest) def move(call: APICall, company_id: str, request: MoveRequest): if not (request.project or request.project_name): raise errors.bad_request.MissingRequiredFields( "project or project_name is required" ) updated_projects = set( t.project for t in Task.objects(id__in=request.ids).only("project") if t.project ) project_id = project_bll.move_under_project( entity_cls=Task, user=call.identity.user, company=company_id, ids=request.ids, project=request.project, project_name=request.project_name, ) projects = list(updated_projects | {project_id}) _reset_cached_tags(company_id, projects=projects) update_project_time(projects) return {"project_id": project_id} @endpoint("tasks.add_or_update_model", min_version="2.13") def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequest): get_task_for_update(company_id=company_id, task_id=request.task, force=True) models_field = f"models__{request.type}" model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow()) query = {"id": request.task, f"{models_field}__name": request.name} updated = Task.objects(**query).update_one(**{f"set__{models_field}__S": model}) updated = TaskBLL.update_statistics( task_id=request.task, company_id=company_id, last_iteration_max=request.iteration, **({f"push__{models_field}": model} if not updated else {}), ) return {"updated": updated} @endpoint("tasks.delete_models", min_version="2.13") def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest): task = get_task_for_update(company_id=company_id, task_id=request.task, force=True) delete_names = { type_: [m.name for m in request.models if m.type == type_] for type_ in get_options(ModelItemType) } commands = { f"pull__models__{field}__name__in": names for field, names in delete_names.items() if names } updated = task.update(last_change=datetime.utcnow(), **commands,) return {"updated": updated}