Fix batch operations results

This commit is contained in:
allegroai 2021-05-03 18:07:37 +03:00
parent dc9623e964
commit 6b0c45a861
11 changed files with 368 additions and 278 deletions

View File

@ -1,9 +1,11 @@
from typing import Sequence from typing import Sequence
from jsonmodels.fields import StringField
from jsonmodels.models import Base from jsonmodels.models import Base
from jsonmodels.validators import Length from jsonmodels.validators import Length
from apiserver.apimodels import ListField, IntField from apiserver.apimodels import ListField
from apiserver.apimodels.base import UpdateResponse
class BatchRequest(Base): class BatchRequest(Base):
@ -11,5 +13,13 @@ class BatchRequest(Base):
class BatchResponse(Base): class BatchResponse(Base):
succeeded: int = IntField() succeeded: Sequence[dict] = ListField([dict])
failures: Sequence[dict] = ListField([dict]) failed: Sequence[dict] = ListField([dict])
class UpdateBatchItem(UpdateResponse):
id: str = StringField()
class UpdateBatchResponse(BatchResponse):
succeeded: Sequence[UpdateBatchItem] = ListField(UpdateBatchItem)

View File

@ -3,13 +3,12 @@ from six import string_types
from apiserver.apimodels import ListField, DictField from apiserver.apimodels import ListField, DictField
from apiserver.apimodels.base import UpdateResponse from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.batch import BatchRequest, BatchResponse from apiserver.apimodels.batch import BatchRequest
from apiserver.apimodels.metadata import ( from apiserver.apimodels.metadata import (
MetadataItem, MetadataItem,
DeleteMetadata, DeleteMetadata,
AddOrUpdateMetadata, AddOrUpdateMetadata,
) )
from apiserver.apimodels.tasks import PublishResponse as TaskPublishResponse
class GetFrameworksRequest(models.Base): class GetFrameworksRequest(models.Base):
@ -51,10 +50,6 @@ class ModelsDeleteManyRequest(BatchRequest):
force = fields.BoolField(default=False) force = fields.BoolField(default=False)
class ModelsDeleteManyResponse(BatchResponse):
urls = fields.ListField([str])
class PublishModelRequest(ModelRequest): class PublishModelRequest(ModelRequest):
force_publish_task = fields.BoolField(default=False) force_publish_task = fields.BoolField(default=False)
publish_task = fields.BoolField(default=True) publish_task = fields.BoolField(default=True)
@ -62,7 +57,7 @@ class PublishModelRequest(ModelRequest):
class ModelTaskPublishResponse(models.Base): class ModelTaskPublishResponse(models.Base):
id = fields.StringField(required=True) id = fields.StringField(required=True)
data = fields.EmbeddedField(TaskPublishResponse) data = fields.EmbeddedField(UpdateResponse)
class PublishModelResponse(UpdateResponse): class PublishModelResponse(UpdateResponse):
@ -74,10 +69,6 @@ class ModelsPublishManyRequest(BatchRequest):
publish_task = fields.BoolField(default=True) publish_task = fields.BoolField(default=True)
class ModelsPublishManyResponse(BatchResponse):
published_tasks = fields.ListField([ModelTaskPublishResponse])
class DeleteMetadataRequest(DeleteMetadata): class DeleteMetadataRequest(DeleteMetadata):
model = fields.StringField(required=True) model = fields.StringField(required=True)

View File

