Fix batch operations results

This commit is contained in:
allegroai 2021-05-03 18:07:37 +03:00
parent dc9623e964
commit 6b0c45a861
11 changed files with 368 additions and 278 deletions

View File

@ -1,9 +1,11 @@
from typing import Sequence
from jsonmodels.fields import StringField
from jsonmodels.models import Base
from jsonmodels.validators import Length
from apiserver.apimodels import ListField, IntField
from apiserver.apimodels import ListField
from apiserver.apimodels.base import UpdateResponse
class BatchRequest(Base):
@ -11,5 +13,13 @@ class BatchRequest(Base):
class BatchResponse(Base):
succeeded: int = IntField()
failures: Sequence[dict] = ListField([dict])
succeeded: Sequence[dict] = ListField([dict])
failed: Sequence[dict] = ListField([dict])
class UpdateBatchItem(UpdateResponse):
id: str = StringField()
class UpdateBatchResponse(BatchResponse):
succeeded: Sequence[UpdateBatchItem] = ListField(UpdateBatchItem)

View File

@ -3,13 +3,12 @@ from six import string_types
from apiserver.apimodels import ListField, DictField
from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.batch import BatchRequest, BatchResponse
from apiserver.apimodels.batch import BatchRequest
from apiserver.apimodels.metadata import (
MetadataItem,
DeleteMetadata,
AddOrUpdateMetadata,
)
from apiserver.apimodels.tasks import PublishResponse as TaskPublishResponse
class GetFrameworksRequest(models.Base):
@ -51,10 +50,6 @@ class ModelsDeleteManyRequest(BatchRequest):
force = fields.BoolField(default=False)
class ModelsDeleteManyResponse(BatchResponse):
urls = fields.ListField([str])
class PublishModelRequest(ModelRequest):
force_publish_task = fields.BoolField(default=False)
publish_task = fields.BoolField(default=True)
@ -62,7 +57,7 @@ class PublishModelRequest(ModelRequest):
class ModelTaskPublishResponse(models.Base):
id = fields.StringField(required=True)
data = fields.EmbeddedField(TaskPublishResponse)
data = fields.EmbeddedField(UpdateResponse)
class PublishModelResponse(UpdateResponse):
@ -74,10 +69,6 @@ class ModelsPublishManyRequest(BatchRequest):
publish_task = fields.BoolField(default=True)
class ModelsPublishManyResponse(BatchResponse):
published_tasks = fields.ListField([ModelTaskPublishResponse])
class DeleteMetadataRequest(DeleteMetadata):
model = fields.StringField(required=True)

View File

