From 6b0c45a861c33ecc26a9522b148d2a1a4ee97182 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 3 May 2021 18:07:37 +0300 Subject: [PATCH] Fix batch operations results --- apiserver/apimodels/batch.py | 16 +- apiserver/apimodels/models.py | 13 +- apiserver/apimodels/tasks.py | 35 ++- apiserver/bll/task/task_operations.py | 24 +- apiserver/bll/util.py | 10 +- apiserver/schema/services/_common.conf | 13 +- apiserver/schema/services/models.conf | 32 ++- apiserver/schema/services/tasks.conf | 126 ++++++---- apiserver/services/models.py | 95 +++----- apiserver/services/tasks.py | 220 +++++++++--------- .../tests/automated/test_batch_operations.py | 62 ++--- 11 files changed, 368 insertions(+), 278 deletions(-) diff --git a/apiserver/apimodels/batch.py b/apiserver/apimodels/batch.py index 5863f1d..af21385 100644 --- a/apiserver/apimodels/batch.py +++ b/apiserver/apimodels/batch.py @@ -1,9 +1,11 @@ from typing import Sequence +from jsonmodels.fields import StringField from jsonmodels.models import Base from jsonmodels.validators import Length -from apiserver.apimodels import ListField, IntField +from apiserver.apimodels import ListField +from apiserver.apimodels.base import UpdateResponse class BatchRequest(Base): @@ -11,5 +13,13 @@ class BatchRequest(Base): class BatchResponse(Base): - succeeded: int = IntField() - failures: Sequence[dict] = ListField([dict]) + succeeded: Sequence[dict] = ListField([dict]) + failed: Sequence[dict] = ListField([dict]) + + +class UpdateBatchItem(UpdateResponse): + id: str = StringField() + + +class UpdateBatchResponse(BatchResponse): + succeeded: Sequence[UpdateBatchItem] = ListField(UpdateBatchItem) diff --git a/apiserver/apimodels/models.py b/apiserver/apimodels/models.py index 72350da..cabc700 100644 --- a/apiserver/apimodels/models.py +++ b/apiserver/apimodels/models.py @@ -3,13 +3,12 @@ from six import string_types from apiserver.apimodels import ListField, DictField from apiserver.apimodels.base import UpdateResponse -from apiserver.apimodels.batch import BatchRequest, BatchResponse +from apiserver.apimodels.batch import BatchRequest from apiserver.apimodels.metadata import ( MetadataItem, DeleteMetadata, AddOrUpdateMetadata, ) -from apiserver.apimodels.tasks import PublishResponse as TaskPublishResponse class GetFrameworksRequest(models.Base): @@ -51,10 +50,6 @@ class ModelsDeleteManyRequest(BatchRequest): force = fields.BoolField(default=False) -class ModelsDeleteManyResponse(BatchResponse): - urls = fields.ListField([str]) - - class PublishModelRequest(ModelRequest): force_publish_task = fields.BoolField(default=False) publish_task = fields.BoolField(default=True) @@ -62,7 +57,7 @@ class PublishModelRequest(ModelRequest): class ModelTaskPublishResponse(models.Base): id = fields.StringField(required=True) - data = fields.EmbeddedField(TaskPublishResponse) + data = fields.EmbeddedField(UpdateResponse) class PublishModelResponse(UpdateResponse): @@ -74,10 +69,6 @@ class ModelsPublishManyRequest(BatchRequest): publish_task = fields.BoolField(default=True) -class ModelsPublishManyResponse(BatchResponse): - published_tasks = fields.ListField([ModelTaskPublishResponse]) - - class DeleteMetadataRequest(DeleteMetadata): model = fields.StringField(required=True) diff --git a/apiserver/apimodels/tasks.py b/apiserver/apimodels/tasks.py index 29a1aa6..7acd531 100644 --- a/apiserver/apimodels/tasks.py +++ b/apiserver/apimodels/tasks.py @@ -1,13 +1,12 @@ from typing import Sequence -import six from jsonmodels import models from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField from jsonmodels.validators import Enum, Length from apiserver.apimodels import DictField, ListField from apiserver.apimodels.base import UpdateResponse -from apiserver.apimodels.batch import BatchRequest +from apiserver.apimodels.batch import BatchRequest, UpdateBatchItem, BatchResponse from apiserver.database.model.task.task import ( TaskType, ArtifactModes, @@ -45,19 +44,43 @@ class EnqueueResponse(UpdateResponse): queued = IntField() +class EnqueueBatchItem(UpdateBatchItem): + queued: bool = BoolField() + + +class EnqueueManyResponse(BatchResponse): + succeeded: Sequence[EnqueueBatchItem] = ListField(EnqueueBatchItem) + + class DequeueResponse(UpdateResponse): dequeued = IntField() +class DequeueBatchItem(UpdateBatchItem): + dequeued: bool = BoolField() + + +class DequeueManyResponse(BatchResponse): + succeeded: Sequence[DequeueBatchItem] = ListField(DequeueBatchItem) + + class ResetResponse(UpdateResponse): - deleted_indices = ListField(items_types=six.string_types) dequeued = DictField() - frames = DictField() events = DictField() deleted_models = IntField() urls = DictField() +class ResetBatchItem(UpdateBatchItem): + dequeued: bool = BoolField() + deleted_models = IntField() + urls = DictField() + + +class ResetManyResponse(BatchResponse): + succeeded: Sequence[ResetBatchItem] = ListField(ResetBatchItem) + + class TaskRequest(models.Base): task = StringField(required=True) @@ -89,10 +112,6 @@ class PublishRequest(UpdateRequest): publish_model = BoolField(default=True) -class PublishResponse(UpdateResponse): - pass - - class TaskData(models.Base): """ This is a partial description of task can be updated incrementally diff --git a/apiserver/bll/task/task_operations.py b/apiserver/bll/task/task_operations.py index 48a5d02..a5d735a 100644 --- a/apiserver/bll/task/task_operations.py +++ b/apiserver/bll/task/task_operations.py @@ -56,13 +56,13 @@ def archive_task( except APIError: # dequeue may fail if the task was not enqueued pass - task.update( + + return task.update( status_message=status_message, status_reason=status_reason, add_to_set__system_tags=EntityVisibility.archived.value, last_change=datetime.utcnow(), ) - return 1 def unarchive_task( @@ -82,6 +82,26 @@ def unarchive_task( ) +def dequeue_task( + task_id: str, + company_id: str, + status_message: str, + status_reason: str, +) -> Tuple[int, dict]: + query = dict(id=task_id, company=company_id) + task = Task.get_for_writing(**query) + if not task: + raise errors.bad_request.InvalidTaskId(**query) + + res = TaskBLL.dequeue_and_change_status( + task, + company_id, + status_message=status_message, + status_reason=status_reason, + ) + return 1, res + + def enqueue_task( task_id: str, company_id: str, diff --git a/apiserver/bll/util.py b/apiserver/bll/util.py index 95d228f..62c5e41 100644 --- a/apiserver/bll/util.py +++ b/apiserver/bll/util.py @@ -113,13 +113,13 @@ T = TypeVar("T") def run_batch_operation( - func: Callable[[str], T], init_res: T, ids: Sequence[str] -) -> Tuple[T, Sequence]: - res = init_res + func: Callable[[str], T], ids: Sequence[str] +) -> Tuple[Sequence[Tuple[str, T]], Sequence[dict]]: + results = list() failures = list() for _id in ids: try: - res += func(_id) + results.append((_id, func(_id))) except APIError as err: failures.append( { @@ -131,4 +131,4 @@ def run_batch_operation( }, } ) - return res, failures + return results, failures diff --git a/apiserver/schema/services/_common.conf b/apiserver/schema/services/_common.conf index 7594887..9ab40f8 100644 --- a/apiserver/schema/services/_common.conf +++ b/apiserver/schema/services/_common.conf @@ -43,9 +43,18 @@ batch_operation { type: object properties { succeeded { - type: integer + type: array + items { + type: object + properties { + id: { + description: ID of the succeeded entity + type: string + } + } + } } - failures { + failed { type: array items { type: object diff --git a/apiserver/schema/services/models.conf b/apiserver/schema/services/models.conf index 97e220d..90535a8 100644 --- a/apiserver/schema/services/models.conf +++ b/apiserver/schema/services/models.conf @@ -680,11 +680,11 @@ publish_many { } response { properties { - succeeded.description: "Number of models published" - published_tasks { - type: array - items: ${_definitions.published_task_item} + succeeded.items.properties.updated { + description: "Indicates whether the model was updated" + type: boolean } + succeeded.items.properties.published_task: ${_definitions.published_task_item} } } } @@ -733,7 +733,10 @@ archive_many { } response { properties { - succeeded.description: "Number of models archived" + succeeded.items.properties.archived { + description: "Indicates whether the model was archived" + type: boolean + } } } } @@ -748,7 +751,10 @@ unarchive_many { } response { properties { - succeeded.description: "Number of models unarchived" + succeeded.items.properties.unarchived { + description: "Indicates whether the model was unarchived" + type: boolean + } } } } @@ -768,11 +774,13 @@ delete_many { } response { properties { - succeeded.description: "Number of models deleted" - urls { - descrition: "The urls of the deleted model files" - type: array - items {type: string} + succeeded.items.properties.deleted { + description: "Indicates whether the model was deleted" + type: boolean + } + succeeded.items.properties.url { + description: "The url of the model file" + type: string } } } @@ -814,7 +822,7 @@ delete { response { properties { url { - descrition: "The url of the model file" + description: "The url of the model file" type: string } } diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index 3365a8a..4be36e9 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -39,6 +39,20 @@ _definitions { } } } + response { + properties { + succeeded.items.properties.updated { + description: "Number of tasks updated (0 or 1)" + type: integer + enum: [ 0, 1 ] + } + succeeded.items.properties.fields { + description: "Updated fields names and values" + type: object + additionalProperties: true + } + } + } } update_response { type: object @@ -1370,21 +1384,11 @@ reset { } ${_references.status_change_request} response: ${_definitions.update_response} { properties { - deleted_indices { - description: "List of deleted ES indices that were removed as part of the reset process" - type: array - items { type: string } - } dequeued { description: "Response from queues.remove_task" type: object additionalProperties: true } - frames { - description: "Response from frames.rollback" - type: object - additionalProperties: true - } events { description: "Response from events.delete_for_task" type: object @@ -1442,18 +1446,26 @@ reset_many { } response { properties { - succeeded.description: "Number of tasks reset" - dequeued { - description: "Number of tasks dequeued" + succeeded.items.properties.dequeued { + description: "Indicates whether the task was dequeued" + type: boolean + } + succeeded.items.properties.updated { + description: "Number of tasks updated (0 or 1)" + type: integer + enum: [ 0, 1 ] + } + succeeded.items.properties.fields { + description: "Updated fields names and values" type: object additionalProperties: true } - deleted_models { + succeeded.items.properties.deleted_models { description: "Number of output models deleted by the reset" type: integer } - urls { - description: "The urls of the files that were uploaded by the tasks. Returned if the 'return_file_urls' was set to 'true'" + succeeded.items.properties.urls { + description: "The urls of the files that were uploaded by the task. Returned if the 'return_file_urls' was set to 'true'" "$ref": "#/definitions/task_urls" } } @@ -1486,21 +1498,24 @@ delete_many { } response { properties { - succeeded.description: "Number of tasks deleted" - updated_children { + succeeded.items.properties.deleted { + description: "Indicates whether the task was deleted" + type: boolean + } + succeeded.items.properties.updated_children { description: "Number of child tasks whose parent property was updated" type: integer } - updated_models { + succeeded.items.properties.updated_models { description: "Number of models whose task property was updated" type: integer } - deleted_models { + succeeded.items.properties.deleted_models { description: "Number of deleted output models" type: integer } - urls { - description: "The urls of the files that were uploaded by the tasks. Returned if the 'return_file_urls' was set to 'true'" + succeeded.items.properties.urls { + description: "The urls of the files that were uploaded by the task. Returned if the 'return_file_urls' was set to 'true'" "$ref": "#/definitions/task_urls" } } @@ -1609,31 +1624,53 @@ archive { } } archive_many { - "2.13": ${_definitions.change_many_request} { + "2.13": ${_definitions.batch_operation} { description: Archive tasks request { properties { ids.description: "IDs of the tasks to archive" + status_reason { + description: Reason for status change + type: string + } + status_message { + description: Extra information regarding status change + type: string + } } - response { - properties { - succeeded.description: "Number of tasks archived" + response { + properties { + succeeded.items.properties.archived { + description: "Indicates whether the task was archived" + type: boolean + } } } } } } unarchive_many { - "2.13": ${_definitions.change_many_request} { + "2.13": ${_definitions.batch_operation} { description: Unarchive tasks request { properties { ids.description: "IDs of the tasks to unarchive" + status_reason { + description: Reason for status change + type: string + } + status_message { + description: Extra information regarding status change + type: string + } } } response { properties { - succeeded.description: "Number of tasks unarchived" + succeeded.items.properties.unarchived { + description: "Indicates whether the task was unarchived" + type: boolean + } } } } @@ -1687,11 +1724,6 @@ stop_many { } } } - response { - properties { - succeeded.description: "Number of tasks stopped" - } - } } } stopped { @@ -1775,11 +1807,6 @@ publish_many { } } } - response { - properties { - succeeded.description: "Number of tasks published" - } - } } } enqueue { @@ -1837,7 +1864,10 @@ enqueue_many { } response { properties { - succeeded.description: "Number of tasks enqueued" + succeeded.items.properties.queued { + description: "Indicates whether the task was queued" + type: boolean + } } } } @@ -1863,6 +1893,24 @@ dequeue { } } } +dequeue_many { + "2.13": ${_definitions.change_many_request} { + description: Dequeue tasks + request { + properties { + ids.description: "IDs of the tasks to dequeue" + } + } + response { + properties { + succeeded.items.properties.dequeued { + description: "Indicates whether the task was dequeued" + type: boolean + } + } + } + } +} set_requirements { "2.1" { description: """Set the script requirements for a task""" diff --git a/apiserver/services/models.py b/apiserver/services/models.py index 6c4cff3..642011b 100644 --- a/apiserver/services/models.py +++ b/apiserver/services/models.py @@ -1,8 +1,7 @@ from datetime import datetime from functools import partial -from typing import Sequence, Tuple, Set +from typing import Sequence -import attr from mongoengine import Q, EmbeddedDocument from apiserver import database @@ -15,15 +14,12 @@ from apiserver.apimodels.models import ( CreateModelResponse, PublishModelRequest, PublishModelResponse, - ModelTaskPublishResponse, GetFrameworksRequest, DeleteModelRequest, DeleteMetadataRequest, AddOrUpdateMetadataRequest, ModelsPublishManyRequest, - ModelsPublishManyResponse, ModelsDeleteManyRequest, - ModelsDeleteManyResponse, ) from apiserver.bll.model import ModelBLL from apiserver.bll.organization import OrgBLL, Tags @@ -501,26 +497,13 @@ def set_ready(call: APICall, company_id: str, request: PublishModelRequest): ) -@attr.s(auto_attribs=True) -class PublishRes: - published: int = 0 - published_tasks: Sequence = [] - - def __add__(self, other: Tuple[int, ModelTaskPublishResponse]): - published, response = other - return PublishRes( - published=self.published + published, - published_tasks=[*self.published_tasks, *([response] if response else [])], - ) - - @endpoint( "models.publish_many", request_data_model=ModelsPublishManyRequest, - response_data_model=ModelsPublishManyResponse, + response_data_model=BatchResponse, ) def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest): - res, failures = run_batch_operation( + results, failures = run_batch_operation( func=partial( ModelBLL.publish_model, company_id=company_id, @@ -528,11 +511,16 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest): publish_task_func=publish_task if request.publish_task else None, ), ids=request.ids, - init_res=PublishRes(), ) - call.result.data_model = ModelsPublishManyResponse( - succeeded=res.published, published_tasks=res.published_tasks, failures=failures, + call.result.data_model = BatchResponse( + succeeded=[ + dict( + id=_id, updated=bool(updated), published_task=published_task.to_struct() + ) + for _id, (updated, published_task) in results + ], + failed=failures, ) @@ -546,41 +534,30 @@ def delete(call: APICall, company_id, request: DeleteModelRequest): company_id, projects=[model.project] if model.project else [] ) - call.result.data = dict(deleted=del_count > 0, url=model.uri) - - -@attr.s(auto_attribs=True) -class DeleteRes: - deleted: int = 0 - projects: Set = set() - urls: Set = set() - - def __add__(self, other: Tuple[int, Model]): - del_count, model = other - return DeleteRes( - deleted=self.deleted + del_count, - projects=self.projects | {model.project}, - urls=self.urls | {model.uri}, - ) + call.result.data = dict(deleted=bool(del_count), url=model.uri) @endpoint( "models.delete_many", request_data_model=ModelsDeleteManyRequest, - response_data_model=ModelsDeleteManyResponse, + response_data_model=BatchResponse, ) def delete(call: APICall, company_id, request: ModelsDeleteManyRequest): - res, failures = run_batch_operation( + results, failures = run_batch_operation( func=partial(ModelBLL.delete_model, company_id=company_id, force=request.force), ids=request.ids, - init_res=DeleteRes(), ) - if res.deleted: - _reset_cached_tags(company_id, projects=list(res.projects)) - res.urls.discard(None) - call.result.data_model = ModelsDeleteManyResponse( - succeeded=res.deleted, urls=list(res.urls), failures=failures, + if results: + projects = set(model.project for _, (_, model) in results) + _reset_cached_tags(company_id, projects=list(projects)) + + call.result.data_model = BatchResponse( + succeeded=[ + dict(id=_id, deleted=bool(deleted), url=model.uri) + for _id, (deleted, model) in results + ], + failed=failures, ) @@ -590,12 +567,13 @@ def delete(call: APICall, company_id, request: ModelsDeleteManyRequest): response_data_model=BatchResponse, ) def archive_many(call: APICall, company_id, request: BatchRequest): - archived, failures = run_batch_operation( - func=partial(ModelBLL.archive_model, company_id=company_id), - ids=request.ids, - init_res=0, + results, failures = run_batch_operation( + func=partial(ModelBLL.archive_model, company_id=company_id), ids=request.ids, + ) + call.result.data_model = BatchResponse( + succeeded=[dict(id=_id, archived=bool(archived)) for _id, archived in results], + failed=failures, ) - call.result.data_model = BatchResponse(succeeded=archived, failures=failures) @endpoint( @@ -604,12 +582,15 @@ def archive_many(call: APICall, company_id, request: BatchRequest): response_data_model=BatchResponse, ) def unarchive_many(call: APICall, company_id, request: BatchRequest): - unarchived, failures = run_batch_operation( - func=partial(ModelBLL.unarchive_model, company_id=company_id), - ids=request.ids, - init_res=0, + results, failures = run_batch_operation( + func=partial(ModelBLL.unarchive_model, company_id=company_id), ids=request.ids, + ) + call.result.data_model = BatchResponse( + succeeded=[ + dict(id=_id, unarchived=bool(unarchived)) for _id, unarchived in results + ], + failed=failures, ) - call.result.data_model = BatchResponse(succeeded=unarchived, failures=failures,) @endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest) diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 9442bb2..8689302 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -1,7 +1,7 @@ from copy import deepcopy from datetime import datetime from functools import partial -from typing import Sequence, Union, Tuple, Set +from typing import Sequence, Union, Tuple import attr import dpath @@ -17,12 +17,15 @@ from apiserver.apimodels.base import ( MakePublicRequest, MoveRequest, ) -from apiserver.apimodels.batch import BatchResponse +from apiserver.apimodels.batch import ( + BatchResponse, + UpdateBatchResponse, + UpdateBatchItem, +) from apiserver.apimodels.tasks import ( StartedResponse, ResetResponse, PublishRequest, - PublishResponse, CreateRequest, UpdateRequest, SetRequirementsRequest, @@ -54,6 +57,12 @@ from apiserver.apimodels.tasks import ( DeleteManyRequest, PublishManyRequest, TaskBatchRequest, + EnqueueManyResponse, + EnqueueBatchItem, + DequeueBatchItem, + DequeueManyResponse, + ResetManyResponse, + ResetBatchItem, ) from apiserver.bll.event import EventBLL from apiserver.bll.model import ModelBLL @@ -77,10 +86,10 @@ from apiserver.bll.task.param_utils import ( params_unprepare_from_saved, escape_paths, ) -from apiserver.bll.task.task_cleanup import CleanupResult from apiserver.bll.task.task_operations import ( stop_task, enqueue_task, + dequeue_task, reset_task, archive_task, delete_task, @@ -287,21 +296,13 @@ def stop(call: APICall, company_id, req_model: UpdateRequest): ) -@attr.s(auto_attribs=True) -class StopRes: - stopped: int = 0 - - def __add__(self, other: dict): - return StopRes(stopped=self.stopped + 1) - - @endpoint( "tasks.stop_many", request_data_model=StopManyRequest, - response_data_model=BatchResponse, + response_data_model=UpdateBatchResponse, ) def stop_many(call: APICall, company_id, request: StopManyRequest): - res, failures = run_batch_operation( + results, failures = run_batch_operation( func=partial( stop_task, company_id=company_id, @@ -310,9 +311,11 @@ def stop_many(call: APICall, company_id, request: StopManyRequest): force=request.force, ), ids=request.ids, - init_res=StopRes(), ) - call.result.data_model = BatchResponse(succeeded=res.stopped, failures=failures) + call.result.data_model = UpdateBatchResponse( + succeeded=[UpdateBatchItem(id=_id, **res) for _id, res in results], + failed=failures, + ) @endpoint( @@ -854,22 +857,13 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest): call.result.data_model = EnqueueResponse(queued=queued, **res) -@attr.s(auto_attribs=True) -class EnqueueRes: - queued: int = 0 - - def __add__(self, other: Tuple[int, dict]): - queued, _ = other - return EnqueueRes(queued=self.queued + queued) - - @endpoint( "tasks.enqueue_many", request_data_model=EnqueueManyRequest, - response_data_model=BatchResponse, + response_data_model=EnqueueManyResponse, ) def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest): - res, failures = run_batch_operation( + results, failures = run_batch_operation( func=partial( enqueue_task, company_id=company_id, @@ -879,9 +873,14 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest): validate=request.validate_tasks, ), ids=request.ids, - init_res=EnqueueRes(), ) - call.result.data_model = BatchResponse(succeeded=res.queued, failures=failures) + call.result.data_model = EnqueueManyResponse( + succeeded=[ + EnqueueBatchItem(id=_id, queued=bool(queued), **res) + for _id, (queued, res) in results + ], + failed=failures, + ) @endpoint( @@ -890,23 +889,37 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest): response_data_model=DequeueResponse, ) def dequeue(call: APICall, company_id, request: UpdateRequest): - task = TaskBLL.get_task_with_access( - request.task, + dequeued, res = dequeue_task( + task_id=request.task, company_id=company_id, - only=("id", "execution", "status", "project", "enqueue_status"), - requires_write_access=True, + status_message=request.status_message, + status_reason=request.status_reason, ) - res = DequeueResponse( - **TaskBLL.dequeue_and_change_status( - task, - company_id, + call.result.data_model = DequeueResponse(dequeued=dequeued, **res) + + +@endpoint( + "tasks.dequeue_many", + request_data_model=TaskBatchRequest, + response_data_model=DequeueManyResponse, +) +def dequeue_many(call: APICall, company_id, request: TaskBatchRequest): + results, failures = run_batch_operation( + func=partial( + dequeue_task, + company_id=company_id, status_message=request.status_message, status_reason=request.status_reason, - ) + ), + ids=request.ids, + ) + call.result.data_model = DequeueManyResponse( + succeeded=[ + DequeueBatchItem(id=_id, dequeued=bool(dequeued), **res) + for _id, (dequeued, res) in results + ], + failed=failures, ) - - res.dequeued = 1 - call.result.data_model = res @endpoint( @@ -932,25 +945,13 @@ def reset(call: APICall, company_id, request: ResetRequest): call.result.data_model = res -@attr.s(auto_attribs=True) -class ResetRes: - reset: int = 0 - dequeued: int = 0 - cleanup_res: CleanupResult = None - - def __add__(self, other: Tuple[dict, CleanupResult, dict]): - dequeued, other_res, _ = other - dequeued = dequeued.get("removed", 0) if dequeued else 0 - return ResetRes( - reset=self.reset + 1, - dequeued=self.dequeued + dequeued, - cleanup_res=self.cleanup_res + other_res if self.cleanup_res else other_res, - ) - - -@endpoint("tasks.reset_many", request_data_model=ResetManyRequest) +@endpoint( + "tasks.reset_many", + request_data_model=ResetManyRequest, + response_data_model=ResetManyResponse, +) def reset_many(call: APICall, company_id, request: ResetManyRequest): - res, failures = run_batch_operation( + results, failures = run_batch_operation( func=partial( reset_task, company_id=company_id, @@ -960,18 +961,26 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest): clear_all=request.clear_all, ), ids=request.ids, - init_res=ResetRes(), ) - if res.cleanup_res: - cleanup_res = dict( - deleted_models=res.cleanup_res.deleted_models, - urls=attr.asdict(res.cleanup_res.urls) if res.cleanup_res.urls else {}, - ) - else: - cleanup_res = {} - call.result.data = dict( - succeeded=res.reset, dequeued=res.dequeued, **cleanup_res, failures=failures, + def clean_res(res: dict) -> dict: + # do not return artifacts since they are not serializable + fields = res.get("fields") + if fields: + fields.pop("execution.artifacts", None) + return res + + call.result.data_model = ResetManyResponse( + succeeded=[ + ResetBatchItem( + id=_id, + dequeued=bool(dequeued.get("removed")) if dequeued else False, + **attr.asdict(cleanup), + **clean_res(res), + ) + for _id, (dequeued, cleanup, res) in results + ], + failed=failures, ) @@ -1004,7 +1013,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest): response_data_model=BatchResponse, ) def archive_many(call: APICall, company_id, request: TaskBatchRequest): - archived, failures = run_batch_operation( + results, failures = run_batch_operation( func=partial( archive_task, company_id=company_id, @@ -1012,9 +1021,11 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest): status_reason=request.status_reason, ), ids=request.ids, - init_res=0, ) - call.result.data_model = BatchResponse(succeeded=archived, failures=failures) + call.result.data_model = BatchResponse( + succeeded=[dict(id=_id, archived=bool(archived)) for _id, archived in results], + failed=failures, + ) @endpoint( @@ -1023,7 +1034,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest): response_data_model=BatchResponse, ) def unarchive_many(call: APICall, company_id, request: TaskBatchRequest): - unarchived, failures = run_batch_operation( + results, failures = run_batch_operation( func=partial( unarchive_task, company_id=company_id, @@ -1031,9 +1042,13 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest): status_reason=request.status_reason, ), ids=request.ids, - init_res=0, ) - call.result.data_model = BatchResponse(succeeded=unarchived, failures=failures) + call.result.data_model = BatchResponse( + succeeded=[ + dict(id=_id, unarchived=bool(unarchived)) for _id, unarchived in results + ], + failed=failures, + ) @endpoint("tasks.delete", request_data_model=DeleteRequest) @@ -1051,25 +1066,9 @@ def delete(call: APICall, company_id, request: DeleteRequest): call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res)) -@attr.s(auto_attribs=True) -class DeleteRes: - deleted: int = 0 - projects: Set = set() - cleanup_res: CleanupResult = None - - def __add__(self, other: Tuple[int, Task, CleanupResult]): - del_count, task, other_res = other - - return DeleteRes( - deleted=self.deleted + del_count, - projects=self.projects | {task.project}, - cleanup_res=self.cleanup_res + other_res if self.cleanup_res else other_res, - ) - - @endpoint("tasks.delete_many", request_data_model=DeleteManyRequest) def delete_many(call: APICall, company_id, request: DeleteManyRequest): - res, failures = run_batch_operation( + results, failures = run_batch_operation( func=partial( delete_task, company_id=company_id, @@ -1079,20 +1078,25 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest): delete_output_models=request.delete_output_models, ), ids=request.ids, - init_res=DeleteRes(), ) - if res.deleted: - _reset_cached_tags(company_id, projects=list(res.projects)) + if results: + projects = set(task.project for _, (_, task, _) in results) + _reset_cached_tags(company_id, projects=list(projects)) - cleanup_res = attr.asdict(res.cleanup_res) if res.cleanup_res else {} - call.result.data = dict(succeeded=res.deleted, **cleanup_res, failures=failures) + call.result.data = dict( + succeeded=[ + dict(id=_id, deleted=bool(deleted), **attr.asdict(cleanup_res)) + for _id, (deleted, _, cleanup_res) in results + ], + failed=failures, + ) @endpoint( "tasks.publish", request_data_model=PublishRequest, - response_data_model=PublishResponse, + response_data_model=UpdateResponse, ) def publish(call: APICall, company_id, request: PublishRequest): updates = publish_task( @@ -1103,24 +1107,16 @@ def publish(call: APICall, company_id, request: PublishRequest): status_reason=request.status_reason, status_message=request.status_message, ) - call.result.data_model = PublishResponse(**updates) - - -@attr.s(auto_attribs=True) -class PublishRes: - published: int = 0 - - def __add__(self, other: dict): - return PublishRes(published=self.published + 1) + call.result.data_model = UpdateResponse(**updates) @endpoint( "tasks.publish_many", request_data_model=PublishManyRequest, - response_data_model=BatchResponse, + response_data_model=UpdateBatchResponse, ) def publish_many(call: APICall, company_id, request: PublishManyRequest): - res, failures = run_batch_operation( + results, failures = run_batch_operation( func=partial( publish_task, company_id=company_id, @@ -1132,10 +1128,12 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest): status_message=request.status_message, ), ids=request.ids, - init_res=PublishRes(), ) - call.result.data_model = BatchResponse(succeeded=res.published, failures=failures) + call.result.data_model = UpdateBatchResponse( + succeeded=[UpdateBatchItem(id=_id, **res) for _id, res in results], + failed=failures, + ) @endpoint( diff --git a/apiserver/tests/automated/test_batch_operations.py b/apiserver/tests/automated/test_batch_operations.py index 06b90ce..48ca572 100644 --- a/apiserver/tests/automated/test_batch_operations.py +++ b/apiserver/tests/automated/test_batch_operations.py @@ -21,8 +21,8 @@ class TestBatchOperations(TestService): # enqueue res = self.api.tasks.enqueue_many(ids=ids) - self.assertEqual(res.succeeded, 2) - self._assert_failures(res, [missing_id]) + self._assert_succeeded(res, tasks) + self._assert_failed(res, [missing_id]) data = self.api.tasks.get_all_ex(id=ids).tasks self.assertEqual({t.status for t in data}, {"queued"}) @@ -30,15 +30,15 @@ class TestBatchOperations(TestService): for t in tasks: self.api.tasks.started(task=t) res = self.api.tasks.stop_many(ids=ids) - self.assertEqual(res.succeeded, 2) - self._assert_failures(res, [missing_id]) + self._assert_succeeded(res, tasks) + self._assert_failed(res, [missing_id]) data = self.api.tasks.get_all_ex(id=ids).tasks self.assertEqual({t.status for t in data}, {"stopped"}) # publish res = self.api.tasks.publish_many(ids=ids, publish_model=False) - self.assertEqual(res.succeeded, 2) - self._assert_failures(res, [missing_id]) + self._assert_succeeded(res, tasks) + self._assert_failed(res, [missing_id]) data = self.api.tasks.get_all_ex(id=ids).tasks self.assertEqual({t.status for t in data}, {"published"}) @@ -46,23 +46,26 @@ class TestBatchOperations(TestService): res = self.api.tasks.reset_many( ids=ids, delete_output_models=True, return_file_urls=True, force=True ) - self.assertEqual(res.succeeded, 2) - self.assertEqual(res.deleted_models, 2) - self.assertEqual(set(res.urls.model_urls), {"uri_0", "uri_1"}) - self._assert_failures(res, [missing_id]) + self._assert_succeeded(res, tasks) + self.assertEqual(sum(t.deleted_models for t in res.succeeded), 2) + self.assertEqual( + set(url for t in res.succeeded for url in t.urls.model_urls), + {"uri_0", "uri_1"}, + ) + self._assert_failed(res, [missing_id]) data = self.api.tasks.get_all_ex(id=ids).tasks self.assertEqual({t.status for t in data}, {"created"}) # archive/unarchive res = self.api.tasks.archive_many(ids=ids) - self.assertEqual(res.succeeded, 2) - self._assert_failures(res, [missing_id]) + self._assert_succeeded(res, tasks) + self._assert_failed(res, [missing_id]) data = self.api.tasks.get_all_ex(id=ids).tasks self.assertTrue(all("archived" in t.system_tags for t in data)) res = self.api.tasks.unarchive_many(ids=ids) - self.assertEqual(res.succeeded, 2) - self._assert_failures(res, [missing_id]) + self._assert_succeeded(res, tasks) + self._assert_failed(res, [missing_id]) data = self.api.tasks.get_all_ex(id=ids).tasks self.assertFalse(any("archived" in t.system_tags for t in data)) @@ -70,8 +73,8 @@ class TestBatchOperations(TestService): res = self.api.tasks.delete_many( ids=ids, delete_output_models=True, return_file_urls=True ) - self.assertEqual(res.succeeded, 2) - self._assert_failures(res, [missing_id]) + self._assert_succeeded(res, tasks) + self._assert_failed(res, [missing_id]) data = self.api.tasks.get_all_ex(id=ids).tasks self.assertEqual(data, []) @@ -90,33 +93,36 @@ class TestBatchOperations(TestService): res = self.api.models.publish_many( ids=ids, publish_task=True, force_publish_task=True ) - self.assertEqual(res.succeeded, 1) - self.assertEqual(res.published_tasks[0].id, task) - self._assert_failures(res, [ids[1], missing_id]) + self._assert_succeeded(res, [ids[0]]) + self.assertEqual(res.succeeded[0].published_task.id, task) + self._assert_failed(res, [ids[1], missing_id]) # archive/unarchive res = self.api.models.archive_many(ids=ids) - self.assertEqual(res.succeeded, 2) - self._assert_failures(res, [missing_id]) + self._assert_succeeded(res, models) + self._assert_failed(res, [missing_id]) data = self.api.models.get_all_ex(id=ids).models self.assertTrue(all("archived" in m.system_tags for m in data)) res = self.api.models.unarchive_many(ids=ids) - self.assertEqual(res.succeeded, 2) - self._assert_failures(res, [missing_id]) + self._assert_succeeded(res, models) + self._assert_failed(res, [missing_id]) data = self.api.models.get_all_ex(id=ids).models self.assertFalse(any("archived" in m.system_tags for m in data)) # delete res = self.api.models.delete_many(ids=[*models, missing_id], force=True) - self.assertEqual(res.succeeded, 2) - self.assertEqual(set(res.urls), set(uris)) - self._assert_failures(res, [missing_id]) + self._assert_succeeded(res, models) + self.assertEqual(set(m.url for m in res.succeeded), set(uris)) + self._assert_failed(res, [missing_id]) data = self.api.models.get_all_ex(id=ids).models self.assertEqual(data, []) - def _assert_failures(self, res, failed_ids): - self.assertEqual(set(f.id for f in res.failures), set(failed_ids)) + def _assert_succeeded(self, res, succeeded_ids): + self.assertEqual(set(f.id for f in res.succeeded), set(succeeded_ids)) + + def _assert_failed(self, res, failed_ids): + self.assertEqual(set(f.id for f in res.failed), set(failed_ids)) def _temp_model(self, **kwargs): self.update_missing(kwargs, name=self.name, uri="file:///a/b", labels={})