@ -1,13 +1,12 @@
from typing import Sequence from typing import Sequence
import six
from jsonmodels import models from jsonmodels import models
from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField
from jsonmodels.validators import Enum, Length from jsonmodels.validators import Enum, Length
from apiserver.apimodels import DictField, ListField from apiserver.apimodels import DictField, ListField
from apiserver.apimodels.base import UpdateResponse from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.batch import BatchRequest from apiserver.apimodels.batch import BatchRequest, UpdateBatchItem, BatchResponse
from apiserver.database.model.task.task import ( from apiserver.database.model.task.task import (
TaskType, TaskType,
ArtifactModes, ArtifactModes,
@ -45,19 +44,43 @@ class EnqueueResponse(UpdateResponse):
queued = IntField() queued = IntField()
class EnqueueBatchItem(UpdateBatchItem):
queued: bool = BoolField()
class EnqueueManyResponse(BatchResponse):
succeeded: Sequence[EnqueueBatchItem] = ListField(EnqueueBatchItem)
class DequeueResponse(UpdateResponse): class DequeueResponse(UpdateResponse):
dequeued = IntField() dequeued = IntField()
class DequeueBatchItem(UpdateBatchItem):
dequeued: bool = BoolField()
class DequeueManyResponse(BatchResponse):
succeeded: Sequence[DequeueBatchItem] = ListField(DequeueBatchItem)
class ResetResponse(UpdateResponse): class ResetResponse(UpdateResponse):
deleted_indices = ListField(items_types=six.string_types)
dequeued = DictField() dequeued = DictField()
frames = DictField()
events = DictField() events = DictField()
deleted_models = IntField() deleted_models = IntField()
urls = DictField() urls = DictField()
class ResetBatchItem(UpdateBatchItem):
dequeued: bool = BoolField()
deleted_models = IntField()
urls = DictField()
class ResetManyResponse(BatchResponse):
succeeded: Sequence[ResetBatchItem] = ListField(ResetBatchItem)
class TaskRequest(models.Base): class TaskRequest(models.Base):
task = StringField(required=True) task = StringField(required=True)
@ -89,10 +112,6 @@ class PublishRequest(UpdateRequest):
publish_model = BoolField(default=True) publish_model = BoolField(default=True)
class PublishResponse(UpdateResponse):
pass
class TaskData(models.Base): class TaskData(models.Base):
""" """
This is a partial description of task can be updated incrementally This is a partial description of task can be updated incrementally

View File

@ -56,13 +56,13 @@ def archive_task(
except APIError: except APIError:
# dequeue may fail if the task was not enqueued # dequeue may fail if the task was not enqueued
pass pass
task.update(
return task.update(
status_message=status_message, status_message=status_message,
status_reason=status_reason, status_reason=status_reason,
add_to_set__system_tags=EntityVisibility.archived.value, add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(), last_change=datetime.utcnow(),
) )
return 1
def unarchive_task( def unarchive_task(
@ -82,6 +82,26 @@ def unarchive_task(
) )
def dequeue_task(
task_id: str,
company_id: str,
status_message: str,
status_reason: str,
) -> Tuple[int, dict]:
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
res = TaskBLL.dequeue_and_change_status(
task,
company_id,
status_message=status_message,
status_reason=status_reason,
)
return 1, res
def enqueue_task( def enqueue_task(
task_id: str, task_id: str,
company_id: str, company_id: str,

View File

@ -113,13 +113,13 @@ T = TypeVar("T")
def run_batch_operation( def run_batch_operation(
func: Callable[[str], T], init_res: T, ids: Sequence[str] func: Callable[[str], T], ids: Sequence[str]
) -> Tuple[T, Sequence]: ) -> Tuple[Sequence[Tuple[str, T]], Sequence[dict]]:
res = init_res results = list()
failures = list() failures = list()
for _id in ids: for _id in ids:
try: try:
res += func(_id) results.append((_id, func(_id)))
except APIError as err: except APIError as err:
failures.append( failures.append(
{ {
@ -131,4 +131,4 @@ def run_batch_operation(
}, },
} }
) )
return res, failures return results, failures

View File

@ -43,9 +43,18 @@ batch_operation {
type: object type: object
properties { properties {
succeeded { succeeded {
type: integer type: array
items {
type: object
properties {
id: {
description: ID of the succeeded entity
type: string
}
}
}
} }
failures { failed {
type: array type: array
items { items {
type: object type: object

View File

@ -680,11 +680,11 @@ publish_many {
} }
response { response {
properties { properties {
succeeded.description: "Number of models published" succeeded.items.properties.updated {
published_tasks { description: "Indicates whether the model was updated"
type: array type: boolean
items: ${_definitions.published_task_item}
} }
succeeded.items.properties.published_task: ${_definitions.published_task_item}
} }
} }
} }
@ -733,7 +733,10 @@ archive_many {
} }
response { response {
properties { properties {
succeeded.description: "Number of models archived" succeeded.items.properties.archived {
description: "Indicates whether the model was archived"
type: boolean
}
} }
} }
} }
@ -748,7 +751,10 @@ unarchive_many {
} }
response { response {
properties { properties {
succeeded.description: "Number of models unarchived" succeeded.items.properties.unarchived {
description: "Indicates whether the model was unarchived"
type: boolean
}
} }
} }
} }
@ -768,11 +774,13 @@ delete_many {
} }
response { response {
properties { properties {
succeeded.description: "Number of models deleted" succeeded.items.properties.deleted {
urls { description: "Indicates whether the model was deleted"
descrition: "The urls of the deleted model files" type: boolean
type: array }
items {type: string} succeeded.items.properties.url {
description: "The url of the model file"
type: string
} }
} }
} }
@ -814,7 +822,7 @@ delete {
response { response {
properties { properties {
url { url {
descrition: "The url of the model file" description: "The url of the model file"
type: string type: string
} }
} }

View File

@ -39,6 +39,20 @@ _definitions {
} }
} }
} }
response {
properties {
succeeded.items.properties.updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
succeeded.items.properties.fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
} }
update_response { update_response {
type: object type: object
@ -1370,21 +1384,11 @@ reset {
} ${_references.status_change_request} } ${_references.status_change_request}
response: ${_definitions.update_response} { response: ${_definitions.update_response} {
properties { properties {
deleted_indices {
description: "List of deleted ES indices that were removed as part of the reset process"
type: array
items { type: string }
}
dequeued { dequeued {
description: "Response from queues.remove_task" description: "Response from queues.remove_task"
type: object type: object
additionalProperties: true additionalProperties: true
} }
frames {
description: "Response from frames.rollback"
type: object
additionalProperties: true
}
events { events {
description: "Response from events.delete_for_task" description: "Response from events.delete_for_task"
type: object type: object
@ -1442,18 +1446,26 @@ reset_many {
} }
response { response {
properties { properties {
succeeded.description: "Number of tasks reset" succeeded.items.properties.dequeued {
dequeued { description: "Indicates whether the task was dequeued"
description: "Number of tasks dequeued" type: boolean
}
succeeded.items.properties.updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
succeeded.items.properties.fields {
description: "Updated fields names and values"
type: object type: object
additionalProperties: true additionalProperties: true
} }
deleted_models { succeeded.items.properties.deleted_models {
description: "Number of output models deleted by the reset" description: "Number of output models deleted by the reset"
type: integer type: integer
} }
urls { succeeded.items.properties.urls {
description: "The urls of the files that were uploaded by the tasks. Returned if the 'return_file_urls' was set to 'true'" description: "The urls of the files that were uploaded by the task. Returned if the 'return_file_urls' was set to 'true'"
"$ref": "#/definitions/task_urls" "$ref": "#/definitions/task_urls"
} }
} }
@ -1486,21 +1498,24 @@ delete_many {
} }
response { response {
properties { properties {
succeeded.description: "Number of tasks deleted" succeeded.items.properties.deleted {
updated_children { description: "Indicates whether the task was deleted"
type: boolean
}
succeeded.items.properties.updated_children {
description: "Number of child tasks whose parent property was updated" description: "Number of child tasks whose parent property was updated"
type: integer type: integer
} }
updated_models { succeeded.items.properties.updated_models {
description: "Number of models whose task property was updated" description: "Number of models whose task property was updated"
type: integer type: integer
} }
deleted_models { succeeded.items.properties.deleted_models {
description: "Number of deleted output models" description: "Number of deleted output models"
type: integer type: integer
} }
urls { succeeded.items.properties.urls {
description: "The urls of the files that were uploaded by the tasks. Returned if the 'return_file_urls' was set to 'true'" description: "The urls of the files that were uploaded by the task. Returned if the 'return_file_urls' was set to 'true'"
"$ref": "#/definitions/task_urls" "$ref": "#/definitions/task_urls"
} }
} }
@ -1609,31 +1624,53 @@ archive {
} }
} }
archive_many { archive_many {
"2.13": ${_definitions.change_many_request} { "2.13": ${_definitions.batch_operation} {
description: Archive tasks description: Archive tasks
request { request {
properties { properties {
ids.description: "IDs of the tasks to archive" ids.description: "IDs of the tasks to archive"
status_reason {
description: Reason for status change
type: string
}
status_message {
description: Extra information regarding status change
type: string
}
} }
response { response {
properties { properties {
succeeded.description: "Number of tasks archived" succeeded.items.properties.archived {
description: "Indicates whether the task was archived"
type: boolean
}
} }
} }
} }
} }
} }
unarchive_many { unarchive_many {
"2.13": ${_definitions.change_many_request} { "2.13": ${_definitions.batch_operation} {
description: Unarchive tasks description: Unarchive tasks
request { request {
properties { properties {
ids.description: "IDs of the tasks to unarchive" ids.description: "IDs of the tasks to unarchive"
status_reason {
description: Reason for status change
type: string
}
status_message {
description: Extra information regarding status change
type: string
}
} }
} }
response { response {
properties { properties {
succeeded.description: "Number of tasks unarchived" succeeded.items.properties.unarchived {
description: "Indicates whether the task was unarchived"
type: boolean
}
} }
} }
} }
@ -1687,11 +1724,6 @@ stop_many {
} }
} }
} }
response {
properties {
succeeded.description: "Number of tasks stopped"
}
}
} }
} }
stopped { stopped {
@ -1775,11 +1807,6 @@ publish_many {
} }
} }
} }
response {
properties {
succeeded.description: "Number of tasks published"
}
}
} }
} }
enqueue { enqueue {
@ -1837,7 +1864,10 @@ enqueue_many {
} }
response { response {
properties { properties {
succeeded.description: "Number of tasks enqueued" succeeded.items.properties.queued {
description: "Indicates whether the task was queued"
type: boolean
}
} }
} }
} }
@ -1863,6 +1893,24 @@ dequeue {
} }
} }
} }
dequeue_many {
"2.13": ${_definitions.change_many_request} {
description: Dequeue tasks
request {
properties {
ids.description: "IDs of the tasks to dequeue"
}
}
response {
properties {
succeeded.items.properties.dequeued {
description: "Indicates whether the task was dequeued"
type: boolean
}
}
}
}
}
set_requirements { set_requirements {
"2.1" { "2.1" {
description: """Set the script requirements for a task""" description: """Set the script requirements for a task"""

View File

@ -1,8 +1,7 @@
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Sequence, Tuple, Set from typing import Sequence
import attr
from mongoengine import Q, EmbeddedDocument from mongoengine import Q, EmbeddedDocument
from apiserver import database from apiserver import database
@ -15,15 +14,12 @@ from apiserver.apimodels.models import (
CreateModelResponse, CreateModelResponse,
PublishModelRequest, PublishModelRequest,
PublishModelResponse, PublishModelResponse,
ModelTaskPublishResponse,
GetFrameworksRequest, GetFrameworksRequest,
DeleteModelRequest, DeleteModelRequest,
DeleteMetadataRequest, DeleteMetadataRequest,
AddOrUpdateMetadataRequest, AddOrUpdateMetadataRequest,
ModelsPublishManyRequest, ModelsPublishManyRequest,
ModelsPublishManyResponse,
ModelsDeleteManyRequest, ModelsDeleteManyRequest,
ModelsDeleteManyResponse,
) )
from apiserver.bll.model import ModelBLL from apiserver.bll.model import ModelBLL
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
@ -501,26 +497,13 @@ def set_ready(call: APICall, company_id: str, request: PublishModelRequest):
) )
@attr.s(auto_attribs=True)
class PublishRes:
published: int = 0
published_tasks: Sequence = []
def __add__(self, other: Tuple[int, ModelTaskPublishResponse]):
published, response = other
return PublishRes(
published=self.published + published,
published_tasks=[*self.published_tasks, *([response] if response else [])],
)
@endpoint( @endpoint(
"models.publish_many", "models.publish_many",
request_data_model=ModelsPublishManyRequest, request_data_model=ModelsPublishManyRequest,
response_data_model=ModelsPublishManyResponse, response_data_model=BatchResponse,
) )
def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest): def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
res, failures = run_batch_operation( results, failures = run_batch_operation(
func=partial( func=partial(
ModelBLL.publish_model, ModelBLL.publish_model,
company_id=company_id, company_id=company_id,
@ -528,11 +511,16 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
publish_task_func=publish_task if request.publish_task else None, publish_task_func=publish_task if request.publish_task else None,
), ),
ids=request.ids, ids=request.ids,
init_res=PublishRes(),
) )
call.result.data_model = ModelsPublishManyResponse( call.result.data_model = BatchResponse(
succeeded=res.published, published_tasks=res.published_tasks, failures=failures, succeeded=[
dict(
id=_id, updated=bool(updated), published_task=published_task.to_struct()
)
for _id, (updated, published_task) in results
],
failed=failures,
) )
@ -546,41 +534,30 @@ def delete(call: APICall, company_id, request: DeleteModelRequest):
company_id, projects=[model.project] if model.project else [] company_id, projects=[model.project] if model.project else []
) )
call.result.data = dict(deleted=del_count > 0, url=model.uri) call.result.data = dict(deleted=bool(del_count), url=model.uri)
@attr.s(auto_attribs=True)
class DeleteRes:
deleted: int = 0
projects: Set = set()
urls: Set = set()
def __add__(self, other: Tuple[int, Model]):
del_count, model = other
return DeleteRes(
deleted=self.deleted + del_count,
projects=self.projects | {model.project},
urls=self.urls | {model.uri},
)
@endpoint( @endpoint(
"models.delete_many", "models.delete_many",
request_data_model=ModelsDeleteManyRequest, request_data_model=ModelsDeleteManyRequest,
response_data_model=ModelsDeleteManyResponse, response_data_model=BatchResponse,
) )
def delete(call: APICall, company_id, request: ModelsDeleteManyRequest): def delete(call: APICall, company_id, request: ModelsDeleteManyRequest):
res, failures = run_batch_operation( results, failures = run_batch_operation(
func=partial(ModelBLL.delete_model, company_id=company_id, force=request.force), func=partial(ModelBLL.delete_model, company_id=company_id, force=request.force),
ids=request.ids, ids=request.ids,
init_res=DeleteRes(),
) )
if res.deleted:
_reset_cached_tags(company_id, projects=list(res.projects))
res.urls.discard(None) if results:
call.result.data_model = ModelsDeleteManyResponse( projects = set(model.project for _, (_, model) in results)
succeeded=res.deleted, urls=list(res.urls), failures=failures, _reset_cached_tags(company_id, projects=list(projects))
call.result.data_model = BatchResponse(
succeeded=[
dict(id=_id, deleted=bool(deleted), url=model.uri)
for _id, (deleted, model) in results
],
failed=failures,
) )
@ -590,12 +567,13 @@ def delete(call: APICall, company_id, request: ModelsDeleteManyRequest):
response_data_model=BatchResponse, response_data_model=BatchResponse,
) )
def archive_many(call: APICall, company_id, request: BatchRequest): def archive_many(call: APICall, company_id, request: BatchRequest):
archived, failures = run_batch_operation( results, failures = run_batch_operation(
func=partial(ModelBLL.archive_model, company_id=company_id), func=partial(ModelBLL.archive_model, company_id=company_id), ids=request.ids,
ids=request.ids, )
init_res=0, call.result.data_model = BatchResponse(
succeeded=[dict(id=_id, archived=bool(archived)) for _id, archived in results],
failed=failures,
) )
call.result.data_model = BatchResponse(succeeded=archived, failures=failures)
@endpoint( @endpoint(
@ -604,12 +582,15 @@ def archive_many(call: APICall, company_id, request: BatchRequest):
response_data_model=BatchResponse, response_data_model=BatchResponse,
) )
def unarchive_many(call: APICall, company_id, request: BatchRequest): def unarchive_many(call: APICall, company_id, request: BatchRequest):
unarchived, failures = run_batch_operation( results, failures = run_batch_operation(
func=partial(ModelBLL.unarchive_model, company_id=company_id), func=partial(ModelBLL.unarchive_model, company_id=company_id), ids=request.ids,
ids=request.ids, )
init_res=0, call.result.data_model = BatchResponse(
succeeded=[
dict(id=_id, unarchived=bool(unarchived)) for _id, unarchived in results
],
failed=failures,
) )
call.result.data_model = BatchResponse(succeeded=unarchived, failures=failures,)
@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest) @endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest)