@ -1,13 +1,12 @@
from typing import Sequence
import six
from jsonmodels import models
from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField
from jsonmodels.validators import Enum, Length
from apiserver.apimodels import DictField, ListField
from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.batch import BatchRequest
from apiserver.apimodels.batch import BatchRequest, UpdateBatchItem, BatchResponse
from apiserver.database.model.task.task import (
TaskType,
ArtifactModes,
@ -45,19 +44,43 @@ class EnqueueResponse(UpdateResponse):
queued = IntField()
class EnqueueBatchItem(UpdateBatchItem):
queued: bool = BoolField()
class EnqueueManyResponse(BatchResponse):
succeeded: Sequence[EnqueueBatchItem] = ListField(EnqueueBatchItem)
class DequeueResponse(UpdateResponse):
dequeued = IntField()
class DequeueBatchItem(UpdateBatchItem):
dequeued: bool = BoolField()
class DequeueManyResponse(BatchResponse):
succeeded: Sequence[DequeueBatchItem] = ListField(DequeueBatchItem)
class ResetResponse(UpdateResponse):
deleted_indices = ListField(items_types=six.string_types)
dequeued = DictField()
frames = DictField()
events = DictField()
deleted_models = IntField()
urls = DictField()
class ResetBatchItem(UpdateBatchItem):
dequeued: bool = BoolField()
deleted_models = IntField()
urls = DictField()
class ResetManyResponse(BatchResponse):
succeeded: Sequence[ResetBatchItem] = ListField(ResetBatchItem)
class TaskRequest(models.Base):
task = StringField(required=True)
@ -89,10 +112,6 @@ class PublishRequest(UpdateRequest):
publish_model = BoolField(default=True)
class PublishResponse(UpdateResponse):
pass
class TaskData(models.Base):
"""
This is a partial description of task can be updated incrementally

View File

@ -56,13 +56,13 @@ def archive_task(
except APIError:
# dequeue may fail if the task was not enqueued
pass
task.update(
return task.update(
status_message=status_message,
status_reason=status_reason,
add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
)
return 1
def unarchive_task(
@ -82,6 +82,26 @@ def unarchive_task(
)
def dequeue_task(
task_id: str,
company_id: str,
status_message: str,
status_reason: str,
) -> Tuple[int, dict]:
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
res = TaskBLL.dequeue_and_change_status(
task,
company_id,
status_message=status_message,
status_reason=status_reason,
)
return 1, res
def enqueue_task(
task_id: str,
company_id: str,

View File

@ -113,13 +113,13 @@ T = TypeVar("T")
def run_batch_operation(
func: Callable[[str], T], init_res: T, ids: Sequence[str]
) -> Tuple[T, Sequence]:
res = init_res
func: Callable[[str], T], ids: Sequence[str]
) -> Tuple[Sequence[Tuple[str, T]], Sequence[dict]]:
results = list()
failures = list()
for _id in ids:
try:
res += func(_id)
results.append((_id, func(_id)))
except APIError as err:
failures.append(
{
@ -131,4 +131,4 @@ def run_batch_operation(
},
}
)
return res, failures
return results, failures

View File

@ -43,9 +43,18 @@ batch_operation {
type: object
properties {
succeeded {
type: integer
type: array
items {
type: object
properties {
id: {
description: ID of the succeeded entity
type: string
}
}
}
}
failures {
failed {
type: array
items {
type: object

View File

@ -680,11 +680,11 @@ publish_many {
}
response {
properties {
succeeded.description: "Number of models published"
published_tasks {
type: array
items: ${_definitions.published_task_item}
succeeded.items.properties.updated {
description: "Indicates whether the model was updated"
type: boolean
}
succeeded.items.properties.published_task: ${_definitions.published_task_item}
}
}
}
@ -733,7 +733,10 @@ archive_many {
}
response {
properties {
succeeded.description: "Number of models archived"
succeeded.items.properties.archived {
description: "Indicates whether the model was archived"
type: boolean
}
}
}
}
@ -748,7 +751,10 @@ unarchive_many {
}
response {
properties {
succeeded.description: "Number of models unarchived"
succeeded.items.properties.unarchived {
description: "Indicates whether the model was unarchived"
type: boolean
}
}
}
}
@ -768,11 +774,13 @@ delete_many {
}
response {
properties {
succeeded.description: "Number of models deleted"
urls {
descrition: "The urls of the deleted model files"
type: array
items {type: string}
succeeded.items.properties.deleted {
description: "Indicates whether the model was deleted"
type: boolean
}
succeeded.items.properties.url {
description: "The url of the model file"
type: string
}
}
}
@ -814,7 +822,7 @@ delete {
response {
properties {
url {
descrition: "The url of the model file"
description: "The url of the model file"
type: string
}
}

View File

@ -39,6 +39,20 @@ _definitions {
}
}
}
response {
properties {
succeeded.items.properties.updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
succeeded.items.properties.fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
update_response {
type: object
@ -1370,21 +1384,11 @@ reset {
} ${_references.status_change_request}
response: ${_definitions.update_response} {
properties {
deleted_indices {
description: "List of deleted ES indices that were removed as part of the reset process"
type: array
items { type: string }
}
dequeued {
description: "Response from queues.remove_task"
type: object
additionalProperties: true
}
frames {
description: "Response from frames.rollback"
type: object
additionalProperties: true
}
events {
description: "Response from events.delete_for_task"
type: object
@ -1442,18 +1446,26 @@ reset_many {
}
response {
properties {
succeeded.description: "Number of tasks reset"
dequeued {
description: "Number of tasks dequeued"
succeeded.items.properties.dequeued {
description: "Indicates whether the task was dequeued"
type: boolean
}
succeeded.items.properties.updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
succeeded.items.properties.fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
deleted_models {
succeeded.items.properties.deleted_models {
description: "Number of output models deleted by the reset"
type: integer
}
urls {
description: "The urls of the files that were uploaded by the tasks. Returned if the 'return_file_urls' was set to 'true'"
succeeded.items.properties.urls {
description: "The urls of the files that were uploaded by the task. Returned if the 'return_file_urls' was set to 'true'"
"$ref": "#/definitions/task_urls"
}
}
@ -1486,21 +1498,24 @@ delete_many {
}
response {
properties {
succeeded.description: "Number of tasks deleted"
updated_children {
succeeded.items.properties.deleted {
description: "Indicates whether the task was deleted"
type: boolean
}
succeeded.items.properties.updated_children {
description: "Number of child tasks whose parent property was updated"
type: integer
}
updated_models {
succeeded.items.properties.updated_models {
description: "Number of models whose task property was updated"
type: integer
}
deleted_models {
succeeded.items.properties.deleted_models {
description: "Number of deleted output models"
type: integer
}
urls {
description: "The urls of the files that were uploaded by the tasks. Returned if the 'return_file_urls' was set to 'true'"
succeeded.items.properties.urls {
description: "The urls of the files that were uploaded by the task. Returned if the 'return_file_urls' was set to 'true'"
"$ref": "#/definitions/task_urls"
}
}
@ -1609,31 +1624,53 @@ archive {
}
}
archive_many {
"2.13": ${_definitions.change_many_request} {
"2.13": ${_definitions.batch_operation} {
description: Archive tasks
request {
properties {
ids.description: "IDs of the tasks to archive"
status_reason {
description: Reason for status change
type: string
}
status_message {
description: Extra information regarding status change
type: string
}
}
response {
properties {
succeeded.description: "Number of tasks archived"
response {
properties {
succeeded.items.properties.archived {
description: "Indicates whether the task was archived"
type: boolean
}
}
}
}
}
}
unarchive_many {
"2.13": ${_definitions.change_many_request} {
"2.13": ${_definitions.batch_operation} {
description: Unarchive tasks
request {
properties {
ids.description: "IDs of the tasks to unarchive"
status_reason {
description: Reason for status change
type: string
}
status_message {
description: Extra information regarding status change
type: string
}
}
}
response {
properties {
succeeded.description: "Number of tasks unarchived"
succeeded.items.properties.unarchived {
description: "Indicates whether the task was unarchived"
type: boolean
}
}
}
}
@ -1687,11 +1724,6 @@ stop_many {
}
}
}
response {
properties {
succeeded.description: "Number of tasks stopped"
}
}
}
}
stopped {
@ -1775,11 +1807,6 @@ publish_many {
}
}
}
response {
properties {
succeeded.description: "Number of tasks published"
}
}
}
}
enqueue {
@ -1837,7 +1864,10 @@ enqueue_many {
}
response {
properties {
succeeded.description: "Number of tasks enqueued"
succeeded.items.properties.queued {
description: "Indicates whether the task was queued"
type: boolean
}
}
}
}
@ -1863,6 +1893,24 @@ dequeue {
}
}
}
dequeue_many {
"2.13": ${_definitions.change_many_request} {
description: Dequeue tasks
request {
properties {
ids.description: "IDs of the tasks to dequeue"
}
}
response {
properties {
succeeded.items.properties.dequeued {
description: "Indicates whether the task was dequeued"
type: boolean
}
}
}
}
}
set_requirements {
"2.1" {
description: """Set the script requirements for a task"""

View File

@ -1,8 +1,7 @@
from datetime import datetime
from functools import partial
from typing import Sequence, Tuple, Set
from typing import Sequence
import attr
from mongoengine import Q, EmbeddedDocument
from apiserver import database
@ -15,15 +14,12 @@ from apiserver.apimodels.models import (
CreateModelResponse,
PublishModelRequest,
PublishModelResponse,
ModelTaskPublishResponse,
GetFrameworksRequest,
DeleteModelRequest,
DeleteMetadataRequest,
AddOrUpdateMetadataRequest,
ModelsPublishManyRequest,
ModelsPublishManyResponse,
ModelsDeleteManyRequest,
ModelsDeleteManyResponse,
)
from apiserver.bll.model import ModelBLL
from apiserver.bll.organization import OrgBLL, Tags
@ -501,26 +497,13 @@ def set_ready(call: APICall, company_id: str, request: PublishModelRequest):
)
@attr.s(auto_attribs=True)
class PublishRes:
published: int = 0
published_tasks: Sequence = []
def __add__(self, other: Tuple[int, ModelTaskPublishResponse]):
published, response = other
return PublishRes(
published=self.published + published,
published_tasks=[*self.published_tasks, *([response] if response else [])],
)
@endpoint(
"models.publish_many",
request_data_model=ModelsPublishManyRequest,
response_data_model=ModelsPublishManyResponse,
response_data_model=BatchResponse,
)
def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
res, failures = run_batch_operation(
results, failures = run_batch_operation(
func=partial(
ModelBLL.publish_model,
company_id=company_id,
@ -528,11 +511,16 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
publish_task_func=publish_task if request.publish_task else None,
),
ids=request.ids,
init_res=PublishRes(),
)
call.result.data_model = ModelsPublishManyResponse(
succeeded=res.published, published_tasks=res.published_tasks, failures=failures,
call.result.data_model = BatchResponse(
succeeded=[
dict(
id=_id, updated=bool(updated), published_task=published_task.to_struct()
)
for _id, (updated, published_task) in results
],
failed=failures,
)
@ -546,41 +534,30 @@ def delete(call: APICall, company_id, request: DeleteModelRequest):
company_id, projects=[model.project] if model.project else []
)
call.result.data = dict(deleted=del_count > 0, url=model.uri)
@attr.s(auto_attribs=True)
class DeleteRes:
deleted: int = 0
projects: Set = set()
urls: Set = set()
def __add__(self, other: Tuple[int, Model]):
del_count, model = other
return DeleteRes(
deleted=self.deleted + del_count,
projects=self.projects | {model.project},
urls=self.urls | {model.uri},
)
call.result.data = dict(deleted=bool(del_count), url=model.uri)
@endpoint(
"models.delete_many",
request_data_model=ModelsDeleteManyRequest,
response_data_model=ModelsDeleteManyResponse,
response_data_model=BatchResponse,
)
def delete(call: APICall, company_id, request: ModelsDeleteManyRequest):
res, failures = run_batch_operation(
results, failures = run_batch_operation(
func=partial(ModelBLL.delete_model, company_id=company_id, force=request.force),
ids=request.ids,
init_res=DeleteRes(),
)
if res.deleted:
_reset_cached_tags(company_id, projects=list(res.projects))
res.urls.discard(None)
call.result.data_model = ModelsDeleteManyResponse(
succeeded=res.deleted, urls=list(res.urls), failures=failures,
if results:
projects = set(model.project for _, (_, model) in results)
_reset_cached_tags(company_id, projects=list(projects))
call.result.data_model = BatchResponse(
succeeded=[
dict(id=_id, deleted=bool(deleted), url=model.uri)
for _id, (deleted, model) in results
],
failed=failures,
)
@ -590,12 +567,13 @@ def delete(call: APICall, company_id, request: ModelsDeleteManyRequest):
response_data_model=BatchResponse,
)
def archive_many(call: APICall, company_id, request: BatchRequest):
archived, failures = run_batch_operation(
func=partial(ModelBLL.archive_model, company_id=company_id),
ids=request.ids,
init_res=0,
results, failures = run_batch_operation(
func=partial(ModelBLL.archive_model, company_id=company_id), ids=request.ids,
)
call.result.data_model = BatchResponse(
succeeded=[dict(id=_id, archived=bool(archived)) for _id, archived in results],
failed=failures,
)
call.result.data_model = BatchResponse(succeeded=archived, failures=failures)
@endpoint(
@ -604,12 +582,15 @@ def archive_many(call: APICall, company_id, request: BatchRequest):
response_data_model=BatchResponse,
)
def unarchive_many(call: APICall, company_id, request: BatchRequest):
unarchived, failures = run_batch_operation(
func=partial(ModelBLL.unarchive_model, company_id=company_id),
ids=request.ids,
init_res=0,
results, failures = run_batch_operation(
func=partial(ModelBLL.unarchive_model, company_id=company_id), ids=request.ids,
)
call.result.data_model = BatchResponse(
succeeded=[
dict(id=_id, unarchived=bool(unarchived)) for _id, unarchived in results
],
failed=failures,
)
call.result.data_model = BatchResponse(succeeded=unarchived, failures=failures,)
@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest)

View File

@ -1,7 +1,7 @@
from copy import deepcopy
from datetime import datetime
from functools import partial
from typing import Sequence, Union, Tuple, Set
from typing import Sequence, Union, Tuple
import attr
import dpath
@ -17,12 +17,15 @@ from apiserver.apimodels.base import (
MakePublicRequest,
MoveRequest,
)
from apiserver.apimodels.batch import BatchResponse
from apiserver.apimodels.batch import (
BatchResponse,
UpdateBatchResponse,
UpdateBatchItem,
)
from apiserver.apimodels.tasks import (
StartedResponse,
ResetResponse,
PublishRequest,
PublishResponse,
CreateRequest,
UpdateRequest,
SetRequirementsRequest,
@ -54,6 +57,12 @@ from apiserver.apimodels.tasks import (
DeleteManyRequest,
PublishManyRequest,
TaskBatchRequest,
EnqueueManyResponse,
EnqueueBatchItem,
DequeueBatchItem,
DequeueManyResponse,
ResetManyResponse,
ResetBatchItem,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL
@ -77,10 +86,10 @@ from apiserver.bll.task.param_utils import (
params_unprepare_from_saved,
escape_paths,
)
from apiserver.bll.task.task_cleanup import CleanupResult
from apiserver.bll.task.task_operations import (
stop_task,
enqueue_task,
dequeue_task,
reset_task,
archive_task,
delete_task,
@ -287,21 +296,13 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
)
@attr.s(auto_attribs=True)
class StopRes:
stopped: int = 0
def __add__(self, other: dict):
return StopRes(stopped=self.stopped + 1)
@endpoint(
"tasks.stop_many",
request_data_model=StopManyRequest,
response_data_model=BatchResponse,
response_data_model=UpdateBatchResponse,
)
def stop_many(call: APICall, company_id, request: StopManyRequest):
res, failures = run_batch_operation(
results, failures = run_batch_operation(
func=partial(
stop_task,
company_id=company_id,
@ -310,9 +311,11 @@ def stop_many(call: APICall, company_id, request: StopManyRequest):
force=request.force,
),
ids=request.ids,
init_res=StopRes(),
)
call.result.data_model = BatchResponse(succeeded=res.stopped, failures=failures)
call.result.data_model = UpdateBatchResponse(
succeeded=[UpdateBatchItem(id=_id, **res) for _id, res in results],
failed=failures,
)
@endpoint(
@ -854,22 +857,13 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
call.result.data_model = EnqueueResponse(queued=queued, **res)
@attr.s(auto_attribs=True)
class EnqueueRes:
queued: int = 0
def __add__(self, other: Tuple[int, dict]):
queued, _ = other
return EnqueueRes(queued=self.queued + queued)
@endpoint(
"tasks.enqueue_many",
request_data_model=EnqueueManyRequest,
response_data_model=BatchResponse,
response_data_model=EnqueueManyResponse,
)
def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
res, failures = run_batch_operation(
results, failures = run_batch_operation(
func=partial(
enqueue_task,
company_id=company_id,
@ -879,9 +873,14 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
validate=request.validate_tasks,
),
ids=request.ids,
init_res=EnqueueRes(),
)
call.result.data_model = BatchResponse(succeeded=res.queued, failures=failures)
call.result.data_model = EnqueueManyResponse(
succeeded=[
EnqueueBatchItem(id=_id, queued=bool(queued), **res)
for _id, (queued, res) in results
],
failed=failures,
)
@endpoint(
@ -890,23 +889,37 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
response_data_model=DequeueResponse,
)
def dequeue(call: APICall, company_id, request: UpdateRequest):
task = TaskBLL.get_task_with_access(
request.task,
dequeued, res = dequeue_task(
task_id=request.task,
company_id=company_id,
only=("id", "execution", "status", "project", "enqueue_status"),
requires_write_access=True,
status_message=request.status_message,
status_reason=request.status_reason,
)
res = DequeueResponse(
**TaskBLL.dequeue_and_change_status(
task,
company_id,
call.result.data_model = DequeueResponse(dequeued=dequeued, **res)
@endpoint(
"tasks.dequeue_many",
request_data_model=TaskBatchRequest,
response_data_model=DequeueManyResponse,
)
def dequeue_many(call: APICall, company_id, request: TaskBatchRequest):
results, failures = run_batch_operation(
func=partial(
dequeue_task,
company_id=company_id,
status_message=request.status_message,
status_reason=request.status_reason,
)
),
ids=request.ids,
)
call.result.data_model = DequeueManyResponse(
succeeded=[
DequeueBatchItem(id=_id, dequeued=bool(dequeued), **res)
for _id, (dequeued, res) in results
],
failed=failures,
)
res.dequeued = 1
call.result.data_model = res
@endpoint(
@ -932,25 +945,13 @@ def reset(call: APICall, company_id, request: ResetRequest):
call.result.data_model = res
@attr.s(auto_attribs=True)
class ResetRes:
reset: int = 0
dequeued: int = 0
cleanup_res: CleanupResult = None
def __add__(self, other: Tuple[dict, CleanupResult, dict]):
dequeued, other_res, _ = other
dequeued = dequeued.get("removed", 0) if dequeued else 0
return ResetRes(
reset=self.reset + 1,
dequeued=self.dequeued + dequeued,
cleanup_res=self.cleanup_res + other_res if self.cleanup_res else other_res,
)
@endpoint("tasks.reset_many", request_data_model=ResetManyRequest)
@endpoint(
"tasks.reset_many",
request_data_model=ResetManyRequest,
response_data_model=ResetManyResponse,
)
def reset_many(call: APICall, company_id, request: ResetManyRequest):
res, failures = run_batch_operation(
results, failures = run_batch_operation(
func=partial(
reset_task,
company_id=company_id,
@ -960,18 +961,26 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
clear_all=request.clear_all,
),
ids=request.ids,
init_res=ResetRes(),
)
if res.cleanup_res:
cleanup_res = dict(
deleted_models=res.cleanup_res.deleted_models,
urls=attr.asdict(res.cleanup_res.urls) if res.cleanup_res.urls else {},
)
else:
cleanup_res = {}
call.result.data = dict(
succeeded=res.reset, dequeued=res.dequeued, **cleanup_res, failures=failures,
def clean_res(res: dict) -> dict:
# do not return artifacts since they are not serializable
fields = res.get("fields")
if fields:
fields.pop("execution.artifacts", None)
return res
call.result.data_model = ResetManyResponse(
succeeded=[
ResetBatchItem(
id=_id,
dequeued=bool(dequeued.get("removed")) if dequeued else False,
**attr.asdict(cleanup),
**clean_res(res),
)
for _id, (dequeued, cleanup, res) in results
],
failed=failures,
)
@ -1004,7 +1013,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
response_data_model=BatchResponse,
)
def archive_many(call: APICall, company_id, request: TaskBatchRequest):
archived, failures = run_batch_operation(
results, failures = run_batch_operation(
func=partial(
archive_task,
company_id=company_id,
@ -1012,9 +1021,11 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
status_reason=request.status_reason,
),
ids=request.ids,
init_res=0,
)
call.result.data_model = BatchResponse(succeeded=archived, failures=failures)
call.result.data_model = BatchResponse(
succeeded=[dict(id=_id, archived=bool(archived)) for _id, archived in results],
failed=failures,
)
@endpoint(
@ -1023,7 +1034,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
response_data_model=BatchResponse,
)
def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
unarchived, failures = run_batch_operation(
results, failures = run_batch_operation(
func=partial(
unarchive_task,
company_id=company_id,
@ -1031,9 +1042,13 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
status_reason=request.status_reason,
),
ids=request.ids,
init_res=0,
)
call.result.data_model = BatchResponse(succeeded=unarchived, failures=failures)
call.result.data_model = BatchResponse(
succeeded=[
dict(id=_id, unarchived=bool(unarchived)) for _id, unarchived in results
],
failed=failures,
)
@endpoint("tasks.delete", request_data_model=DeleteRequest)
@ -1051,25 +1066,9 @@ def delete(call: APICall, company_id, request: DeleteRequest):
call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res))
@attr.s(auto_attribs=True)
class DeleteRes:
deleted: int = 0
projects: Set = set()
cleanup_res: CleanupResult = None
def __add__(self, other: Tuple[int, Task, CleanupResult]):
del_count, task, other_res = other
return DeleteRes(
deleted=self.deleted + del_count,
projects=self.projects | {task.project},
cleanup_res=self.cleanup_res + other_res if self.cleanup_res else other_res,
)
@endpoint("tasks.delete_many", request_data_model=DeleteManyRequest)
def delete_many(call: APICall, company_id, request: DeleteManyRequest):
res, failures = run_batch_operation(
results, failures = run_batch_operation(
func=partial(
delete_task,
company_id=company_id,
@ -1079,20 +1078,25 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
delete_output_models=request.delete_output_models,
),
ids=request.ids,
init_res=DeleteRes(),
)
if res.deleted:
_reset_cached_tags(company_id, projects=list(res.projects))
if results:
projects = set(task.project for _, (_, task, _) in results)
_reset_cached_tags(company_id, projects=list(projects))
cleanup_res = attr.asdict(res.cleanup_res) if res.cleanup_res else {}
call.result.data = dict(succeeded=res.deleted, **cleanup_res, failures=failures)
call.result.data = dict(
succeeded=[
dict(id=_id, deleted=bool(deleted), **attr.asdict(cleanup_res))
for _id, (deleted, _, cleanup_res) in results
],
failed=failures,
)
@endpoint(
"tasks.publish",
request_data_model=PublishRequest,
response_data_model=PublishResponse,
response_data_model=UpdateResponse,
)
def publish(call: APICall, company_id, request: PublishRequest):
updates = publish_task(
@ -1103,24 +1107,16 @@ def publish(call: APICall, company_id, request: PublishRequest):
status_reason=request.status_reason,
status_message=request.status_message,
)
call.result.data_model = PublishResponse(**updates)
@attr.s(auto_attribs=True)
class PublishRes:
published: int = 0
def __add__(self, other: dict):
return PublishRes(published=self.published + 1)
call.result.data_model = UpdateResponse(**updates)
@endpoint(
"tasks.publish_many",
request_data_model=PublishManyRequest,
response_data_model=BatchResponse,
response_data_model=UpdateBatchResponse,
)
def publish_many(call: APICall, company_id, request: PublishManyRequest):
res, failures = run_batch_operation(
results, failures = run_batch_operation(
func=partial(
publish_task,
company_id=company_id,
@ -1132,10 +1128,12 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest):
status_message=request.status_message,
),
ids=request.ids,
init_res=PublishRes(),
)
call.result.data_model = BatchResponse(succeeded=res.published, failures=failures)
call.result.data_model = UpdateBatchResponse(
succeeded=[UpdateBatchItem(id=_id, **res) for _id, res in results],
failed=failures,
)
@endpoint(

View File

@ -21,8 +21,8 @@ class TestBatchOperations(TestService):
# enqueue
res = self.api.tasks.enqueue_many(ids=ids)
self.assertEqual(res.succeeded, 2)
self._assert_failures(res, [missing_id])
self._assert_succeeded(res, tasks)
self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual({t.status for t in data}, {"queued"})
@ -30,15 +30,15 @@ class TestBatchOperations(TestService):
for t in tasks:
self.api.tasks.started(task=t)
res = self.api.tasks.stop_many(ids=ids)
self.assertEqual(res.succeeded, 2)
self._assert_failures(res, [missing_id])
self._assert_succeeded(res, tasks)
self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual({t.status for t in data}, {"stopped"})
# publish
res = self.api.tasks.publish_many(ids=ids, publish_model=False)
self.assertEqual(res.succeeded, 2)
self._assert_failures(res, [missing_id])
self._assert_succeeded(res, tasks)
self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual({t.status for t in data}, {"published"})
@ -46,23 +46,26 @@ class TestBatchOperations(TestService):
res = self.api.tasks.reset_many(
ids=ids, delete_output_models=True, return_file_urls=True, force=True
)
self.assertEqual(res.succeeded, 2)
self.assertEqual(res.deleted_models, 2)
self.assertEqual(set(res.urls.model_urls), {"uri_0", "uri_1"})
self._assert_failures(res, [missing_id])
self._assert_succeeded(res, tasks)
self.assertEqual(sum(t.deleted_models for t in res.succeeded), 2)
self.assertEqual(
set(url for t in res.succeeded for url in t.urls.model_urls),
{"uri_0", "uri_1"},
)
self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual({t.status for t in data}, {"created"})
# archive/unarchive
res = self.api.tasks.archive_many(ids=ids)
self.assertEqual(res.succeeded, 2)
self._assert_failures(res, [missing_id])
self._assert_succeeded(res, tasks)
self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertTrue(all("archived" in t.system_tags for t in data))
res = self.api.tasks.unarchive_many(ids=ids)
self.assertEqual(res.succeeded, 2)
self._assert_failures(res, [missing_id])
self._assert_succeeded(res, tasks)
self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertFalse(any("archived" in t.system_tags for t in data))
@ -70,8 +73,8 @@ class TestBatchOperations(TestService):
res = self.api.tasks.delete_many(
ids=ids, delete_output_models=True, return_file_urls=True
)
self.assertEqual(res.succeeded, 2)
self._assert_failures(res, [missing_id])
self._assert_succeeded(res, tasks)
self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual(data, [])
@ -90,33 +93,36 @@ class TestBatchOperations(TestService):
res = self.api.models.publish_many(
ids=ids, publish_task=True, force_publish_task=True
)
self.assertEqual(res.succeeded, 1)
self.assertEqual(res.published_tasks[0].id, task)
self._assert_failures(res, [ids[1], missing_id])
self._assert_succeeded(res, [ids[0]])
self.assertEqual(res.succeeded[0].published_task.id, task)
self._assert_failed(res, [ids[1], missing_id])
# archive/unarchive
res = self.api.models.archive_many(ids=ids)
self.assertEqual(res.succeeded, 2)
self._assert_failures(res, [missing_id])
self._assert_succeeded(res, models)
self._assert_failed(res, [missing_id])
data = self.api.models.get_all_ex(id=ids).models
self.assertTrue(all("archived" in m.system_tags for m in data))
res = self.api.models.unarchive_many(ids=ids)
self.assertEqual(res.succeeded, 2)
self._assert_failures(res, [missing_id])
self._assert_succeeded(res, models)
self._assert_failed(res, [missing_id])
data = self.api.models.get_all_ex(id=ids).models
self.assertFalse(any("archived" in m.system_tags for m in data))
# delete
res = self.api.models.delete_many(ids=[*models, missing_id], force=True)
self.assertEqual(res.succeeded, 2)
self.assertEqual(set(res.urls), set(uris))
self._assert_failures(res, [missing_id])
self._assert_succeeded(res, models)
self.assertEqual(set(m.url for m in res.succeeded), set(uris))
self._assert_failed(res, [missing_id])
data = self.api.models.get_all_ex(id=ids).models
self.assertEqual(data, [])
def _assert_failures(self, res, failed_ids):
self.assertEqual(set(f.id for f in res.failures), set(failed_ids))
def _assert_succeeded(self, res, succeeded_ids):
self.assertEqual(set(f.id for f in res.succeeded), set(succeeded_ids))
def _assert_failed(self, res, failed_ids):
self.assertEqual(set(f.id for f in res.failed), set(failed_ids))
def _temp_model(self, **kwargs):
self.update_missing(kwargs, name=self.name, uri="file:///a/b", labels={})