From a75534ec3445dff1a33bac87d665b27ecbf496c7 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 3 May 2021 17:52:54 +0300 Subject: [PATCH] Add batch operations support --- apiserver/apimodels/batch.py | 14 + apiserver/apimodels/models.py | 29 +- apiserver/apimodels/tasks.py | 51 ++- apiserver/bll/model/__init__.py | 116 +++++ apiserver/bll/task/task_bll.py | 145 +----- apiserver/bll/task/task_cleanup.py | 29 +- apiserver/bll/task/task_operations.py | 329 ++++++++++++++ apiserver/bll/util.py | 38 +- apiserver/mongo/initialize/pre_populate.py | 1 + apiserver/schema/services/_common.conf | 42 ++ apiserver/schema/services/models.conf | 127 ++++-- apiserver/schema/services/tasks.conf | 381 ++++++++-------- apiserver/services/models.py | 239 ++++++---- apiserver/services/tasks.py | 426 ++++++++++-------- apiserver/services/utils.py | 17 +- apiserver/tests/automated/__init__.py | 10 - .../tests/automated/test_batch_operations.py | 124 +++++ 17 files changed, 1444 insertions(+), 674 deletions(-) create mode 100644 apiserver/apimodels/batch.py create mode 100644 apiserver/bll/model/__init__.py create mode 100644 apiserver/bll/task/task_operations.py create mode 100644 apiserver/tests/automated/test_batch_operations.py diff --git a/apiserver/apimodels/batch.py b/apiserver/apimodels/batch.py new file mode 100644 index 0000000..01a1402 --- /dev/null +++ b/apiserver/apimodels/batch.py @@ -0,0 +1,14 @@ +from typing import Sequence + +from jsonmodels.models import Base +from jsonmodels.validators import Length + +from apiserver.apimodels import ListField + + +class BatchRequest(Base): + ids: Sequence[str] = ListField([str], validators=Length(minimum_value=1)) + + +class BatchResponse(Base): + failures: Sequence[dict] = ListField([dict]) diff --git a/apiserver/apimodels/models.py b/apiserver/apimodels/models.py index dbe9176..30a0971 100644 --- a/apiserver/apimodels/models.py +++ b/apiserver/apimodels/models.py @@ -3,6 +3,7 @@ from six import string_types from apiserver.apimodels import ListField, DictField from apiserver.apimodels.base import UpdateResponse +from apiserver.apimodels.batch import BatchRequest, BatchResponse from apiserver.apimodels.metadata import ( MetadataItem, DeleteMetadata, @@ -46,6 +47,23 @@ class DeleteModelRequest(ModelRequest): force = fields.BoolField(default=False) +class ModelsDeleteManyRequest(BatchRequest): + force = fields.BoolField(default=False) + + +class ModelsArchiveManyRequest(BatchRequest): + pass + + +class ModelsArchiveManyResponse(BatchResponse): + archived = fields.IntField(required=True) + + +class ModelsDeleteManyResponse(BatchResponse): + deleted = fields.IntField() + urls = fields.ListField([str]) + + class PublishModelRequest(ModelRequest): force_publish_task = fields.BoolField(default=False) publish_task = fields.BoolField(default=True) @@ -58,7 +76,16 @@ class ModelTaskPublishResponse(models.Base): class PublishModelResponse(UpdateResponse): published_task = fields.EmbeddedField(ModelTaskPublishResponse) - updated = fields.IntField() + + +class ModelsPublishManyRequest(BatchRequest): + force_publish_task = fields.BoolField(default=False) + publish_task = fields.BoolField(default=True) + + +class ModelsPublishManyResponse(BatchResponse): + published = fields.IntField(required=True) + published_tasks = fields.ListField([ModelTaskPublishResponse]) class DeleteMetadataRequest(DeleteMetadata): diff --git a/apiserver/apimodels/tasks.py b/apiserver/apimodels/tasks.py index 52611ce..39f7430 100644 --- a/apiserver/apimodels/tasks.py +++ b/apiserver/apimodels/tasks.py @@ -7,6 +7,7 @@ from jsonmodels.validators import Enum, Length from apiserver.apimodels import DictField, ListField from apiserver.apimodels.base import UpdateResponse +from apiserver.apimodels.batch import BatchRequest, BatchResponse from apiserver.database.model.task.task import ( TaskType, ArtifactModes, @@ -52,7 +53,7 @@ class ResetResponse(UpdateResponse): dequeued = DictField() frames = DictField() events = DictField() - model_deleted = IntField() + deleted_models = IntField() urls = DictField() @@ -230,6 +231,54 @@ class ArchiveResponse(models.Base): archived = IntField() +class TaskBatchRequest(BatchRequest): + status_reason = StringField(default="") + status_message = StringField(default="") + + +class StopManyRequest(TaskBatchRequest): + force = BoolField(default=False) + + +class StopManyResponse(BatchResponse): + stopped = IntField(required=True) + + +class ArchiveManyRequest(TaskBatchRequest): + pass + + +class ArchiveManyResponse(BatchResponse): + archived = IntField(required=True) + + +class EnqueueManyRequest(TaskBatchRequest): + queue = StringField() + + +class EnqueueManyResponse(BatchResponse): + queued = IntField() + + +class DeleteManyRequest(TaskBatchRequest): + move_to_trash = BoolField(default=True) + return_file_urls = BoolField(default=False) + delete_output_models = BoolField(default=True) + force = BoolField(default=False) + + +class ResetManyRequest(TaskBatchRequest): + clear_all = BoolField(default=False) + return_file_urls = BoolField(default=False) + delete_output_models = BoolField(default=True) + force = BoolField(default=False) + + +class PublishManyRequest(TaskBatchRequest): + publish_model = BoolField(default=True) + force = BoolField(default=False) + + class ModelItemType(object): input = "input" output = "output" diff --git a/apiserver/bll/model/__init__.py b/apiserver/bll/model/__init__.py new file mode 100644 index 0000000..e98d3bd --- /dev/null +++ b/apiserver/bll/model/__init__.py @@ -0,0 +1,116 @@ +from datetime import datetime +from typing import Callable, Tuple + +from apiserver.apierrors import errors +from apiserver.apimodels.models import ModelTaskPublishResponse +from apiserver.bll.task.utils import deleted_prefix +from apiserver.database.model import EntityVisibility +from apiserver.database.model.model import Model +from apiserver.database.model.task.task import Task, TaskStatus + + +class ModelBLL: + @classmethod + def get_company_model_by_id( + cls, company_id: str, model_id: str, only_fields=None + ) -> Model: + query = dict(company=company_id, id=model_id) + qs = Model.objects(**query) + if only_fields: + qs = qs.only(*only_fields) + model = qs.first() + if not model: + raise errors.bad_request.InvalidModelId(**query) + return model + + @classmethod + def publish_model( + cls, + model_id: str, + company_id: str, + force_publish_task: bool = False, + publish_task_func: Callable[[str, str, bool], dict] = None, + ) -> Tuple[int, ModelTaskPublishResponse]: + model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id) + if model.ready: + raise errors.bad_request.ModelIsReady(company=company_id, model=model_id) + + published_task = None + if model.task and publish_task_func: + task = ( + Task.objects(id=model.task, company=company_id) + .only("id", "status") + .first() + ) + if task and task.status != TaskStatus.published: + task_publish_res = publish_task_func( + model.task, company_id, force_publish_task + ) + published_task = ModelTaskPublishResponse( + id=model.task, data=task_publish_res + ) + + updated = model.update(upsert=False, ready=True) + return updated, published_task + + @classmethod + def delete_model( + cls, model_id: str, company_id: str, force: bool + ) -> Tuple[int, Model]: + model = cls.get_company_model_by_id( + company_id=company_id, + model_id=model_id, + only_fields=("id", "task", "project", "uri"), + ) + deleted_model_id = f"{deleted_prefix}{model_id}" + + using_tasks = Task.objects(models__input__model=model_id).only("id") + if using_tasks: + if not force: + raise errors.bad_request.ModelInUse( + "as execution model, use force=True to delete", + num_tasks=len(using_tasks), + ) + # update deleted model id in using tasks + Task._get_collection().update_many( + filter={"_id": {"$in": [t.id for t in using_tasks]}}, + update={"$set": {"models.input.$[elem].model": deleted_model_id}}, + array_filters=[{"elem.model": model_id}], + upsert=False, + ) + + if model.task: + task = Task.objects(id=model.task).first() + if task and task.status == TaskStatus.published: + if not force: + raise errors.bad_request.ModelCreatingTaskExists( + "and published, use force=True to delete", task=model.task + ) + if task.models.output and model_id in task.models.output: + now = datetime.utcnow() + Task._get_collection().update_one( + filter={"_id": model.task, "models.output.model": model_id}, + update={ + "$set": { + "models.output.$[elem].model": deleted_model_id, + "output.error": f"model deleted on {now.isoformat()}", + }, + "last_change": now, + }, + array_filters=[{"elem.model": model_id}], + upsert=False, + ) + + del_count = Model.objects(id=model_id, company=company_id).delete() + return del_count, model + + @classmethod + def archive_model(cls, model_id: str, company_id: str): + cls.get_company_model_by_id( + company_id=company_id, model_id=model_id, only_fields=("id",) + ) + archived = Model.objects(company=company_id, id=model_id).update( + add_to_set__system_tags=EntityVisibility.archived.value + ) + + return archived diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index 5cc457d..75ed954 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -11,6 +11,7 @@ from six import string_types import apiserver.database.utils as dbutils from apiserver.apierrors import errors +from apiserver.apimodels.tasks import TaskInputModel from apiserver.bll.queue import QueueBLL from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL, project_ids_with_children @@ -23,7 +24,6 @@ from apiserver.database.model.task.output import Output from apiserver.database.model.task.task import ( Task, TaskStatus, - TaskStatusMessage, TaskSystemTags, ArtifactModes, ModelItem, @@ -41,11 +41,9 @@ from .artifacts import artifacts_prepare_for_save from .param_utils import params_prepare_for_save from .utils import ( ChangeStatusRequest, - validate_status_change, update_project_time, deleted_prefix, ) -from ...apimodels.tasks import TaskInputModel log = config.logger(__file__) org_bll = OrgBLL() @@ -482,147 +480,6 @@ class TaskBLL: **extra_updates, ) - @classmethod - def model_set_ready( - cls, - model_id: str, - company_id: str, - publish_task: bool, - force_publish_task: bool = False, - ) -> tuple: - with translate_errors_context(): - query = dict(id=model_id, company=company_id) - model = Model.objects(**query).first() - if not model: - raise errors.bad_request.InvalidModelId(**query) - elif model.ready: - raise errors.bad_request.ModelIsReady(**query) - - published_task_data = {} - if model.task and publish_task: - task = ( - Task.objects(id=model.task, company=company_id) - .only("id", "status") - .first() - ) - if task and task.status != TaskStatus.published: - published_task_data["data"] = cls.publish_task( - task_id=model.task, - company_id=company_id, - publish_model=False, - force=force_publish_task, - ) - published_task_data["id"] = model.task - - updated = model.update(upsert=False, ready=True) - return updated, published_task_data - - @classmethod - def publish_task( - cls, - task_id: str, - company_id: str, - publish_model: bool, - force: bool, - status_reason: str = "", - status_message: str = "", - ) -> dict: - task = cls.get_task_with_access( - task_id, company_id=company_id, requires_write_access=True - ) - if not force: - validate_status_change(task.status, TaskStatus.published) - - previous_task_status = task.status - output = task.output or Output() - publish_failed = False - - try: - # set state to publishing - task.status = TaskStatus.publishing - task.save() - - # publish task models - if task.models.output and publish_model: - model_ids = [m.model for m in task.models.output] - for model in Model.objects(id__in=model_ids, ready__ne=True).only("id"): - cls.model_set_ready( - model_id=model.id, company_id=company_id, publish_task=False, - ) - - # set task status to published, and update (or set) it's new output (view and models) - return ChangeStatusRequest( - task=task, - new_status=TaskStatus.published, - force=force, - status_reason=status_reason, - status_message=status_message, - ).execute(published=datetime.utcnow(), output=output) - - except Exception as ex: - publish_failed = True - raise ex - finally: - if publish_failed: - task.status = previous_task_status - task.save() - - @classmethod - def stop_task( - cls, - task_id: str, - company_id: str, - user_name: str, - status_reason: str, - force: bool, - ) -> dict: - """ - Stop a running task. Requires task status 'in_progress' and - execution_progress 'running', or force=True. Development task or - task that has no associated worker is stopped immediately. - For a non-development task with worker only the status message - is set to 'stopping' to allow the worker to stop the task and report by itself - :return: updated task fields - """ - - task = cls.get_task_with_access( - task_id, - company_id=company_id, - only=( - "status", - "project", - "tags", - "system_tags", - "last_worker", - "last_update", - ), - requires_write_access=True, - ) - - def is_run_by_worker(t: Task) -> bool: - """Checks if there is an active worker running the task""" - update_timeout = config.get("apiserver.workers.task_update_timeout", 600) - return ( - t.last_worker - and t.last_update - and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout - ) - - if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task): - new_status = TaskStatus.stopped - status_message = f"Stopped by {user_name}" - else: - new_status = task.status - status_message = TaskStatusMessage.stopping - - return ChangeStatusRequest( - task=task, - new_status=new_status, - status_reason=status_reason, - status_message=status_message, - force=force, - ).execute() - @staticmethod def get_aggregated_project_parameters( company_id, diff --git a/apiserver/bll/task/task_cleanup.py b/apiserver/bll/task/task_cleanup.py index f58a63a..319cfeb 100644 --- a/apiserver/bll/task/task_cleanup.py +++ b/apiserver/bll/task/task_cleanup.py @@ -68,6 +68,16 @@ class TaskUrls: event_urls: Sequence[str] artifact_urls: Sequence[str] + def __add__(self, other: "TaskUrls"): + if not other: + return self + + return TaskUrls( + model_urls=list(set(self.model_urls) | set(other.model_urls)), + event_urls=list(set(self.event_urls) | set(other.event_urls)), + artifact_urls=list(set(self.artifact_urls) | set(other.artifact_urls)), + ) + @attr.s(auto_attribs=True) class CleanupResult: @@ -80,6 +90,17 @@ class CleanupResult: deleted_models: int urls: TaskUrls = None + def __add__(self, other: "CleanupResult"): + if not other: + return self + + return CleanupResult( + updated_children=self.updated_children + other.updated_children, + updated_models=self.updated_models + other.updated_models, + deleted_models=self.deleted_models + other.deleted_models, + urls=self.urls + other.urls if self.urls else other.urls, + ) + def collect_plot_image_urls(company: str, task: str) -> Set[str]: urls = set() @@ -224,7 +245,7 @@ def verify_task_children_and_ouptuts(task: Task, force: bool) -> TaskOutputs[Mod models=len(models.published), ) - if task.models.output: + if task.models and task.models.output: with TimingContext("mongo", "get_task_output_model"): model_ids = [m.model for m in task.models.output] for output_model in Model.objects(id__in=model_ids): @@ -243,11 +264,13 @@ def verify_task_children_and_ouptuts(task: Task, force: bool) -> TaskOutputs[Mod with TimingContext("mongo", "get_execution_models"): model_ids = models.draft.ids dependent_tasks = Task.objects(models__input__model__in=model_ids).only( - "id", "models__input" + "id", "models" ) input_models = { m.model - for m in chain.from_iterable(t.models.input for t in dependent_tasks) + for m in chain.from_iterable( + t.models.input for t in dependent_tasks if t.models + ) } if input_models: models.draft = DocumentGroup( diff --git a/apiserver/bll/task/task_operations.py b/apiserver/bll/task/task_operations.py new file mode 100644 index 0000000..d5423a9 --- /dev/null +++ b/apiserver/bll/task/task_operations.py @@ -0,0 +1,329 @@ +from datetime import datetime +from typing import Callable, Any, Tuple, Union + +from apiserver.apierrors import errors, APIError +from apiserver.bll.queue import QueueBLL +from apiserver.bll.task import ( + TaskBLL, + validate_status_change, + ChangeStatusRequest, + update_project_time, +) +from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult +from apiserver.config_repo import config +from apiserver.database.model import EntityVisibility +from apiserver.database.model.model import Model +from apiserver.database.model.task.output import Output +from apiserver.database.model.task.task import ( + TaskStatus, + Task, + TaskSystemTags, + TaskStatusMessage, + ArtifactModes, + Execution, + DEFAULT_LAST_ITERATION, +) +from apiserver.utilities.dicts import nested_set + +queue_bll = QueueBLL() + + +def archive_task( + task: Union[str, Task], company_id: str, status_message: str, status_reason: str, +) -> int: + """ + Deque and archive task + Return 1 if successful + """ + if isinstance(task, str): + task = TaskBLL.get_task_with_access( + task, + company_id=company_id, + only=("id", "execution", "status", "project", "system_tags"), + requires_write_access=True, + ) + try: + TaskBLL.dequeue_and_change_status( + task, company_id, status_message, status_reason, + ) + except APIError: + # dequeue may fail if the task was not enqueued + pass + task.update( + status_message=status_message, + status_reason=status_reason, + add_to_set__system_tags={EntityVisibility.archived.value}, + last_change=datetime.utcnow(), + ) + return 1 + + +def enqueue_task( + task_id: str, + company_id: str, + queue_id: str, + status_message: str, + status_reason: str, +) -> Tuple[int, dict]: + if not queue_id: + # try to get default queue + queue_id = queue_bll.get_default(company_id).id + + query = dict(id=task_id, company=company_id) + task = Task.get_for_writing( + _only=("type", "script", "execution", "status", "project", "id"), **query + ) + if not task: + raise errors.bad_request.InvalidTaskId(**query) + + res = ChangeStatusRequest( + task=task, + new_status=TaskStatus.queued, + status_reason=status_reason, + status_message=status_message, + allow_same_state_transition=False, + ).execute() + + try: + queue_bll.add_task(company_id=company_id, queue_id=queue_id, task_id=task.id) + except Exception: + # failed enqueueing, revert to previous state + ChangeStatusRequest( + task=task, + current_status_override=TaskStatus.queued, + new_status=task.status, + force=True, + status_reason="failed enqueueing", + ).execute() + raise + + # set the current queue ID in the task + if task.execution: + Task.objects(**query).update(execution__queue=queue_id, multi=False) + else: + Task.objects(**query).update(execution=Execution(queue=queue_id), multi=False) + + nested_set(res, ("fields", "execution.queue"), queue_id) + return 1, res + + +def delete_task( + task_id: str, + company_id: str, + move_to_trash: bool, + force: bool, + return_file_urls: bool, + delete_output_models: bool, +) -> Tuple[int, Task, CleanupResult]: + task = TaskBLL.get_task_with_access( + task_id, company_id=company_id, requires_write_access=True + ) + + if ( + task.status != TaskStatus.created + and EntityVisibility.archived.value not in task.system_tags + and not force + ): + raise errors.bad_request.TaskCannotBeDeleted( + "due to status, use force=True", + task=task.id, + expected=TaskStatus.created, + current=task.status, + ) + + cleanup_res = cleanup_task( + task, + force=force, + return_file_urls=return_file_urls, + delete_output_models=delete_output_models, + ) + + if move_to_trash: + collection_name = task._get_collection_name() + archived_collection = "{}__trash".format(collection_name) + task.switch_collection(archived_collection) + try: + # A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force + # an insert. However, if for some reason such an ID exists, let's make sure we'll keep going. + task.save(force_insert=True) + except Exception: + pass + task.switch_collection(collection_name) + + task.delete() + update_project_time(task.project) + return 1, task, cleanup_res + + +def reset_task( + task_id: str, + company_id: str, + force: bool, + return_file_urls: bool, + delete_output_models: bool, + clear_all: bool, +) -> Tuple[dict, CleanupResult, dict]: + task = TaskBLL.get_task_with_access( + task_id, company_id=company_id, requires_write_access=True + ) + + if not force and task.status == TaskStatus.published: + raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status) + + dequeued = {} + updates = {} + + try: + dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True) + except APIError: + # dequeue may fail if the task was not enqueued + pass + + cleaned_up = cleanup_task( + task, + force=force, + update_children=False, + return_file_urls=return_file_urls, + delete_output_models=delete_output_models, + ) + + updates.update( + set__last_iteration=DEFAULT_LAST_ITERATION, + set__last_metrics={}, + set__metric_stats={}, + set__models__output=[], + unset__output__result=1, + unset__output__error=1, + unset__last_worker=1, + unset__last_worker_report=1, + ) + + if clear_all: + updates.update( + set__execution=Execution(), unset__script=1, + ) + else: + updates.update(unset__execution__queue=1) + if task.execution and task.execution.artifacts: + updates.update( + set__execution__artifacts={ + key: artifact + for key, artifact in task.execution.artifacts.items() + if artifact.mode == ArtifactModes.input + } + ) + + res = ChangeStatusRequest( + task=task, + new_status=TaskStatus.created, + force=force, + status_reason="reset", + status_message="reset", + ).execute( + started=None, completed=None, published=None, active_duration=None, **updates, + ) + + return dequeued, cleaned_up, res + + +def publish_task( + task_id: str, + company_id: str, + force: bool, + publish_model_func: Callable[[str, str], Any] = None, + status_message: str = "", + status_reason: str = "", +) -> dict: + task = TaskBLL.get_task_with_access( + task_id, company_id=company_id, requires_write_access=True + ) + if not force: + validate_status_change(task.status, TaskStatus.published) + + previous_task_status = task.status + output = task.output or Output() + publish_failed = False + + try: + # set state to publishing + task.status = TaskStatus.publishing + task.save() + + # publish task models + if task.models and task.models.output and publish_model_func: + model_id = task.models.output[-1].model + model = ( + Model.objects(id=model_id, company=company_id) + .only("id", "ready") + .first() + ) + if model and not model.ready: + publish_model_func(model.id, company_id) + + # set task status to published, and update (or set) it's new output (view and models) + return ChangeStatusRequest( + task=task, + new_status=TaskStatus.published, + force=force, + status_reason=status_reason, + status_message=status_message, + ).execute(published=datetime.utcnow(), output=output) + + except Exception as ex: + publish_failed = True + raise ex + finally: + if publish_failed: + task.status = previous_task_status + task.save() + + +def stop_task( + task_id: str, company_id: str, user_name: str, status_reason: str, force: bool, +) -> dict: + """ + Stop a running task. Requires task status 'in_progress' and + execution_progress 'running', or force=True. Development task or + task that has no associated worker is stopped immediately. + For a non-development task with worker only the status message + is set to 'stopping' to allow the worker to stop the task and report by itself + :return: updated task fields + """ + + task = TaskBLL.get_task_with_access( + task_id, + company_id=company_id, + only=( + "status", + "project", + "tags", + "system_tags", + "last_worker", + "last_update", + ), + requires_write_access=True, + ) + + def is_run_by_worker(t: Task) -> bool: + """Checks if there is an active worker running the task""" + update_timeout = config.get("apiserver.workers.task_update_timeout", 600) + return ( + t.last_worker + and t.last_update + and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout + ) + + if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task): + new_status = TaskStatus.stopped + status_message = f"Stopped by {user_name}" + else: + new_status = task.status + status_message = TaskStatusMessage.stopping + + return ChangeStatusRequest( + task=task, + new_status=new_status, + status_reason=status_reason, + status_message=status_message, + force=force, + ).execute() diff --git a/apiserver/bll/util.py b/apiserver/bll/util.py index 290d32a..95d228f 100644 --- a/apiserver/bll/util.py +++ b/apiserver/bll/util.py @@ -1,10 +1,21 @@ import functools import itertools from concurrent.futures.thread import ThreadPoolExecutor -from typing import Optional, Callable, Dict, Any, Set, Iterable +from typing import ( + Optional, + Callable, + Dict, + Any, + Set, + Iterable, + Tuple, + Sequence, + TypeVar, +) from boltons import iterutils +from apiserver.apierrors import APIError from apiserver.database.model import AttributedDocument from apiserver.database.model.settings import Settings @@ -96,3 +107,28 @@ def parallel_chunked_decorator(func: Callable = None, chunk_size: int = 100): ) return wrapper + + +T = TypeVar("T") + + +def run_batch_operation( + func: Callable[[str], T], init_res: T, ids: Sequence[str] +) -> Tuple[T, Sequence]: + res = init_res + failures = list() + for _id in ids: + try: + res += func(_id) + except APIError as err: + failures.append( + { + "id": _id, + "error": { + "codes": [err.code, err.subcode], + "msg": err.msg, + "data": err.error_data, + }, + } + ) + return res, failures diff --git a/apiserver/mongo/initialize/pre_populate.py b/apiserver/mongo/initialize/pre_populate.py index 291f40b..6c49d62 100644 --- a/apiserver/mongo/initialize/pre_populate.py +++ b/apiserver/mongo/initialize/pre_populate.py @@ -465,6 +465,7 @@ class PrePopulate: task_models = chain.from_iterable( models for task in entities[cls.task_cls] + if task.models for models in (task.models.input, task.models.output) if models ) diff --git a/apiserver/schema/services/_common.conf b/apiserver/schema/services/_common.conf index 94913b5..1aec228 100644 --- a/apiserver/schema/services/_common.conf +++ b/apiserver/schema/services/_common.conf @@ -31,3 +31,45 @@ credentials { } } } +batch_operation { + request { + type: object + required: [ids] + properties { + ids { + description: Entities to move + type: array + items {type: string} + } + } + } + response { + failures { + type: array + item { + type: object + id: { + description: ID of the failed entity + type: string + } + error: { + description: Error info + type: object + properties { + codes { + type: array + item {type: integer} + } + msg { + type: string + } + data { + type: object + additionalProperties: True + } + } + } + } + } + } +} \ No newline at end of file diff --git a/apiserver/schema/services/models.conf b/apiserver/schema/services/models.conf index 769e0da..f7e692c 100644 --- a/apiserver/schema/services/models.conf +++ b/apiserver/schema/services/models.conf @@ -94,6 +94,32 @@ _definitions { } } } + published_task_item { + description: "Result of publishing of the model's associated task (if exists). Returned only if the task was published successfully as part of the model publishing." + type: object + properties { + id { + description: "Task id" + type: string + } + data { + description: "Data returned from the task publishing operation." + type: object + properties { + updated { + description: "Number of tasks updated (0 or 1)" + type: integer + enum: [ 0, 1 ] + } + fields { + description: "Updated fields names and values" + type: object + additionalProperties: true + } + } + } + } + } } get_by_id { @@ -628,6 +654,33 @@ update { } } } +publish_many { + "2.13": ${_definitions.batch_operation} { + description: Publish models + request { + force_publish_task { + description: "Publish the associated tasks (if exist) even if they are not in the 'stopped' state. Optional, the default value is False." + type: boolean + } + publish_tasks { + description: "Indicates that the associated tasks (if exist) should be published. Optional, the default value is True." + type: boolean + } + } + response { + properties { + published { + description: "Number of models published" + type: integer + } + published_tasks { + type: array + items: ${_definitions.published_task_item} + } + } + } + } +} set_ready { "2.1" { description: "Set the model ready flag to True. If the model is an output model of a task then try to publish the task." @@ -657,39 +710,44 @@ set_ready { type: integer enum: [0, 1] } - published_task { - description: "Result of publishing of the model's associated task (if exists). Returned only if the task was published successfully as part of the model publishing." - type: object - properties { - id { - description: "Task id" - type: string - } - data { - description: "Data returned from the task publishing operation." - type: object - properties { - committed_versions_results { - description: "Committed versions results" - type: array - items { - type: object - additionalProperties: true - } - } - updated { - description: "Number of tasks updated (0 or 1)" - type: integer - enum: [ 0, 1 ] - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true - } - } - } - } + published_task: ${_definitions.published_task_item} + } + } + } +} +archive_many { + "2.13": ${_definitions.batch_operation} { + description: Archive models + response { + properties { + archived { + description: "Number of models archived" + type: integer + } + } + } + } +} +delete_many { + "2.13": ${_definitions.batch_operation} { + description: Delete models + request { + force { + description: """Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published. + """ + type: boolean + } + } + response { + properties { + deleted { + description: "Number of models deleted" + type: integer + } + urls { + descrition: "The urls of the deleted model files" + type: array + items {type: string} } } } @@ -875,5 +933,4 @@ delete_metadata { } } } -} - +} \ No newline at end of file diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index 6e3bb02..6bd5000 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -26,6 +26,35 @@ _references { } _definitions { include "_common.conf" + change_many_request: ${_definitions.batch_operation} { + request { + properties { + status_reason { + description: Reason for status change + type: string + } + status_message { + description: Extra information regarding status change + type: string + } + } + } + } + update_response { + type: object + properties { + updated { + description: "Number of tasks updated (0 or 1)" + type: integer + enum: [ 0, 1 ] + } + fields { + description: "Updated fields names and values" + type: object + additionalProperties: true + } + } + } multi_field_pattern_data { type: object properties { @@ -1216,21 +1245,7 @@ update { } } } - response { - type: object - properties { - updated { - description: "Number of tasks updated (0 or 1)" - type: integer - enum: [ 0, 1 ] - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true - } - } - } + response: ${_definitions.update_response} } } update_batch { @@ -1328,21 +1343,7 @@ edit { } } } - response { - type: object - properties { - updated { - description: "Number of tasks updated (0 or 1)" - type: integer - enum: [ 0, 1 ] - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true - } - } - } + response: ${_definitions.update_response} } "2.13": ${edit."2.1"} { request { @@ -1376,8 +1377,7 @@ reset { default: false } } ${_references.status_change_request} - response { - type: object + response: ${_definitions.update_response} { properties { deleted_indices { description: "List of deleted ES indices that were removed as part of the reset process" @@ -1403,16 +1403,6 @@ reset { description: "Number of output models deleted by the reset" type: integer } - updated { - description: "Number of tasks updated (0 or 1)" - type: integer - enum: [ 0, 1 ] - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true - } } } } @@ -1435,6 +1425,101 @@ reset { } } } +reset_many { + "2.13": ${_definitions.batch_operation} { + description: Reset tasks + request { + properties { + force = ${_references.force_arg} { + description: "If not true, call fails if the task status is 'completed'" + } + clear_all { + description: "Clear script and execution sections completely" + type: boolean + default: false + } + return_file_urls { + description: "If set to 'true' then return the urls of the files that were uploaded by the tasks. Default value is 'false'" + type: boolean + } + delete_output_models { + description: "If set to 'true' then delete output models of the tasks that are not referenced by other tasks. Default value is 'true'" + type: boolean + } + } + } + response { + properties { + reset { + description: "Number of tasks reset" + type: integer + } + dequeued { + description: "Number of tasks dequeued" + type: object + additionalProperties: true + } + deleted_models { + description: "Number of output models deleted by the reset" + type: integer + } + urls { + description: "The urls of the files that were uploaded by the tasks. Returned if the 'return_file_urls' was set to 'true'" + "$ref": "#/definitions/task_urls" + } + } + } + } +} +delete_many { + "2.13": ${_definitions.batch_operation} { + description: Delete tasks + request { + properties { + move_to_trash { + description: "Move task to trash instead of deleting it. For internal use only, tasks in the trash are not visible from the API and cannot be restored!" + type: boolean + default: false + } + force = ${_references.force_arg} { + description: "If not true, call fails if the task status is 'in_progress'" + } + return_file_urls { + description: "If set to 'true' then return the urls of the files that were uploaded by the tasks. Default value is 'false'" + type: boolean + } + delete_output_models { + description: "If set to 'true' then delete output models of the tasks that are not referenced by other tasks. Default value is 'true'" + type: boolean + } + } + } + response { + properties { + deleted { + description: "Number of tasks deleted" + type: integer + } + updated_children { + description: "Number of child tasks whose parent property was updated" + type: integer + } + updated_models { + description: "Number of models whose task property was updated" + type: integer + } + deleted_models { + description: "Number of deleted output models" + type: integer + } + urls { + description: "The urls of the files that were uploaded by the tasks. Returned if the 'return_file_urls' was set to 'true'" + "$ref": "#/definitions/task_urls" + } + } + } + } +} delete { "2.1" { description: """Delete a task along with any information stored for it (statistics, frame updates etc.) @@ -1472,15 +1557,6 @@ delete { description: "Number of models whose task property was updated" type: integer } - updated_versions { - description: "Number of dataset versions whose task property was updated" - type: integer - } - frames { - description: "Response from frames.rollback" - type: object - additionalProperties: true - } events { description: "Response from events.delete_for_task" type: object @@ -1545,6 +1621,19 @@ archive { } } } +archive_many { + "2.13": ${_definitions.change_many_request} { + description: Archive tasks + response { + properties { + archived { + description: "Number of tasks archived" + type: integer + } + } + } + } +} started { "2.1" { description: "Mark a task status as in_progress. Optionally allows to set the task's execution progress." @@ -1557,24 +1646,13 @@ started { description: "If not true, call fails if the task status is not 'not_started'" } } ${_references.status_change_request} - response { - type: object + response: ${_definitions.update_response} { properties { started { description: "Number of tasks started (0 or 1)" type: integer enum: [ 0, 1 ] } - updated { - description: "Number of tasks updated (0 or 1)" - type: integer - enum: [ 0, 1 ] - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true - } } } } @@ -1591,18 +1669,24 @@ stop { description: "If not true, call fails if the task status is not 'in_progress'" } } ${_references.status_change_request} - response { - type: object + response: ${_definitions.update_response} + } +} +stop_many { + "2.13": ${_definitions.change_many_request} { + description: "Request to stop running tasks" + request { properties { - updated { - description: "Number of tasks updated (0 or 1)" - type: integer - enum: [ 0, 1 ] + force = ${_references.force_arg} { + description: "If not true, call fails if the task status is not 'in_progress'" } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true + } + } + response { + properties { + stopped { + description: "Number of tasks stopped" + type: integer } } } @@ -1620,21 +1704,7 @@ stopped { description: "If not true, call fails if the task status is not 'stopped'" } } ${_references.status_change_request} - response { - type: object - properties { - updated { - description: "Number of tasks updated (0 or 1)" - type: integer - enum: [ 0, 1 ] - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true - } - } - } + response: ${_definitions.update_response} } } failed { @@ -1647,21 +1717,7 @@ failed { ] properties.force = ${_references.force_arg} } ${_references.status_change_request} - response { - type: object - properties { - updated { - description: "Number of tasks updated (0 or 1)" - type: integer - enum: [ 0, 1 ] - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true - } - } - } + response: ${_definitions.update_response} } } close { @@ -1674,21 +1730,7 @@ close { ] properties.force = ${_references.force_arg} } ${_references.status_change_request} - response { - type: object - properties { - updated { - description: "Number of tasks updated (0 or 1)" - type: integer - enum: [ 0, 1 ] - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true - } - } - } + response: ${_definitions.update_response} } } publish { @@ -1713,26 +1755,28 @@ publish { } } } ${_references.status_change_request} - response { - type: object + response: ${_definitions.update_response} + } +} +publish_many { + "2.13": ${_definitions.change_many_request} { + description: Publish tasks + request { properties { - committed_versions_results { - description: "Committed versions results" - type: array - items { - type: object - additionalProperties: true - } + force = ${_references.force_arg} { + description: "If not true, call fails if the task status is not 'stopped'" } - updated { - description: "Number of tasks updated (0 or 1)" + publish_model { + description: "Indicates that the task output model (if exists) should be published. Optional, the default value is True." + type: boolean + } + } + } + response { + properties { + published { + description: "Number of tasks published" type: integer - enum: [ 0, 1 ] - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true } } } @@ -1763,23 +1807,25 @@ Fails if the following parameters in the task were not filled: } } ${_references.status_change_request} - response { - type: object + response: ${_definitions.update_response} { properties { queued { description: "Number of tasks queued (0 or 1)" type: integer enum: [ 0, 1 ] } - updated { - description: "Number of tasks updated (0 or 1)" + } + } + } +} +enqueue_many { + "2.13": ${_definitions.change_many_request} { + description: Enqueue tasks + response { + properties { + enqueued { + description: "Number of tasks enqueued" type: integer - enum: [ 0, 1 ] - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true } } } @@ -1795,24 +1841,13 @@ dequeue { task ] } ${_references.status_change_request} - response { - type: object + response: ${_definitions.update_response} { properties { dequeued { description: "Number of tasks dequeued (0 or 1)" type: integer enum: [ 0, 1 ] } - updated { - description: "Number of tasks updated (0 or 1)" - type: integer - enum: [ 0, 1 ] - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true - } } } } @@ -1837,21 +1872,7 @@ set_requirements { } } } - response { - type: object - properties { - updated { - description: "Number of tasks updated (0 or 1)" - type: integer - enum: [ 0, 1 ] - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true - } - } - } + response: ${_definitions.update_response} } } @@ -1867,21 +1888,7 @@ completed { description: "If not true, call fails if the task status is not in_progress/stopped" } } ${_references.status_change_request} - response { - type: object - properties { - updated { - description: "Number of tasks updated (0 or 1)" - type: integer - enum: [ 0, 1 ] - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true - } - } - } + response: ${_definitions.update_response} } } diff --git a/apiserver/services/models.py b/apiserver/services/models.py index 01db602..dd7152a 100644 --- a/apiserver/services/models.py +++ b/apiserver/services/models.py @@ -1,6 +1,8 @@ from datetime import datetime -from typing import Sequence +from functools import partial +from typing import Sequence, Tuple, Set +import attr from mongoengine import Q, EmbeddedDocument from apiserver import database @@ -17,11 +19,19 @@ from apiserver.apimodels.models import ( DeleteModelRequest, DeleteMetadataRequest, AddOrUpdateMetadataRequest, + ModelsPublishManyRequest, + ModelsPublishManyResponse, + ModelsDeleteManyRequest, + ModelsDeleteManyResponse, + ModelsArchiveManyRequest, + ModelsArchiveManyResponse, ) +from apiserver.bll.model import ModelBLL from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL, project_ids_with_children from apiserver.bll.task import TaskBLL -from apiserver.bll.task.utils import deleted_prefix +from apiserver.bll.task.task_operations import publish_task +from apiserver.bll.util import run_batch_operation from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context from apiserver.database.model import validate_id @@ -80,7 +90,7 @@ def get_by_task_id(call: APICall, company_id, _): task = Task.get(_only=["models"], **query) if not task: raise errors.bad_request.InvalidTaskId(**query) - if not task.models.output: + if not task.models or not task.models.output: raise errors.bad_request.MissingTaskFields(field="models.output") model_id = task.models.output[-1].model @@ -198,17 +208,6 @@ def _reset_cached_tags(company: str, projects: Sequence[str]): ) -def _get_company_model(company_id: str, model_id: str, only_fields=None) -> Model: - query = dict(company=company_id, id=model_id) - qs = Model.objects(**query) - if only_fields: - qs = qs.only(*only_fields) - model = qs.first() - if not model: - raise errors.bad_request.InvalidModelId(**query) - return model - - @endpoint("models.update_for_task", required_fields=["task"]) def update_for_task(call: APICall, company_id, _): if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version: @@ -242,7 +241,7 @@ def update_for_task(call: APICall, company_id, _): ) if override_model_id: - model = _get_company_model( + model = ModelBLL.get_company_model_by_id( company_id=company_id, model_id=override_model_id ) else: @@ -253,7 +252,7 @@ def update_for_task(call: APICall, company_id, _): if "comment" not in call.data: call.data["comment"] = f"Created by task `{task.name}` ({task.id})" - if task.models.output: + if task.models and task.models.output: # model exists, update model_id = task.models.output[-1].model res = _update_model(call, company_id, model_id=model_id).to_struct() @@ -272,7 +271,9 @@ def update_for_task(call: APICall, company_id, _): company=company_id, project=task.project, framework=task.execution.framework, - parent=task.models.input[0].model if task.models.input else None, + parent=task.models.input[0].model + if task.models and task.models.input + else None, design=task.execution.model_desc, labels=task.execution.model_labels, ready=(task.status == TaskStatus.published), @@ -377,7 +378,9 @@ def edit(call: APICall, company_id, _): model_id = call.data["model"] with translate_errors_context(): - model = _get_company_model(company_id=company_id, model_id=model_id) + model = ModelBLL.get_company_model_by_id( + company_id=company_id, model_id=model_id + ) fields = parse_model_fields(call, create_fields) fields = prepare_update_fields(call, company_id, fields) @@ -423,7 +426,9 @@ def _update_model(call: APICall, company_id, model_id=None): model_id = model_id or call.data["model"] with translate_errors_context(): - model = _get_company_model(company_id=company_id, model_id=model_id) + model = ModelBLL.get_company_model_by_id( + company_id=company_id, model_id=model_id + ) data = prepare_update_fields(call, company_id, call.data) @@ -463,94 +468,131 @@ def update(call, company_id, _): request_data_model=PublishModelRequest, response_data_model=PublishModelResponse, ) -def set_ready(call: APICall, company_id, req_model: PublishModelRequest): - updated, published_task_data = TaskBLL.model_set_ready( - model_id=req_model.model, +def set_ready(call: APICall, company_id: str, request: PublishModelRequest): + updated, published_task = ModelBLL.publish_model( + model_id=request.model, company_id=company_id, - publish_task=req_model.publish_task, - force_publish_task=req_model.force_publish_task, + force_publish_task=request.force_publish_task, + publish_task_func=publish_task if request.publish_task else None, + ) + call.result.data_model = PublishModelResponse( + updated=updated, published_task=published_task ) - call.result.data_model = PublishModelResponse( - updated=updated, - published_task=ModelTaskPublishResponse(**published_task_data) - if published_task_data - else None, + +@attr.s(auto_attribs=True) +class PublishRes: + published: int = 0 + published_tasks: Sequence = [] + + def __add__(self, other: Tuple[int, ModelTaskPublishResponse]): + published, response = other + return PublishRes( + published=self.published + published, + published_tasks=[*self.published_tasks, *([response] if response else [])], + ) + + +@endpoint( + "models.publish_many", + request_data_model=ModelsPublishManyRequest, + response_data_model=ModelsPublishManyResponse, +) +def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest): + res, failures = run_batch_operation( + func=partial( + ModelBLL.publish_model, + company_id=company_id, + force_publish_task=request.force_publish_task, + publish_task_func=publish_task if request.publish_task else None, + ), + ids=request.ids, + init_res=PublishRes(), + ) + + call.result.data_model = ModelsPublishManyResponse( + published=res.published, published_tasks=res.published_tasks, failures=failures, ) @endpoint("models.delete", request_data_model=DeleteModelRequest) def delete(call: APICall, company_id, request: DeleteModelRequest): - model_id = request.model - force = request.force - - with translate_errors_context(): - model = _get_company_model( - company_id=company_id, - model_id=model_id, - only_fields=("id", "task", "project", "uri"), + del_count, model = ModelBLL.delete_model( + model_id=request.model, company_id=company_id, force=request.force + ) + if del_count: + _reset_cached_tags( + company_id, projects=[model.project] if model.project else [] ) - deleted_model_id = f"{deleted_prefix}{model_id}" - using_tasks = Task.objects(models__input__model=model_id).only("id") - if using_tasks: - if not force: - raise errors.bad_request.ModelInUse( - "as execution model, use force=True to delete", - num_tasks=len(using_tasks), - ) - # update deleted model id in using tasks - Task._get_collection().update_many( - filter={"_id": {"$in": [t.id for t in using_tasks]}}, - update={"$set": {"models.input.$[elem].model": deleted_model_id}}, - array_filters=[{"elem.model": model_id}], - upsert=False, - ) + call.result.data = dict(deleted=del_count > 0, url=model.uri) - if model.task: - task: Task = Task.objects(id=model.task).first() - if task and task.status == TaskStatus.published: - if not force: - raise errors.bad_request.ModelCreatingTaskExists( - "and published, use force=True to delete", task=model.task - ) - if task.models.output and model_id in task.models.output: - now = datetime.utcnow() - Task._get_collection().update_one( - filter={"_id": model.task, "models.output.model": model_id}, - update={ - "$set": { - "models.output.$[elem].model": deleted_model_id, - "output.error": f"model deleted on {now.isoformat()}", - }, - "last_change": now, - }, - array_filters=[{"elem.model": model_id}], - upsert=False, - ) - del_count = Model.objects(id=model_id, company=company_id).delete() - if del_count: - _reset_cached_tags(company_id, projects=[model.project]) - call.result.data = dict(deleted=del_count > 0, url=model.uri,) +@attr.s(auto_attribs=True) +class DeleteRes: + deleted: int = 0 + projects: Set = set() + urls: Set = set() + + def __add__(self, other: Tuple[int, Model]): + del_count, model = other + return DeleteRes( + deleted=self.deleted + del_count, + projects=self.projects | {model.project}, + urls=self.urls | {model.uri}, + ) + + +@endpoint( + "models.delete_many", + request_data_model=ModelsDeleteManyRequest, + response_data_model=ModelsDeleteManyResponse, +) +def delete(call: APICall, company_id, request: ModelsDeleteManyRequest): + res, failures = run_batch_operation( + func=partial(ModelBLL.delete_model, company_id=company_id, force=request.force), + ids=request.ids, + init_res=DeleteRes(), + ) + if res.deleted: + _reset_cached_tags(company_id, projects=list(res.projects)) + + res.urls.discard(None) + call.result.data_model = ModelsDeleteManyResponse( + deleted=res.deleted, urls=list(res.urls), failures=failures, + ) + + +@endpoint( + "models.archive_many", + request_data_model=ModelsArchiveManyRequest, + response_data_model=ModelsArchiveManyResponse, +) +def archive_many(call: APICall, company_id, request: ModelsArchiveManyRequest): + archived, failures = run_batch_operation( + func=partial(ModelBLL.archive_model, company_id=company_id), + ids=request.ids, + init_res=0, + ) + call.result.data_model = ModelsArchiveManyResponse( + archived=archived, failures=failures, + ) @endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest) def make_public(call: APICall, company_id, request: MakePublicRequest): - with translate_errors_context(): - call.result.data = Model.set_public( - company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True - ) + call.result.data = Model.set_public( + company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True + ) @endpoint( "models.make_private", min_version="2.9", request_data_model=MakePublicRequest ) def make_public(call: APICall, company_id, request: MakePublicRequest): - with translate_errors_context(): - call.result.data = Model.set_public( - company_id, request.ids, invalid_cls=InvalidModelId, enabled=False - ) + call.result.data = Model.set_public( + company_id, request.ids, invalid_cls=InvalidModelId, enabled=False + ) @endpoint("models.move", request_data_model=MoveRequest) @@ -560,17 +602,16 @@ def move(call: APICall, company_id: str, request: MoveRequest): "project or project_name is required" ) - with translate_errors_context(): - return { - "project_id": project_bll.move_under_project( - entity_cls=Model, - user=call.identity.user, - company=company_id, - ids=request.ids, - project=request.project, - project_name=request.project_name, - ) - } + return { + "project_id": project_bll.move_under_project( + entity_cls=Model, + user=call.identity.user, + company=company_id, + ids=request.ids, + project=request.project, + project_name=request.project_name, + ) + } @endpoint("models.add_or_update_metadata", min_version="2.13") @@ -578,7 +619,7 @@ def add_or_update_metadata( _: APICall, company_id: str, request: AddOrUpdateMetadataRequest ): model_id = request.model - _get_company_model(company_id=company_id, model_id=model_id) + ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id) return { "updated": metadata_add_or_update( @@ -590,6 +631,8 @@ def add_or_update_metadata( @endpoint("models.delete_metadata", min_version="2.13") def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest): model_id = request.model - _get_company_model(company_id=company_id, model_id=model_id, only_fields=("id",)) + ModelBLL.get_company_model_by_id( + company_id=company_id, model_id=model_id, only_fields=("id",) + ) return {"updated": metadata_delete(cls=Model, _id=model_id, keys=request.keys)} diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index f2241d1..f6be307 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -1,6 +1,7 @@ from copy import deepcopy from datetime import datetime -from typing import Sequence, Union, Tuple +from functools import partial +from typing import Sequence, Union, Tuple, Set import attr import dpath @@ -8,7 +9,7 @@ from mongoengine import EmbeddedDocument, Q from mongoengine.queryset.transform import COMPARISON_OPERATORS from pymongo import UpdateOne -from apiserver.apierrors import errors, APIError +from apiserver.apierrors import errors from apiserver.apierrors.errors.bad_request import InvalidTaskId from apiserver.apimodels.base import ( UpdateResponse, @@ -47,8 +48,18 @@ from apiserver.apimodels.tasks import ( AddUpdateModelRequest, DeleteModelsRequest, ModelItemType, + StopManyResponse, + StopManyRequest, + EnqueueManyRequest, + EnqueueManyResponse, + ResetManyRequest, + ArchiveManyRequest, + ArchiveManyResponse, + DeleteManyRequest, + PublishManyRequest, ) from apiserver.bll.event import EventBLL +from apiserver.bll.model import ModelBLL from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL, project_ids_with_children from apiserver.bll.queue import QueueBLL @@ -69,19 +80,23 @@ from apiserver.bll.task.param_utils import ( params_unprepare_from_saved, escape_paths, ) -from apiserver.bll.task.task_cleanup import cleanup_task +from apiserver.bll.task.task_cleanup import CleanupResult +from apiserver.bll.task.task_operations import ( + stop_task, + enqueue_task, + reset_task, + archive_task, + delete_task, + publish_task, +) from apiserver.bll.task.utils import update_task, deleted_prefix, get_task_for_update -from apiserver.bll.util import SetFieldsResolver +from apiserver.bll.util import SetFieldsResolver, run_batch_operation from apiserver.database.errors import translate_errors_context -from apiserver.database.model import EntityVisibility from apiserver.database.model.task.output import Output from apiserver.database.model.task.task import ( Task, TaskStatus, Script, - DEFAULT_LAST_ITERATION, - Execution, - ArtifactModes, ModelItem, ) from apiserver.database.utils import get_fields_attr, parse_from_call, get_options @@ -199,9 +214,7 @@ def get_all_ex(call: APICall, company_id, _): with TimingContext("mongo", "task_get_all_ex"): _process_include_subprojects(call_data) tasks = Task.get_many_with_join( - company=company_id, - query_dict=call_data, - allow_public=True, # required in case projection is requested for public dataset/versions + company=company_id, query_dict=call_data, allow_public=True, ) unprepare_from_saved(call, tasks) call.result.data = {"tasks": tasks} @@ -235,7 +248,7 @@ def get_all(call: APICall, company_id, _): company=company_id, parameters=call_data, query_dict=call_data, - allow_public=True, # required in case projection is requested for public dataset/versions + allow_public=True, ) unprepare_from_saved(call, tasks) call.result.data = {"tasks": tasks} @@ -263,7 +276,7 @@ def stop(call: APICall, company_id, req_model: UpdateRequest): """ call.result.data_model = UpdateResponse( - **TaskBLL.stop_task( + **stop_task( task_id=req_model.task, company_id=company_id, user_name=call.identity.user_name, @@ -273,6 +286,34 @@ def stop(call: APICall, company_id, req_model: UpdateRequest): ) +@attr.s(auto_attribs=True) +class StopRes: + stopped: int = 0 + + def __add__(self, other: dict): + return StopRes(stopped=self.stopped + 1) + + +@endpoint( + "tasks.stop_many", + request_data_model=StopManyRequest, + response_data_model=StopManyResponse, +) +def stop_many(call: APICall, company_id, request: StopManyRequest): + res, failures = run_batch_operation( + func=partial( + stop_task, + company_id=company_id, + user_name=call.identity.user_name, + status_reason=request.status_reason, + force=request.force, + ), + ids=request.ids, + init_res=StopRes(), + ) + call.result.data_model = StopManyResponse(stopped=res.stopped, failures=failures) + + @endpoint( "tasks.stopped", request_data_model=UpdateRequest, @@ -792,61 +833,44 @@ def delete_configuration( request_data_model=EnqueueRequest, response_data_model=EnqueueResponse, ) -def enqueue(call: APICall, company_id, req_model: EnqueueRequest): - task_id = req_model.task - queue_id = req_model.queue - status_message = req_model.status_message - status_reason = req_model.status_reason +def enqueue(call: APICall, company_id, request: EnqueueRequest): + queued, res = enqueue_task( + task_id=request.task, + company_id=company_id, + queue_id=request.queue, + status_message=request.status_message, + status_reason=request.status_reason, + ) + call.result.data_model = EnqueueResponse(queued=queued, **res) - if not queue_id: - # try to get default queue - queue_id = queue_bll.get_default(company_id).id - with translate_errors_context(): - query = dict(id=task_id, company=company_id) - task = Task.get_for_writing( - _only=("type", "script", "execution", "status", "project", "id"), **query - ) - if not task: - raise errors.bad_request.InvalidTaskId(**query) +@attr.s(auto_attribs=True) +class EnqueueRes: + queued: int = 0 - res = EnqueueResponse( - **ChangeStatusRequest( - task=task, - new_status=TaskStatus.queued, - status_reason=status_reason, - status_message=status_message, - allow_same_state_transition=False, - ).execute() - ) + def __add__(self, other: Tuple[int, dict]): + queued, _ = other + return EnqueueRes(queued=self.queued + queued) - try: - queue_bll.add_task( - company_id=company_id, queue_id=queue_id, task_id=task.id - ) - except Exception: - # failed enqueueing, revert to previous state - ChangeStatusRequest( - task=task, - current_status_override=TaskStatus.queued, - new_status=task.status, - force=True, - status_reason="failed enqueueing", - ).execute() - raise - # set the current queue ID in the task - if task.execution: - Task.objects(**query).update(execution__queue=queue_id, multi=False) - else: - Task.objects(**query).update( - execution=Execution(queue=queue_id), multi=False - ) - - res.queued = 1 - res.fields.update(**{"execution.queue": queue_id}) - - call.result.data_model = res +@endpoint( + "tasks.enqueue_many", + request_data_model=EnqueueManyRequest, + response_data_model=EnqueueManyResponse, +) +def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest): + res, failures = run_batch_operation( + func=partial( + enqueue_task, + company_id=company_id, + queue_id=request.queue, + status_message=request.status_message, + status_reason=request.status_reason, + ), + ids=request.ids, + init_res=EnqueueRes(), + ) + call.result.data_model = EnqueueManyResponse(queued=res.queued, failures=failures) @endpoint( @@ -878,164 +902,161 @@ def dequeue(call: APICall, company_id, request: UpdateRequest): "tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse ) def reset(call: APICall, company_id, request: ResetRequest): - task = TaskBLL.get_task_with_access( - request.task, company_id=company_id, requires_write_access=True - ) - - force = request.force - - if not force and task.status == TaskStatus.published: - raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status) - - api_results = {} - updates = {} - - try: - dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True) - except APIError: - # dequeue may fail if the task was not enqueued - pass - else: - if dequeued: - api_results.update(dequeued=dequeued) - - cleaned_up = cleanup_task( - task, - force=force, - update_children=False, + dequeued, cleanup_res, updates = reset_task( + task_id=request.task, + company_id=company_id, + force=request.force, return_file_urls=request.return_file_urls, delete_output_models=request.delete_output_models, - ) - api_results.update(attr.asdict(cleaned_up)) - - updates.update( - set__last_iteration=DEFAULT_LAST_ITERATION, - set__last_metrics={}, - set__metric_stats={}, - set__models__output=[], - unset__output__result=1, - unset__output__error=1, - unset__last_worker=1, - unset__last_worker_report=1, - ) - - if request.clear_all: - updates.update( - set__execution=Execution(), unset__script=1, - ) - else: - updates.update(unset__execution__queue=1) - if task.execution and task.execution.artifacts: - updates.update( - set__execution__artifacts={ - key: artifact - for key, artifact in task.execution.artifacts.items() - if artifact.mode == ArtifactModes.input - } - ) - - res = ResetResponse( - **ChangeStatusRequest( - task=task, - new_status=TaskStatus.created, - force=force, - status_reason="reset", - status_message="reset", - ).execute( - started=None, - completed=None, - published=None, - active_duration=None, - **updates, - ) + clear_all=request.clear_all, ) + res = ResetResponse(**updates, dequeued=dequeued) # do not return artifacts since they are not serializable res.fields.pop("execution.artifacts", None) - for key, value in api_results.items(): + for key, value in attr.asdict(cleanup_res).items(): setattr(res, key, value) call.result.data_model = res +@attr.s(auto_attribs=True) +class ResetRes: + reset: int = 0 + dequeued: int = 0 + cleanup_res: CleanupResult = None + + def __add__(self, other: Tuple[dict, CleanupResult, dict]): + dequeued, other_res, _ = other + dequeued = dequeued.get("removed", 0) if dequeued else 0 + return ResetRes( + reset=self.reset + 1, + dequeued=self.dequeued + dequeued, + cleanup_res=self.cleanup_res + other_res if self.cleanup_res else other_res, + ) + + +@endpoint("tasks.reset_many", request_data_model=ResetManyRequest) +def reset_many(call: APICall, company_id, request: ResetManyRequest): + res, failures = run_batch_operation( + func=partial( + reset_task, + company_id=company_id, + force=request.force, + return_file_urls=request.return_file_urls, + delete_output_models=request.delete_output_models, + clear_all=request.clear_all, + ), + ids=request.ids, + init_res=ResetRes(), + ) + + if res.cleanup_res: + cleanup_res = dict( + deleted_models=res.cleanup_res.deleted_models, + urls=attr.asdict(res.cleanup_res.urls), + ) + else: + cleanup_res = {} + call.result.data = dict( + reset=res.reset, dequeued=res.dequeued, **cleanup_res, failures=failures, + ) + + @endpoint( "tasks.archive", request_data_model=ArchiveRequest, response_data_model=ArchiveResponse, ) def archive(call: APICall, company_id, request: ArchiveRequest): - archived = 0 tasks = TaskBLL.assert_exists( company_id, task_ids=request.tasks, only=("id", "execution", "status", "project", "system_tags"), ) + archived = 0 for task in tasks: - try: - TaskBLL.dequeue_and_change_status( - task, company_id, request.status_message, request.status_reason, - ) - except APIError: - # dequeue may fail if the task was not enqueued - pass - task.update( + archived += archive_task( + company_id=company_id, + task=task, status_message=request.status_message, status_reason=request.status_reason, - system_tags=sorted( - set(task.system_tags) | {EntityVisibility.archived.value} - ), - last_change=datetime.utcnow(), ) - archived += 1 - call.result.data_model = ArchiveResponse(archived=archived) +@endpoint( + "tasks.archive_many", + request_data_model=ArchiveManyRequest, + response_data_model=ArchiveManyResponse, +) +def archive_many(call: APICall, company_id, request: ArchiveManyRequest): + archived, failures = run_batch_operation( + func=partial( + archive_task, + company_id=company_id, + status_message=request.status_message, + status_reason=request.status_reason, + ), + ids=request.ids, + init_res=0, + ) + call.result.data_model = ArchiveManyResponse(archived=archived, failures=failures) + + @endpoint("tasks.delete", request_data_model=DeleteRequest) def delete(call: APICall, company_id, request: DeleteRequest): - task = TaskBLL.get_task_with_access( - request.task, company_id=company_id, requires_write_access=True + deleted, task, cleanup_res = delete_task( + task_id=request.task, + company_id=company_id, + move_to_trash=request.move_to_trash, + force=request.force, + return_file_urls=request.return_file_urls, + delete_output_models=request.delete_output_models, ) + if deleted: + _reset_cached_tags(company_id, projects=[task.project] if task.project else []) + call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res)) - move_to_trash = request.move_to_trash - force = request.force - if task.status != TaskStatus.created and not force: - raise errors.bad_request.TaskCannotBeDeleted( - "due to status, use force=True", - task=task.id, - expected=TaskStatus.created, - current=task.status, +@attr.s(auto_attribs=True) +class DeleteRes: + deleted: int = 0 + projects: Set = set() + cleanup_res: CleanupResult = None + + def __add__(self, other: Tuple[int, Task, CleanupResult]): + del_count, task, other_res = other + + return DeleteRes( + deleted=self.deleted + del_count, + projects=self.projects | {task.project}, + cleanup_res=self.cleanup_res + other_res if self.cleanup_res else other_res, ) - with translate_errors_context(): - result = cleanup_task( - task, - force=force, + +@endpoint("tasks.delete_many", request_data_model=DeleteManyRequest) +def delete_many(call: APICall, company_id, request: DeleteManyRequest): + res, failures = run_batch_operation( + func=partial( + delete_task, + company_id=company_id, + move_to_trash=request.move_to_trash, + force=request.force, return_file_urls=request.return_file_urls, delete_output_models=request.delete_output_models, - ) + ), + ids=request.ids, + init_res=DeleteRes(), + ) - if move_to_trash: - collection_name = task._get_collection_name() - archived_collection = "{}__trash".format(collection_name) - task.switch_collection(archived_collection) - try: - # A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force - # an insert. However, if for some reason such an ID exists, let's make sure we'll keep going. - with TimingContext("mongo", "save_task"): - task.save(force_insert=True) - except Exception: - pass - task.switch_collection(collection_name) + if res.deleted: + _reset_cached_tags(company_id, projects=list(res.projects)) - task.delete() - _reset_cached_tags(company_id, projects=[task.project]) - update_project_time(task.project) - - call.result.data = dict(deleted=True, **attr.asdict(result)) + cleanup_res = attr.asdict(res.cleanup_res) if res.cleanup_res else {} + call.result.data = dict(deleted=res.deleted, **cleanup_res, failures=failures) @endpoint( @@ -1043,17 +1064,44 @@ def delete(call: APICall, company_id, request: DeleteRequest): request_data_model=PublishRequest, response_data_model=PublishResponse, ) -def publish(call: APICall, company_id, req_model: PublishRequest): - call.result.data_model = PublishResponse( - **TaskBLL.publish_task( - task_id=req_model.task, - company_id=company_id, - publish_model=req_model.publish_model, - force=req_model.force, - status_reason=req_model.status_reason, - status_message=req_model.status_message, - ) +def publish(call: APICall, company_id, request: PublishRequest): + updates = publish_task( + task_id=request.task, + company_id=company_id, + force=request.force, + publish_model_func=ModelBLL.publish_model if request.publish_model else None, + status_reason=request.status_reason, + status_message=request.status_message, ) + call.result.data_model = PublishResponse(**updates) + + +@attr.s(auto_attribs=True) +class PublishRes: + published: int = 0 + + def __add__(self, other: dict): + return PublishRes(published=self.published + 1) + + +@endpoint("tasks.publish_many", request_data_model=PublishManyRequest) +def publish_many(call: APICall, company_id, request: PublishManyRequest): + res, failures = run_batch_operation( + func=partial( + publish_task, + company_id=company_id, + force=request.force, + publish_model_func=ModelBLL.publish_model + if request.publish_model + else None, + status_reason=request.status_reason, + status_message=request.status_message, + ), + ids=request.ids, + init_res=PublishRes(), + ) + + call.result.data = dict(published=res.published, failures=failures) @endpoint( diff --git a/apiserver/services/utils.py b/apiserver/services/utils.py index 81e88c9..947fd73 100644 --- a/apiserver/services/utils.py +++ b/apiserver/services/utils.py @@ -28,16 +28,23 @@ def get_tags_response(ret: dict) -> dict: def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]): """ + Make sure that tags are always returned sorted For old clients both tags and system tags are returned in 'tags' field """ - if call.requested_endpoint_version >= PartialVersion("2.3"): - return if isinstance(documents, dict): documents = [documents] + + merge_tags = call.requested_endpoint_version < PartialVersion("2.3") for doc in documents: - system_tags = doc.get("system_tags") - if system_tags: - doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags)) + if merge_tags: + system_tags = doc.get("system_tags") + if system_tags: + doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags)) + + for field in ("system_tags", "tags"): + tags = doc.get(field) + if tags: + doc[field] = sorted(tags) def conform_tag_fields(call: APICall, document: dict, validate=False): diff --git a/apiserver/tests/automated/__init__.py b/apiserver/tests/automated/__init__.py index f79f8e0..f8e3a85 100644 --- a/apiserver/tests/automated/__init__.py +++ b/apiserver/tests/automated/__init__.py @@ -69,16 +69,6 @@ class TestService(TestCase, TestServiceInterface): delete_params=delete_params, ) - def create_temp_version(self, *, client=None, **kwargs) -> str: - return self._create_temp_helper( - service="datasets", - create_endpoint="create_version", - delete_endpoint="delete_version", - object_name="version", - create_params=kwargs, - client=client, - ) - def setUp(self, version="1.7"): self._api = APIClient(base_url=f"http://localhost:8008/v{version}") self._deferred = [] diff --git a/apiserver/tests/automated/test_batch_operations.py b/apiserver/tests/automated/test_batch_operations.py new file mode 100644 index 0000000..9f4f909 --- /dev/null +++ b/apiserver/tests/automated/test_batch_operations.py @@ -0,0 +1,124 @@ +from apiserver.database.utils import id as db_id +from apiserver.tests.automated import TestService + + +class TestBatchOperations(TestService): + name = "batch operation test" + comment = "this is a comment" + delete_params = dict(can_fail=True, force=True) + + def setUp(self, version="2.13"): + super().setUp(version=version) + + def test_tasks(self): + tasks = [self._temp_task() for _ in range(2)] + models = [ + self._temp_task_model(task=t, uri=f"uri_{idx}") + for idx, t in enumerate(tasks) + ] + missing_id = db_id() + ids = [*tasks, missing_id] + + # enqueue + res = self.api.tasks.enqueue_many(ids=ids) + self.assertEqual(res.queued, 2) + self._assert_failures(res, [missing_id]) + data = self.api.tasks.get_all_ex(id=ids).tasks + self.assertEqual({t.status for t in data}, {"queued"}) + + # stop + for t in tasks: + self.api.tasks.started(task=t) + res = self.api.tasks.stop_many(ids=ids) + self.assertEqual(res.stopped, 2) + self._assert_failures(res, [missing_id]) + data = self.api.tasks.get_all_ex(id=ids).tasks + self.assertEqual({t.status for t in data}, {"stopped"}) + + # publish + res = self.api.tasks.publish_many(ids=ids, publish_model=False) + self.assertEqual(res.published, 2) + self._assert_failures(res, [missing_id]) + data = self.api.tasks.get_all_ex(id=ids).tasks + self.assertEqual({t.status for t in data}, {"published"}) + + # reset + res = self.api.tasks.reset_many( + ids=ids, delete_output_models=True, return_file_urls=True, force=True + ) + self.assertEqual(res.reset, 2) + self.assertEqual(res.deleted_models, 2) + self.assertEqual(set(res.urls.model_urls), {"uri_0", "uri_1"}) + self._assert_failures(res, [missing_id]) + data = self.api.tasks.get_all_ex(id=ids).tasks + self.assertEqual({t.status for t in data}, {"created"}) + + # archive + res = self.api.tasks.archive_many(ids=ids) + self.assertEqual(res.archived, 2) + self._assert_failures(res, [missing_id]) + data = self.api.tasks.get_all_ex(id=ids).tasks + self.assertTrue(all("archived" in t.system_tags for t in data)) + + # delete + res = self.api.tasks.delete_many( + ids=ids, delete_output_models=True, return_file_urls=True + ) + self.assertEqual(res.deleted, 2) + self._assert_failures(res, [missing_id]) + data = self.api.tasks.get_all_ex(id=ids).tasks + self.assertEqual(data, []) + + def test_models(self): + uris = [f"file:///{i}" for i in range(2)] + models = [self._temp_model(uri=uri) for uri in uris] + missing_id = db_id() + ids = [*models, missing_id] + + # publish + task = self._temp_task() + self.api.models.edit(model=ids[0], ready=False, task=task) + self.api.tasks.add_or_update_model( + task=task, name="output", type="input", model=ids[0] + ) + res = self.api.models.publish_many( + ids=ids, publish_task=True, force_publish_task=True + ) + self.assertEqual(res.published, 1) + self.assertEqual(res.published_tasks[0].id, task) + self._assert_failures(res, [ids[1], missing_id]) + + # archive + res = self.api.models.archive_many(ids=ids) + self.assertEqual(res.archived, 2) + self._assert_failures(res, [missing_id]) + data = self.api.models.get_all_ex(id=ids).models + for m in data: + self.assertIn("archived", m.system_tags) + + # delete + res = self.api.models.delete_many(ids=[*models, missing_id], force=True) + self.assertEqual(res.deleted, 2) + self.assertEqual(set(res.urls), set(uris)) + self._assert_failures(res, [missing_id]) + data = self.api.models.get_all_ex(id=ids).models + self.assertEqual(data, []) + + def _assert_failures(self, res, failed_ids): + self.assertEqual(set(f.id for f in res.failures), set(failed_ids)) + + def _temp_model(self, **kwargs): + self.update_missing(kwargs, name=self.name, uri="file:///a/b", labels={}) + return self.create_temp("models", delete_params=self.delete_params, **kwargs) + + def _temp_task(self): + return self.create_temp( + service="tasks", type="testing", name=self.name, input=dict(view={}), + ) + + def _temp_task_model(self, task, **kwargs) -> str: + model = self._temp_model(ready=False, task=task, **kwargs) + self.api.tasks.add_or_update_model( + task=task, name="output", type="output", model=model + ) + return model