View File

@ -1,7 +1,7 @@
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Sequence, Union, Tuple, Set from typing import Sequence, Union, Tuple
import attr import attr
import dpath import dpath
@ -17,12 +17,15 @@ from apiserver.apimodels.base import (
MakePublicRequest, MakePublicRequest,
MoveRequest, MoveRequest,
) )
from apiserver.apimodels.batch import BatchResponse from apiserver.apimodels.batch import (
BatchResponse,
UpdateBatchResponse,
UpdateBatchItem,
)
from apiserver.apimodels.tasks import ( from apiserver.apimodels.tasks import (
StartedResponse, StartedResponse,
ResetResponse, ResetResponse,
PublishRequest, PublishRequest,
PublishResponse,
CreateRequest, CreateRequest,
UpdateRequest, UpdateRequest,
SetRequirementsRequest, SetRequirementsRequest,
@ -54,6 +57,12 @@ from apiserver.apimodels.tasks import (
DeleteManyRequest, DeleteManyRequest,
PublishManyRequest, PublishManyRequest,
TaskBatchRequest, TaskBatchRequest,
EnqueueManyResponse,
EnqueueBatchItem,
DequeueBatchItem,
DequeueManyResponse,
ResetManyResponse,
ResetBatchItem,
) )
from apiserver.bll.event import EventBLL from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL from apiserver.bll.model import ModelBLL
@ -77,10 +86,10 @@ from apiserver.bll.task.param_utils import (
params_unprepare_from_saved, params_unprepare_from_saved,
escape_paths, escape_paths,
) )
from apiserver.bll.task.task_cleanup import CleanupResult
from apiserver.bll.task.task_operations import ( from apiserver.bll.task.task_operations import (
stop_task, stop_task,
enqueue_task, enqueue_task,
dequeue_task,
reset_task, reset_task,
archive_task, archive_task,
delete_task, delete_task,
@ -287,21 +296,13 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
) )
@attr.s(auto_attribs=True)
class StopRes:
stopped: int = 0
def __add__(self, other: dict):
return StopRes(stopped=self.stopped + 1)
@endpoint( @endpoint(
"tasks.stop_many", "tasks.stop_many",
request_data_model=StopManyRequest, request_data_model=StopManyRequest,
response_data_model=BatchResponse, response_data_model=UpdateBatchResponse,
) )
def stop_many(call: APICall, company_id, request: StopManyRequest): def stop_many(call: APICall, company_id, request: StopManyRequest):
res, failures = run_batch_operation( results, failures = run_batch_operation(
func=partial( func=partial(
stop_task, stop_task,
company_id=company_id, company_id=company_id,
@ -310,9 +311,11 @@ def stop_many(call: APICall, company_id, request: StopManyRequest):
force=request.force, force=request.force,
), ),
ids=request.ids, ids=request.ids,
init_res=StopRes(),
) )
call.result.data_model = BatchResponse(succeeded=res.stopped, failures=failures) call.result.data_model = UpdateBatchResponse(
succeeded=[UpdateBatchItem(id=_id, **res) for _id, res in results],
failed=failures,
)
@endpoint( @endpoint(
@ -854,22 +857,13 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
call.result.data_model = EnqueueResponse(queued=queued, **res) call.result.data_model = EnqueueResponse(queued=queued, **res)
@attr.s(auto_attribs=True)
class EnqueueRes:
queued: int = 0
def __add__(self, other: Tuple[int, dict]):
queued, _ = other
return EnqueueRes(queued=self.queued + queued)
@endpoint( @endpoint(
"tasks.enqueue_many", "tasks.enqueue_many",
request_data_model=EnqueueManyRequest, request_data_model=EnqueueManyRequest,
response_data_model=BatchResponse, response_data_model=EnqueueManyResponse,
) )
def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest): def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
res, failures = run_batch_operation( results, failures = run_batch_operation(
func=partial( func=partial(
enqueue_task, enqueue_task,
company_id=company_id, company_id=company_id,
@ -879,9 +873,14 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
validate=request.validate_tasks, validate=request.validate_tasks,
), ),
ids=request.ids, ids=request.ids,
init_res=EnqueueRes(),
) )
call.result.data_model = BatchResponse(succeeded=res.queued, failures=failures) call.result.data_model = EnqueueManyResponse(
succeeded=[
EnqueueBatchItem(id=_id, queued=bool(queued), **res)
for _id, (queued, res) in results
],
failed=failures,
)
@endpoint( @endpoint(
@ -890,23 +889,37 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
response_data_model=DequeueResponse, response_data_model=DequeueResponse,
) )
def dequeue(call: APICall, company_id, request: UpdateRequest): def dequeue(call: APICall, company_id, request: UpdateRequest):
task = TaskBLL.get_task_with_access( dequeued, res = dequeue_task(
request.task, task_id=request.task,
company_id=company_id, company_id=company_id,
only=("id", "execution", "status", "project", "enqueue_status"), status_message=request.status_message,
requires_write_access=True, status_reason=request.status_reason,
) )
res = DequeueResponse( call.result.data_model = DequeueResponse(dequeued=dequeued, **res)
**TaskBLL.dequeue_and_change_status(
task,
company_id, @endpoint(
"tasks.dequeue_many",
request_data_model=TaskBatchRequest,
response_data_model=DequeueManyResponse,
)
def dequeue_many(call: APICall, company_id, request: TaskBatchRequest):
results, failures = run_batch_operation(
func=partial(
dequeue_task,
company_id=company_id,
status_message=request.status_message, status_message=request.status_message,
status_reason=request.status_reason, status_reason=request.status_reason,
) ),
ids=request.ids,
)
call.result.data_model = DequeueManyResponse(
succeeded=[
DequeueBatchItem(id=_id, dequeued=bool(dequeued), **res)
for _id, (dequeued, res) in results
],
failed=failures,
) )
res.dequeued = 1
call.result.data_model = res
@endpoint( @endpoint(
@ -932,25 +945,13 @@ def reset(call: APICall, company_id, request: ResetRequest):
call.result.data_model = res call.result.data_model = res
@attr.s(auto_attribs=True) @endpoint(
class ResetRes: "tasks.reset_many",
reset: int = 0 request_data_model=ResetManyRequest,
dequeued: int = 0 response_data_model=ResetManyResponse,
cleanup_res: CleanupResult = None )
def __add__(self, other: Tuple[dict, CleanupResult, dict]):
dequeued, other_res, _ = other
dequeued = dequeued.get("removed", 0) if dequeued else 0
return ResetRes(
reset=self.reset + 1,
dequeued=self.dequeued + dequeued,
cleanup_res=self.cleanup_res + other_res if self.cleanup_res else other_res,
)
@endpoint("tasks.reset_many", request_data_model=ResetManyRequest)
def reset_many(call: APICall, company_id, request: ResetManyRequest): def reset_many(call: APICall, company_id, request: ResetManyRequest):
res, failures = run_batch_operation( results, failures = run_batch_operation(
func=partial( func=partial(
reset_task, reset_task,
company_id=company_id, company_id=company_id,
@ -960,18 +961,26 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
clear_all=request.clear_all, clear_all=request.clear_all,
), ),
ids=request.ids, ids=request.ids,
init_res=ResetRes(),
) )
if res.cleanup_res: def clean_res(res: dict) -> dict:
cleanup_res = dict( # do not return artifacts since they are not serializable
deleted_models=res.cleanup_res.deleted_models, fields = res.get("fields")
urls=attr.asdict(res.cleanup_res.urls) if res.cleanup_res.urls else {}, if fields:
) fields.pop("execution.artifacts", None)
else: return res
cleanup_res = {}
call.result.data = dict( call.result.data_model = ResetManyResponse(
succeeded=res.reset, dequeued=res.dequeued, **cleanup_res, failures=failures, succeeded=[
ResetBatchItem(
id=_id,
dequeued=bool(dequeued.get("removed")) if dequeued else False,
**attr.asdict(cleanup),
**clean_res(res),
)
for _id, (dequeued, cleanup, res) in results
],
failed=failures,
) )
@ -1004,7 +1013,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
response_data_model=BatchResponse, response_data_model=BatchResponse,
) )
def archive_many(call: APICall, company_id, request: TaskBatchRequest): def archive_many(call: APICall, company_id, request: TaskBatchRequest):
archived, failures = run_batch_operation( results, failures = run_batch_operation(
func=partial( func=partial(
archive_task, archive_task,
company_id=company_id, company_id=company_id,
@ -1012,9 +1021,11 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
status_reason=request.status_reason, status_reason=request.status_reason,
), ),
ids=request.ids, ids=request.ids,
init_res=0,
) )
call.result.data_model = BatchResponse(succeeded=archived, failures=failures) call.result.data_model = BatchResponse(
succeeded=[dict(id=_id, archived=bool(archived)) for _id, archived in results],
failed=failures,
)
@endpoint( @endpoint(
@ -1023,7 +1034,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
response_data_model=BatchResponse, response_data_model=BatchResponse,
) )
def unarchive_many(call: APICall, company_id, request: TaskBatchRequest): def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
unarchived, failures = run_batch_operation( results, failures = run_batch_operation(
func=partial( func=partial(
unarchive_task, unarchive_task,
company_id=company_id, company_id=company_id,
@ -1031,9 +1042,13 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
status_reason=request.status_reason, status_reason=request.status_reason,
), ),
ids=request.ids, ids=request.ids,
init_res=0,
) )
call.result.data_model = BatchResponse(succeeded=unarchived, failures=failures) call.result.data_model = BatchResponse(
succeeded=[
dict(id=_id, unarchived=bool(unarchived)) for _id, unarchived in results
],
failed=failures,
)
@endpoint("tasks.delete", request_data_model=DeleteRequest) @endpoint("tasks.delete", request_data_model=DeleteRequest)
@ -1051,25 +1066,9 @@ def delete(call: APICall, company_id, request: DeleteRequest):
call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res)) call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res))
@attr.s(auto_attribs=True)
class DeleteRes:
deleted: int = 0
projects: Set = set()
cleanup_res: CleanupResult = None
def __add__(self, other: Tuple[int, Task, CleanupResult]):
del_count, task, other_res = other
return DeleteRes(
deleted=self.deleted + del_count,
projects=self.projects | {task.project},
cleanup_res=self.cleanup_res + other_res if self.cleanup_res else other_res,
)
@endpoint("tasks.delete_many", request_data_model=DeleteManyRequest) @endpoint("tasks.delete_many", request_data_model=DeleteManyRequest)
def delete_many(call: APICall, company_id, request: DeleteManyRequest): def delete_many(call: APICall, company_id, request: DeleteManyRequest):
res, failures = run_batch_operation( results, failures = run_batch_operation(
func=partial( func=partial(
delete_task, delete_task,
company_id=company_id, company_id=company_id,
@ -1079,20 +1078,25 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
delete_output_models=request.delete_output_models, delete_output_models=request.delete_output_models,
), ),
ids=request.ids, ids=request.ids,
init_res=DeleteRes(),
) )
if res.deleted: if results:
_reset_cached_tags(company_id, projects=list(res.projects)) projects = set(task.project for _, (_, task, _) in results)
_reset_cached_tags(company_id, projects=list(projects))
cleanup_res = attr.asdict(res.cleanup_res) if res.cleanup_res else {} call.result.data = dict(
call.result.data = dict(succeeded=res.deleted, **cleanup_res, failures=failures) succeeded=[
dict(id=_id, deleted=bool(deleted), **attr.asdict(cleanup_res))
for _id, (deleted, _, cleanup_res) in results
],
failed=failures,
)
@endpoint( @endpoint(
"tasks.publish", "tasks.publish",
request_data_model=PublishRequest, request_data_model=PublishRequest,
response_data_model=PublishResponse, response_data_model=UpdateResponse,
) )
def publish(call: APICall, company_id, request: PublishRequest): def publish(call: APICall, company_id, request: PublishRequest):
updates = publish_task( updates = publish_task(
@ -1103,24 +1107,16 @@ def publish(call: APICall, company_id, request: PublishRequest):
status_reason=request.status_reason, status_reason=request.status_reason,
status_message=request.status_message, status_message=request.status_message,
) )
call.result.data_model = PublishResponse(**updates) call.result.data_model = UpdateResponse(**updates)
@attr.s(auto_attribs=True)
class PublishRes:
published: int = 0
def __add__(self, other: dict):
return PublishRes(published=self.published + 1)
@endpoint( @endpoint(
"tasks.publish_many", "tasks.publish_many",
request_data_model=PublishManyRequest, request_data_model=PublishManyRequest,
response_data_model=BatchResponse, response_data_model=UpdateBatchResponse,
) )
def publish_many(call: APICall, company_id, request: PublishManyRequest): def publish_many(call: APICall, company_id, request: PublishManyRequest):
res, failures = run_batch_operation( results, failures = run_batch_operation(
func=partial( func=partial(
publish_task, publish_task,
company_id=company_id, company_id=company_id,
@ -1132,10 +1128,12 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest):
status_message=request.status_message, status_message=request.status_message,
), ),
ids=request.ids, ids=request.ids,
init_res=PublishRes(),
) )
call.result.data_model = BatchResponse(succeeded=res.published, failures=failures) call.result.data_model = UpdateBatchResponse(
succeeded=[UpdateBatchItem(id=_id, **res) for _id, res in results],
failed=failures,
)
@endpoint( @endpoint(

View File

@ -21,8 +21,8 @@ class TestBatchOperations(TestService):
# enqueue # enqueue
res = self.api.tasks.enqueue_many(ids=ids) res = self.api.tasks.enqueue_many(ids=ids)
self.assertEqual(res.succeeded, 2) self._assert_succeeded(res, tasks)
self._assert_failures(res, [missing_id]) self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual({t.status for t in data}, {"queued"}) self.assertEqual({t.status for t in data}, {"queued"})
@ -30,15 +30,15 @@ class TestBatchOperations(TestService):
for t in tasks: for t in tasks:
self.api.tasks.started(task=t) self.api.tasks.started(task=t)
res = self.api.tasks.stop_many(ids=ids) res = self.api.tasks.stop_many(ids=ids)
self.assertEqual(res.succeeded, 2) self._assert_succeeded(res, tasks)
self._assert_failures(res, [missing_id]) self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual({t.status for t in data}, {"stopped"}) self.assertEqual({t.status for t in data}, {"stopped"})
# publish # publish
res = self.api.tasks.publish_many(ids=ids, publish_model=False) res = self.api.tasks.publish_many(ids=ids, publish_model=False)
self.assertEqual(res.succeeded, 2) self._assert_succeeded(res, tasks)
self._assert_failures(res, [missing_id]) self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual({t.status for t in data}, {"published"}) self.assertEqual({t.status for t in data}, {"published"})
@ -46,23 +46,26 @@ class TestBatchOperations(TestService):
res = self.api.tasks.reset_many( res = self.api.tasks.reset_many(
ids=ids, delete_output_models=True, return_file_urls=True, force=True ids=ids, delete_output_models=True, return_file_urls=True, force=True
) )
self.assertEqual(res.succeeded, 2) self._assert_succeeded(res, tasks)
self.assertEqual(res.deleted_models, 2) self.assertEqual(sum(t.deleted_models for t in res.succeeded), 2)
self.assertEqual(set(res.urls.model_urls), {"uri_0", "uri_1"}) self.assertEqual(
self._assert_failures(res, [missing_id]) set(url for t in res.succeeded for url in t.urls.model_urls),
{"uri_0", "uri_1"},
)
self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual({t.status for t in data}, {"created"}) self.assertEqual({t.status for t in data}, {"created"})
# archive/unarchive # archive/unarchive
res = self.api.tasks.archive_many(ids=ids) res = self.api.tasks.archive_many(ids=ids)
self.assertEqual(res.succeeded, 2) self._assert_succeeded(res, tasks)
self._assert_failures(res, [missing_id]) self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertTrue(all("archived" in t.system_tags for t in data)) self.assertTrue(all("archived" in t.system_tags for t in data))
res = self.api.tasks.unarchive_many(ids=ids) res = self.api.tasks.unarchive_many(ids=ids)
self.assertEqual(res.succeeded, 2) self._assert_succeeded(res, tasks)
self._assert_failures(res, [missing_id]) self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertFalse(any("archived" in t.system_tags for t in data)) self.assertFalse(any("archived" in t.system_tags for t in data))
@ -70,8 +73,8 @@ class TestBatchOperations(TestService):
res = self.api.tasks.delete_many( res = self.api.tasks.delete_many(
ids=ids, delete_output_models=True, return_file_urls=True ids=ids, delete_output_models=True, return_file_urls=True
) )
self.assertEqual(res.succeeded, 2) self._assert_succeeded(res, tasks)
self._assert_failures(res, [missing_id]) self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual(data, []) self.assertEqual(data, [])
@ -90,33 +93,36 @@ class TestBatchOperations(TestService):
res = self.api.models.publish_many( res = self.api.models.publish_many(
ids=ids, publish_task=True, force_publish_task=True ids=ids, publish_task=True, force_publish_task=True
) )
self.assertEqual(res.succeeded, 1) self._assert_succeeded(res, [ids[0]])
self.assertEqual(res.published_tasks[0].id, task) self.assertEqual(res.succeeded[0].published_task.id, task)
self._assert_failures(res, [ids[1], missing_id]) self._assert_failed(res, [ids[1], missing_id])
# archive/unarchive # archive/unarchive
res = self.api.models.archive_many(ids=ids) res = self.api.models.archive_many(ids=ids)
self.assertEqual(res.succeeded, 2) self._assert_succeeded(res, models)
self._assert_failures(res, [missing_id]) self._assert_failed(res, [missing_id])
data = self.api.models.get_all_ex(id=ids).models data = self.api.models.get_all_ex(id=ids).models
self.assertTrue(all("archived" in m.system_tags for m in data)) self.assertTrue(all("archived" in m.system_tags for m in data))
res = self.api.models.unarchive_many(ids=ids) res = self.api.models.unarchive_many(ids=ids)
self.assertEqual(res.succeeded, 2) self._assert_succeeded(res, models)
self._assert_failures(res, [missing_id]) self._assert_failed(res, [missing_id])
data = self.api.models.get_all_ex(id=ids).models data = self.api.models.get_all_ex(id=ids).models
self.assertFalse(any("archived" in m.system_tags for m in data)) self.assertFalse(any("archived" in m.system_tags for m in data))
# delete # delete
res = self.api.models.delete_many(ids=[*models, missing_id], force=True) res = self.api.models.delete_many(ids=[*models, missing_id], force=True)
self.assertEqual(res.succeeded, 2) self._assert_succeeded(res, models)
self.assertEqual(set(res.urls), set(uris)) self.assertEqual(set(m.url for m in res.succeeded), set(uris))
self._assert_failures(res, [missing_id]) self._assert_failed(res, [missing_id])
data = self.api.models.get_all_ex(id=ids).models data = self.api.models.get_all_ex(id=ids).models
self.assertEqual(data, []) self.assertEqual(data, [])
def _assert_failures(self, res, failed_ids): def _assert_succeeded(self, res, succeeded_ids):
self.assertEqual(set(f.id for f in res.failures), set(failed_ids)) self.assertEqual(set(f.id for f in res.succeeded), set(succeeded_ids))
def _assert_failed(self, res, failed_ids):
self.assertEqual(set(f.id for f in res.failed), set(failed_ids))
def _temp_model(self, **kwargs): def _temp_model(self, **kwargs):
self.update_missing(kwargs, name=self.name, uri="file:///a/b", labels={}) self.update_missing(kwargs, name=self.name, uri="file:///a/b", labels={})