Add batch operations support

This commit is contained in:
allegroai 2021-05-03 17:52:54 +03:00
parent eab33de97e
commit a75534ec34
17 changed files with 1444 additions and 674 deletions

View File

@ -0,0 +1,14 @@
from typing import Sequence
from jsonmodels.models import Base
from jsonmodels.validators import Length
from apiserver.apimodels import ListField
class BatchRequest(Base):
ids: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
class BatchResponse(Base):
failures: Sequence[dict] = ListField([dict])

View File

@ -3,6 +3,7 @@ from six import string_types
from apiserver.apimodels import ListField, DictField from apiserver.apimodels import ListField, DictField
from apiserver.apimodels.base import UpdateResponse from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.batch import BatchRequest, BatchResponse
from apiserver.apimodels.metadata import ( from apiserver.apimodels.metadata import (
MetadataItem, MetadataItem,
DeleteMetadata, DeleteMetadata,
@ -46,6 +47,23 @@ class DeleteModelRequest(ModelRequest):
force = fields.BoolField(default=False) force = fields.BoolField(default=False)
class ModelsDeleteManyRequest(BatchRequest):
force = fields.BoolField(default=False)
class ModelsArchiveManyRequest(BatchRequest):
pass
class ModelsArchiveManyResponse(BatchResponse):
archived = fields.IntField(required=True)
class ModelsDeleteManyResponse(BatchResponse):
deleted = fields.IntField()
urls = fields.ListField([str])
class PublishModelRequest(ModelRequest): class PublishModelRequest(ModelRequest):
force_publish_task = fields.BoolField(default=False) force_publish_task = fields.BoolField(default=False)
publish_task = fields.BoolField(default=True) publish_task = fields.BoolField(default=True)
@ -58,7 +76,16 @@ class ModelTaskPublishResponse(models.Base):
class PublishModelResponse(UpdateResponse): class PublishModelResponse(UpdateResponse):
published_task = fields.EmbeddedField(ModelTaskPublishResponse) published_task = fields.EmbeddedField(ModelTaskPublishResponse)
updated = fields.IntField()
class ModelsPublishManyRequest(BatchRequest):
force_publish_task = fields.BoolField(default=False)
publish_task = fields.BoolField(default=True)
class ModelsPublishManyResponse(BatchResponse):
published = fields.IntField(required=True)
published_tasks = fields.ListField([ModelTaskPublishResponse])
class DeleteMetadataRequest(DeleteMetadata): class DeleteMetadataRequest(DeleteMetadata):

View File

@ -7,6 +7,7 @@ from jsonmodels.validators import Enum, Length
from apiserver.apimodels import DictField, ListField from apiserver.apimodels import DictField, ListField
from apiserver.apimodels.base import UpdateResponse from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.batch import BatchRequest, BatchResponse
from apiserver.database.model.task.task import ( from apiserver.database.model.task.task import (
TaskType, TaskType,
ArtifactModes, ArtifactModes,
@ -52,7 +53,7 @@ class ResetResponse(UpdateResponse):
dequeued = DictField() dequeued = DictField()
frames = DictField() frames = DictField()
events = DictField() events = DictField()
model_deleted = IntField() deleted_models = IntField()
urls = DictField() urls = DictField()
@ -230,6 +231,54 @@ class ArchiveResponse(models.Base):
archived = IntField() archived = IntField()
class TaskBatchRequest(BatchRequest):
status_reason = StringField(default="")
status_message = StringField(default="")
class StopManyRequest(TaskBatchRequest):
force = BoolField(default=False)
class StopManyResponse(BatchResponse):
stopped = IntField(required=True)
class ArchiveManyRequest(TaskBatchRequest):
pass
class ArchiveManyResponse(BatchResponse):
archived = IntField(required=True)
class EnqueueManyRequest(TaskBatchRequest):
queue = StringField()
class EnqueueManyResponse(BatchResponse):
queued = IntField()
class DeleteManyRequest(TaskBatchRequest):
move_to_trash = BoolField(default=True)
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
force = BoolField(default=False)
class ResetManyRequest(TaskBatchRequest):
clear_all = BoolField(default=False)
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
force = BoolField(default=False)
class PublishManyRequest(TaskBatchRequest):
publish_model = BoolField(default=True)
force = BoolField(default=False)
class ModelItemType(object): class ModelItemType(object):
input = "input" input = "input"
output = "output" output = "output"

View File

@ -0,0 +1,116 @@
from datetime import datetime
from typing import Callable, Tuple
from apiserver.apierrors import errors
from apiserver.apimodels.models import ModelTaskPublishResponse
from apiserver.bll.task.utils import deleted_prefix
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus
class ModelBLL:
@classmethod
def get_company_model_by_id(
cls, company_id: str, model_id: str, only_fields=None
) -> Model:
query = dict(company=company_id, id=model_id)
qs = Model.objects(**query)
if only_fields:
qs = qs.only(*only_fields)
model = qs.first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
return model
@classmethod
def publish_model(
cls,
model_id: str,
company_id: str,
force_publish_task: bool = False,
publish_task_func: Callable[[str, str, bool], dict] = None,
) -> Tuple[int, ModelTaskPublishResponse]:
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
if model.ready:
raise errors.bad_request.ModelIsReady(company=company_id, model=model_id)
published_task = None
if model.task and publish_task_func:
task = (
Task.objects(id=model.task, company=company_id)
.only("id", "status")
.first()
)
if task and task.status != TaskStatus.published:
task_publish_res = publish_task_func(
model.task, company_id, force_publish_task
)
published_task = ModelTaskPublishResponse(
id=model.task, data=task_publish_res
)
updated = model.update(upsert=False, ready=True)
return updated, published_task
@classmethod
def delete_model(
cls, model_id: str, company_id: str, force: bool
) -> Tuple[int, Model]:
model = cls.get_company_model_by_id(
company_id=company_id,
model_id=model_id,
only_fields=("id", "task", "project", "uri"),
)
deleted_model_id = f"{deleted_prefix}{model_id}"
using_tasks = Task.objects(models__input__model=model_id).only("id")
if using_tasks:
if not force:
raise errors.bad_request.ModelInUse(
"as execution model, use force=True to delete",
num_tasks=len(using_tasks),
)
# update deleted model id in using tasks
Task._get_collection().update_many(
filter={"_id": {"$in": [t.id for t in using_tasks]}},
update={"$set": {"models.input.$[elem].model": deleted_model_id}},
array_filters=[{"elem.model": model_id}],
upsert=False,
)
if model.task:
task = Task.objects(id=model.task).first()
if task and task.status == TaskStatus.published:
if not force:
raise errors.bad_request.ModelCreatingTaskExists(
"and published, use force=True to delete", task=model.task
)
if task.models.output and model_id in task.models.output:
now = datetime.utcnow()
Task._get_collection().update_one(
filter={"_id": model.task, "models.output.model": model_id},
update={
"$set": {
"models.output.$[elem].model": deleted_model_id,
"output.error": f"model deleted on {now.isoformat()}",
},
"last_change": now,
},
array_filters=[{"elem.model": model_id}],
upsert=False,
)
del_count = Model.objects(id=model_id, company=company_id).delete()
return del_count, model
@classmethod
def archive_model(cls, model_id: str, company_id: str):
cls.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
archived = Model.objects(company=company_id, id=model_id).update(
add_to_set__system_tags=EntityVisibility.archived.value
)
return archived

View File

@ -11,6 +11,7 @@ from six import string_types
import apiserver.database.utils as dbutils import apiserver.database.utils as dbutils
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.apimodels.tasks import TaskInputModel
from apiserver.bll.queue import QueueBLL from apiserver.bll.queue import QueueBLL
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, project_ids_with_children from apiserver.bll.project import ProjectBLL, project_ids_with_children
@ -23,7 +24,6 @@ from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import ( from apiserver.database.model.task.task import (
Task, Task,
TaskStatus, TaskStatus,
TaskStatusMessage,
TaskSystemTags, TaskSystemTags,
ArtifactModes, ArtifactModes,
ModelItem, ModelItem,
@ -41,11 +41,9 @@ from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save from .param_utils import params_prepare_for_save
from .utils import ( from .utils import (
ChangeStatusRequest, ChangeStatusRequest,
validate_status_change,
update_project_time, update_project_time,
deleted_prefix, deleted_prefix,
) )
from ...apimodels.tasks import TaskInputModel
log = config.logger(__file__) log = config.logger(__file__)
org_bll = OrgBLL() org_bll = OrgBLL()
@ -482,147 +480,6 @@ class TaskBLL:
**extra_updates, **extra_updates,
) )
@classmethod
def model_set_ready(
cls,
model_id: str,
company_id: str,
publish_task: bool,
force_publish_task: bool = False,
) -> tuple:
with translate_errors_context():
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
elif model.ready:
raise errors.bad_request.ModelIsReady(**query)
published_task_data = {}
if model.task and publish_task:
task = (
Task.objects(id=model.task, company=company_id)
.only("id", "status")
.first()
)
if task and task.status != TaskStatus.published:
published_task_data["data"] = cls.publish_task(
task_id=model.task,
company_id=company_id,
publish_model=False,
force=force_publish_task,
)
published_task_data["id"] = model.task
updated = model.update(upsert=False, ready=True)
return updated, published_task_data
@classmethod
def publish_task(
cls,
task_id: str,
company_id: str,
publish_model: bool,
force: bool,
status_reason: str = "",
status_message: str = "",
) -> dict:
task = cls.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
if not force:
validate_status_change(task.status, TaskStatus.published)
previous_task_status = task.status
output = task.output or Output()
publish_failed = False
try:
# set state to publishing
task.status = TaskStatus.publishing
task.save()
# publish task models
if task.models.output and publish_model:
model_ids = [m.model for m in task.models.output]
for model in Model.objects(id__in=model_ids, ready__ne=True).only("id"):
cls.model_set_ready(
model_id=model.id, company_id=company_id, publish_task=False,
)
# set task status to published, and update (or set) it's new output (view and models)
return ChangeStatusRequest(
task=task,
new_status=TaskStatus.published,
force=force,
status_reason=status_reason,
status_message=status_message,
).execute(published=datetime.utcnow(), output=output)
except Exception as ex:
publish_failed = True
raise ex
finally:
if publish_failed:
task.status = previous_task_status
task.save()
@classmethod
def stop_task(
cls,
task_id: str,
company_id: str,
user_name: str,
status_reason: str,
force: bool,
) -> dict:
"""
Stop a running task. Requires task status 'in_progress' and
execution_progress 'running', or force=True. Development task or
task that has no associated worker is stopped immediately.
For a non-development task with worker only the status message
is set to 'stopping' to allow the worker to stop the task and report by itself
:return: updated task fields
"""
task = cls.get_task_with_access(
task_id,
company_id=company_id,
only=(
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
),
requires_write_access=True,
)
def is_run_by_worker(t: Task) -> bool:
"""Checks if there is an active worker running the task"""
update_timeout = config.get("apiserver.workers.task_update_timeout", 600)
return (
t.last_worker
and t.last_update
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
)
if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task):
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
new_status = task.status
status_message = TaskStatusMessage.stopping
return ChangeStatusRequest(
task=task,
new_status=new_status,
status_reason=status_reason,
status_message=status_message,
force=force,
).execute()
@staticmethod @staticmethod
def get_aggregated_project_parameters( def get_aggregated_project_parameters(
company_id, company_id,

View File

@ -68,6 +68,16 @@ class TaskUrls:
event_urls: Sequence[str] event_urls: Sequence[str]
artifact_urls: Sequence[str] artifact_urls: Sequence[str]
def __add__(self, other: "TaskUrls"):
if not other:
return self
return TaskUrls(
model_urls=list(set(self.model_urls) | set(other.model_urls)),
event_urls=list(set(self.event_urls) | set(other.event_urls)),
artifact_urls=list(set(self.artifact_urls) | set(other.artifact_urls)),
)
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class CleanupResult: class CleanupResult:
@ -80,6 +90,17 @@ class CleanupResult:
deleted_models: int deleted_models: int
urls: TaskUrls = None urls: TaskUrls = None
def __add__(self, other: "CleanupResult"):
if not other:
return self
return CleanupResult(
updated_children=self.updated_children + other.updated_children,
updated_models=self.updated_models + other.updated_models,
deleted_models=self.deleted_models + other.deleted_models,
urls=self.urls + other.urls if self.urls else other.urls,
)
def collect_plot_image_urls(company: str, task: str) -> Set[str]: def collect_plot_image_urls(company: str, task: str) -> Set[str]:
urls = set() urls = set()
@ -224,7 +245,7 @@ def verify_task_children_and_ouptuts(task: Task, force: bool) -> TaskOutputs[Mod
models=len(models.published), models=len(models.published),
) )
if task.models.output: if task.models and task.models.output:
with TimingContext("mongo", "get_task_output_model"): with TimingContext("mongo", "get_task_output_model"):
model_ids = [m.model for m in task.models.output] model_ids = [m.model for m in task.models.output]
for output_model in Model.objects(id__in=model_ids): for output_model in Model.objects(id__in=model_ids):
@ -243,11 +264,13 @@ def verify_task_children_and_ouptuts(task: Task, force: bool) -> TaskOutputs[Mod
with TimingContext("mongo", "get_execution_models"): with TimingContext("mongo", "get_execution_models"):
model_ids = models.draft.ids model_ids = models.draft.ids
dependent_tasks = Task.objects(models__input__model__in=model_ids).only( dependent_tasks = Task.objects(models__input__model__in=model_ids).only(
"id", "models__input" "id", "models"
) )
input_models = { input_models = {
m.model m.model
for m in chain.from_iterable(t.models.input for t in dependent_tasks) for m in chain.from_iterable(
t.models.input for t in dependent_tasks if t.models
)
} }
if input_models: if input_models:
models.draft = DocumentGroup( models.draft = DocumentGroup(

View File

@ -0,0 +1,329 @@
from datetime import datetime
from typing import Callable, Any, Tuple, Union
from apiserver.apierrors import errors, APIError
from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import (
TaskBLL,
validate_status_change,
ChangeStatusRequest,
update_project_time,
)
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import (
TaskStatus,
Task,
TaskSystemTags,
TaskStatusMessage,
ArtifactModes,
Execution,
DEFAULT_LAST_ITERATION,
)
from apiserver.utilities.dicts import nested_set
queue_bll = QueueBLL()
def archive_task(
task: Union[str, Task], company_id: str, status_message: str, status_reason: str,
) -> int:
"""
Deque and archive task
Return 1 if successful
"""
if isinstance(task, str):
task = TaskBLL.get_task_with_access(
task,
company_id=company_id,
only=("id", "execution", "status", "project", "system_tags"),
requires_write_access=True,
)
try:
TaskBLL.dequeue_and_change_status(
task, company_id, status_message, status_reason,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
task.update(
status_message=status_message,
status_reason=status_reason,
add_to_set__system_tags={EntityVisibility.archived.value},
last_change=datetime.utcnow(),
)
return 1
def enqueue_task(
task_id: str,
company_id: str,
queue_id: str,
status_message: str,
status_reason: str,
) -> Tuple[int, dict]:
if not queue_id:
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(
_only=("type", "script", "execution", "status", "project", "id"), **query
)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
res = ChangeStatusRequest(
task=task,
new_status=TaskStatus.queued,
status_reason=status_reason,
status_message=status_message,
allow_same_state_transition=False,
).execute()
try:
queue_bll.add_task(company_id=company_id, queue_id=queue_id, task_id=task.id)
except Exception:
# failed enqueueing, revert to previous state
ChangeStatusRequest(
task=task,
current_status_override=TaskStatus.queued,
new_status=task.status,
force=True,
status_reason="failed enqueueing",
).execute()
raise
# set the current queue ID in the task
if task.execution:
Task.objects(**query).update(execution__queue=queue_id, multi=False)
else:
Task.objects(**query).update(execution=Execution(queue=queue_id), multi=False)
nested_set(res, ("fields", "execution.queue"), queue_id)
return 1, res
def delete_task(
task_id: str,
company_id: str,
move_to_trash: bool,
force: bool,
return_file_urls: bool,
delete_output_models: bool,
) -> Tuple[int, Task, CleanupResult]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
if (
task.status != TaskStatus.created
and EntityVisibility.archived.value not in task.system_tags
and not force
):
raise errors.bad_request.TaskCannotBeDeleted(
"due to status, use force=True",
task=task.id,
expected=TaskStatus.created,
current=task.status,
)
cleanup_res = cleanup_task(
task,
force=force,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
)
if move_to_trash:
collection_name = task._get_collection_name()
archived_collection = "{}__trash".format(collection_name)
task.switch_collection(archived_collection)
try:
# A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force
# an insert. However, if for some reason such an ID exists, let's make sure we'll keep going.
task.save(force_insert=True)
except Exception:
pass
task.switch_collection(collection_name)
task.delete()
update_project_time(task.project)
return 1, task, cleanup_res
def reset_task(
task_id: str,
company_id: str,
force: bool,
return_file_urls: bool,
delete_output_models: bool,
clear_all: bool,
) -> Tuple[dict, CleanupResult, dict]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
if not force and task.status == TaskStatus.published:
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
dequeued = {}
updates = {}
try:
dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
except APIError:
# dequeue may fail if the task was not enqueued
pass
cleaned_up = cleanup_task(
task,
force=force,
update_children=False,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
)
updates.update(
set__last_iteration=DEFAULT_LAST_ITERATION,
set__last_metrics={},
set__metric_stats={},
set__models__output=[],
unset__output__result=1,
unset__output__error=1,
unset__last_worker=1,
unset__last_worker_report=1,
)
if clear_all:
updates.update(
set__execution=Execution(), unset__script=1,
)
else:
updates.update(unset__execution__queue=1)
if task.execution and task.execution.artifacts:
updates.update(
set__execution__artifacts={
key: artifact
for key, artifact in task.execution.artifacts.items()
if artifact.mode == ArtifactModes.input
}
)
res = ChangeStatusRequest(
task=task,
new_status=TaskStatus.created,
force=force,
status_reason="reset",
status_message="reset",
).execute(
started=None, completed=None, published=None, active_duration=None, **updates,
)
return dequeued, cleaned_up, res
def publish_task(
task_id: str,
company_id: str,
force: bool,
publish_model_func: Callable[[str, str], Any] = None,
status_message: str = "",
status_reason: str = "",
) -> dict:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
if not force:
validate_status_change(task.status, TaskStatus.published)
previous_task_status = task.status
output = task.output or Output()
publish_failed = False
try:
# set state to publishing
task.status = TaskStatus.publishing
task.save()
# publish task models
if task.models and task.models.output and publish_model_func:
model_id = task.models.output[-1].model
model = (
Model.objects(id=model_id, company=company_id)
.only("id", "ready")
.first()
)
if model and not model.ready:
publish_model_func(model.id, company_id)
# set task status to published, and update (or set) it's new output (view and models)
return ChangeStatusRequest(
task=task,
new_status=TaskStatus.published,
force=force,
status_reason=status_reason,
status_message=status_message,
).execute(published=datetime.utcnow(), output=output)
except Exception as ex:
publish_failed = True
raise ex
finally:
if publish_failed:
task.status = previous_task_status
task.save()
def stop_task(
task_id: str, company_id: str, user_name: str, status_reason: str, force: bool,
) -> dict:
"""
Stop a running task. Requires task status 'in_progress' and
execution_progress 'running', or force=True. Development task or
task that has no associated worker is stopped immediately.
For a non-development task with worker only the status message
is set to 'stopping' to allow the worker to stop the task and report by itself
:return: updated task fields
"""
task = TaskBLL.get_task_with_access(
task_id,
company_id=company_id,
only=(
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
),
requires_write_access=True,
)
def is_run_by_worker(t: Task) -> bool:
"""Checks if there is an active worker running the task"""
update_timeout = config.get("apiserver.workers.task_update_timeout", 600)
return (
t.last_worker
and t.last_update
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
)
if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task):
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
new_status = task.status
status_message = TaskStatusMessage.stopping
return ChangeStatusRequest(
task=task,
new_status=new_status,
status_reason=status_reason,
status_message=status_message,
force=force,
).execute()

View File

@ -1,10 +1,21 @@
import functools import functools
import itertools import itertools
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
from typing import Optional, Callable, Dict, Any, Set, Iterable from typing import (
Optional,
Callable,
Dict,
Any,
Set,
Iterable,
Tuple,
Sequence,
TypeVar,
)
from boltons import iterutils from boltons import iterutils
from apiserver.apierrors import APIError
from apiserver.database.model import AttributedDocument from apiserver.database.model import AttributedDocument
from apiserver.database.model.settings import Settings from apiserver.database.model.settings import Settings
@ -96,3 +107,28 @@ def parallel_chunked_decorator(func: Callable = None, chunk_size: int = 100):
) )
return wrapper return wrapper
T = TypeVar("T")
def run_batch_operation(
func: Callable[[str], T], init_res: T, ids: Sequence[str]
) -> Tuple[T, Sequence]:
res = init_res
failures = list()
for _id in ids:
try:
res += func(_id)
except APIError as err:
failures.append(
{
"id": _id,
"error": {
"codes": [err.code, err.subcode],
"msg": err.msg,
"data": err.error_data,
},
}
)
return res, failures

View File

@ -465,6 +465,7 @@ class PrePopulate:
task_models = chain.from_iterable( task_models = chain.from_iterable(
models models
for task in entities[cls.task_cls] for task in entities[cls.task_cls]
if task.models
for models in (task.models.input, task.models.output) for models in (task.models.input, task.models.output)
if models if models
) )

View File

@ -31,3 +31,45 @@ credentials {
} }
} }
} }
batch_operation {
request {
type: object
required: [ids]
properties {
ids {
description: Entities to move
type: array
items {type: string}
}
}
}
response {
failures {
type: array
item {
type: object
id: {
description: ID of the failed entity
type: string
}
error: {
description: Error info
type: object
properties {
codes {
type: array
item {type: integer}
}
msg {
type: string
}
data {
type: object
additionalProperties: True
}
}
}
}
}
}
}

View File

@ -94,6 +94,32 @@ _definitions {
} }
} }
} }
published_task_item {
description: "Result of publishing of the model's associated task (if exists). Returned only if the task was published successfully as part of the model publishing."
type: object
properties {
id {
description: "Task id"
type: string
}
data {
description: "Data returned from the task publishing operation."
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
}
} }
get_by_id { get_by_id {
@ -628,6 +654,33 @@ update {
} }
} }
} }
publish_many {
"2.13": ${_definitions.batch_operation} {
description: Publish models
request {
force_publish_task {
description: "Publish the associated tasks (if exist) even if they are not in the 'stopped' state. Optional, the default value is False."
type: boolean
}
publish_tasks {
description: "Indicates that the associated tasks (if exist) should be published. Optional, the default value is True."
type: boolean
}
}
response {
properties {
published {
description: "Number of models published"
type: integer
}
published_tasks {
type: array
items: ${_definitions.published_task_item}
}
}
}
}
}
set_ready { set_ready {
"2.1" { "2.1" {
description: "Set the model ready flag to True. If the model is an output model of a task then try to publish the task." description: "Set the model ready flag to True. If the model is an output model of a task then try to publish the task."
@ -657,40 +710,45 @@ set_ready {
type: integer type: integer
enum: [0, 1] enum: [0, 1]
} }
published_task { published_task: ${_definitions.published_task_item}
description: "Result of publishing of the model's associated task (if exists). Returned only if the task was published successfully as part of the model publishing." }
type: object }
}
}
archive_many {
"2.13": ${_definitions.batch_operation} {
description: Archive models
response {
properties { properties {
id { archived {
description: "Task id" description: "Number of models archived"
type: string
}
data {
description: "Data returned from the task publishing operation."
type: object
properties {
committed_versions_results {
description: "Committed versions results"
type: array
items {
type: object
additionalProperties: true
}
}
updated {
description: "Number of tasks updated (0 or 1)"
type: integer type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
} }
} }
} }
} }
} }
delete_many {
"2.13": ${_definitions.batch_operation} {
description: Delete models
request {
force {
description: """Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published.
"""
type: boolean
}
}
response {
properties {
deleted {
description: "Number of models deleted"
type: integer
}
urls {
descrition: "The urls of the deleted model files"
type: array
items {type: string}
}
} }
} }
} }
@ -876,4 +934,3 @@ delete_metadata {
} }
} }
} }

View File

@ -26,6 +26,35 @@ _references {
} }
_definitions { _definitions {
include "_common.conf" include "_common.conf"
change_many_request: ${_definitions.batch_operation} {
request {
properties {
status_reason {
description: Reason for status change
type: string
}
status_message {
description: Extra information regarding status change
type: string
}
}
}
}
update_response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
multi_field_pattern_data { multi_field_pattern_data {
type: object type: object
properties { properties {
@ -1216,21 +1245,7 @@ update {
} }
} }
} }
response { response: ${_definitions.update_response}
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
} }
} }
update_batch { update_batch {
@ -1328,21 +1343,7 @@ edit {
} }
} }
} }
response { response: ${_definitions.update_response}
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
} }
"2.13": ${edit."2.1"} { "2.13": ${edit."2.1"} {
request { request {
@ -1376,8 +1377,7 @@ reset {
default: false default: false
} }
} ${_references.status_change_request} } ${_references.status_change_request}
response { response: ${_definitions.update_response} {
type: object
properties { properties {
deleted_indices { deleted_indices {
description: "List of deleted ES indices that were removed as part of the reset process" description: "List of deleted ES indices that were removed as part of the reset process"
@ -1403,16 +1403,6 @@ reset {
description: "Number of output models deleted by the reset" description: "Number of output models deleted by the reset"
type: integer type: integer
} }
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
} }
} }
} }
@ -1435,6 +1425,101 @@ reset {
} }
} }
} }
reset_many {
"2.13": ${_definitions.batch_operation} {
description: Reset tasks
request {
properties {
force = ${_references.force_arg} {
description: "If not true, call fails if the task status is 'completed'"
}
clear_all {
description: "Clear script and execution sections completely"
type: boolean
default: false
}
return_file_urls {
description: "If set to 'true' then return the urls of the files that were uploaded by the tasks. Default value is 'false'"
type: boolean
}
delete_output_models {
description: "If set to 'true' then delete output models of the tasks that are not referenced by other tasks. Default value is 'true'"
type: boolean
}
}
}
response {
properties {
reset {
description: "Number of tasks reset"
type: integer
}
dequeued {
description: "Number of tasks dequeued"
type: object
additionalProperties: true
}
deleted_models {
description: "Number of output models deleted by the reset"
type: integer
}
urls {
description: "The urls of the files that were uploaded by the tasks. Returned if the 'return_file_urls' was set to 'true'"
"$ref": "#/definitions/task_urls"
}
}
}
}
}
delete_many {
"2.13": ${_definitions.batch_operation} {
description: Delete tasks
request {
properties {
move_to_trash {
description: "Move task to trash instead of deleting it. For internal use only, tasks in the trash are not visible from the API and cannot be restored!"
type: boolean
default: false
}
force = ${_references.force_arg} {
description: "If not true, call fails if the task status is 'in_progress'"
}
return_file_urls {
description: "If set to 'true' then return the urls of the files that were uploaded by the tasks. Default value is 'false'"
type: boolean
}
delete_output_models {
description: "If set to 'true' then delete output models of the tasks that are not referenced by other tasks. Default value is 'true'"
type: boolean
}
}
}
response {
properties {
deleted {
description: "Number of tasks deleted"
type: integer
}
updated_children {
description: "Number of child tasks whose parent property was updated"
type: integer
}
updated_models {
description: "Number of models whose task property was updated"
type: integer
}
deleted_models {
description: "Number of deleted output models"
type: integer
}
urls {
description: "The urls of the files that were uploaded by the tasks. Returned if the 'return_file_urls' was set to 'true'"
"$ref": "#/definitions/task_urls"
}
}
}
}
}
delete { delete {
"2.1" { "2.1" {
description: """Delete a task along with any information stored for it (statistics, frame updates etc.) description: """Delete a task along with any information stored for it (statistics, frame updates etc.)
@ -1472,15 +1557,6 @@ delete {
description: "Number of models whose task property was updated" description: "Number of models whose task property was updated"
type: integer type: integer
} }
updated_versions {
description: "Number of dataset versions whose task property was updated"
type: integer
}
frames {
description: "Response from frames.rollback"
type: object
additionalProperties: true
}
events { events {
description: "Response from events.delete_for_task" description: "Response from events.delete_for_task"
type: object type: object
@ -1545,6 +1621,19 @@ archive {
} }
} }
} }
archive_many {
"2.13": ${_definitions.change_many_request} {
description: Archive tasks
response {
properties {
archived {
description: "Number of tasks archived"
type: integer
}
}
}
}
}
started { started {
"2.1" { "2.1" {
description: "Mark a task status as in_progress. Optionally allows to set the task's execution progress." description: "Mark a task status as in_progress. Optionally allows to set the task's execution progress."
@ -1557,24 +1646,13 @@ started {
description: "If not true, call fails if the task status is not 'not_started'" description: "If not true, call fails if the task status is not 'not_started'"
} }
} ${_references.status_change_request} } ${_references.status_change_request}
response { response: ${_definitions.update_response} {
type: object
properties { properties {
started { started {
description: "Number of tasks started (0 or 1)" description: "Number of tasks started (0 or 1)"
type: integer type: integer
enum: [ 0, 1 ] enum: [ 0, 1 ]
} }
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
} }
} }
} }
@ -1591,18 +1669,24 @@ stop {
description: "If not true, call fails if the task status is not 'in_progress'" description: "If not true, call fails if the task status is not 'in_progress'"
} }
} ${_references.status_change_request} } ${_references.status_change_request}
response { response: ${_definitions.update_response}
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
} }
fields { }
description: "Updated fields names and values" stop_many {
type: object "2.13": ${_definitions.change_many_request} {
additionalProperties: true description: "Request to stop running tasks"
request {
properties {
force = ${_references.force_arg} {
description: "If not true, call fails if the task status is not 'in_progress'"
}
}
}
response {
properties {
stopped {
description: "Number of tasks stopped"
type: integer
} }
} }
} }
@ -1620,21 +1704,7 @@ stopped {
description: "If not true, call fails if the task status is not 'stopped'" description: "If not true, call fails if the task status is not 'stopped'"
} }
} ${_references.status_change_request} } ${_references.status_change_request}
response { response: ${_definitions.update_response}
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
} }
} }
failed { failed {
@ -1647,21 +1717,7 @@ failed {
] ]
properties.force = ${_references.force_arg} properties.force = ${_references.force_arg}
} ${_references.status_change_request} } ${_references.status_change_request}
response { response: ${_definitions.update_response}
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
} }
} }
close { close {
@ -1674,21 +1730,7 @@ close {
] ]
properties.force = ${_references.force_arg} properties.force = ${_references.force_arg}
} ${_references.status_change_request} } ${_references.status_change_request}
response { response: ${_definitions.update_response}
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
} }
} }
publish { publish {
@ -1713,26 +1755,28 @@ publish {
} }
} }
} ${_references.status_change_request} } ${_references.status_change_request}
response { response: ${_definitions.update_response}
type: object }
}
publish_many {
"2.13": ${_definitions.change_many_request} {
description: Publish tasks
request {
properties { properties {
committed_versions_results { force = ${_references.force_arg} {
description: "Committed versions results" description: "If not true, call fails if the task status is not 'stopped'"
type: array }
items { publish_model {
type: object description: "Indicates that the task output model (if exists) should be published. Optional, the default value is True."
additionalProperties: true type: boolean
} }
} }
updated { }
description: "Number of tasks updated (0 or 1)" response {
properties {
published {
description: "Number of tasks published"
type: integer type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
} }
} }
} }
@ -1763,23 +1807,25 @@ Fails if the following parameters in the task were not filled:
} }
} ${_references.status_change_request} } ${_references.status_change_request}
response { response: ${_definitions.update_response} {
type: object
properties { properties {
queued { queued {
description: "Number of tasks queued (0 or 1)" description: "Number of tasks queued (0 or 1)"
type: integer type: integer
enum: [ 0, 1 ] enum: [ 0, 1 ]
} }
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
} }
fields { }
description: "Updated fields names and values" }
type: object }
additionalProperties: true enqueue_many {
"2.13": ${_definitions.change_many_request} {
description: Enqueue tasks
response {
properties {
enqueued {
description: "Number of tasks enqueued"
type: integer
} }
} }
} }
@ -1795,24 +1841,13 @@ dequeue {
task task
] ]
} ${_references.status_change_request} } ${_references.status_change_request}
response { response: ${_definitions.update_response} {
type: object
properties { properties {
dequeued { dequeued {
description: "Number of tasks dequeued (0 or 1)" description: "Number of tasks dequeued (0 or 1)"
type: integer type: integer
enum: [ 0, 1 ] enum: [ 0, 1 ]
} }
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
} }
} }
} }
@ -1837,21 +1872,7 @@ set_requirements {
} }
} }
} }
response { response: ${_definitions.update_response}
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
} }
} }
@ -1867,21 +1888,7 @@ completed {
description: "If not true, call fails if the task status is not in_progress/stopped" description: "If not true, call fails if the task status is not in_progress/stopped"
} }
} ${_references.status_change_request} } ${_references.status_change_request}
response { response: ${_definitions.update_response}
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
} }
} }

View File

@ -1,6 +1,8 @@
from datetime import datetime from datetime import datetime
from typing import Sequence from functools import partial
from typing import Sequence, Tuple, Set
import attr
from mongoengine import Q, EmbeddedDocument from mongoengine import Q, EmbeddedDocument
from apiserver import database from apiserver import database
@ -17,11 +19,19 @@ from apiserver.apimodels.models import (
DeleteModelRequest, DeleteModelRequest,
DeleteMetadataRequest, DeleteMetadataRequest,
AddOrUpdateMetadataRequest, AddOrUpdateMetadataRequest,
ModelsPublishManyRequest,
ModelsPublishManyResponse,
ModelsDeleteManyRequest,
ModelsDeleteManyResponse,
ModelsArchiveManyRequest,
ModelsArchiveManyResponse,
) )
from apiserver.bll.model import ModelBLL
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, project_ids_with_children from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.bll.task import TaskBLL from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import deleted_prefix from apiserver.bll.task.task_operations import publish_task
from apiserver.bll.util import run_batch_operation
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.database.model import validate_id from apiserver.database.model import validate_id
@ -80,7 +90,7 @@ def get_by_task_id(call: APICall, company_id, _):
task = Task.get(_only=["models"], **query) task = Task.get(_only=["models"], **query)
if not task: if not task:
raise errors.bad_request.InvalidTaskId(**query) raise errors.bad_request.InvalidTaskId(**query)
if not task.models.output: if not task.models or not task.models.output:
raise errors.bad_request.MissingTaskFields(field="models.output") raise errors.bad_request.MissingTaskFields(field="models.output")
model_id = task.models.output[-1].model model_id = task.models.output[-1].model
@ -198,17 +208,6 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
) )
def _get_company_model(company_id: str, model_id: str, only_fields=None) -> Model:
query = dict(company=company_id, id=model_id)
qs = Model.objects(**query)
if only_fields:
qs = qs.only(*only_fields)
model = qs.first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
return model
@endpoint("models.update_for_task", required_fields=["task"]) @endpoint("models.update_for_task", required_fields=["task"])
def update_for_task(call: APICall, company_id, _): def update_for_task(call: APICall, company_id, _):
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version: if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
@ -242,7 +241,7 @@ def update_for_task(call: APICall, company_id, _):
) )
if override_model_id: if override_model_id:
model = _get_company_model( model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=override_model_id company_id=company_id, model_id=override_model_id
) )
else: else:
@ -253,7 +252,7 @@ def update_for_task(call: APICall, company_id, _):
if "comment" not in call.data: if "comment" not in call.data:
call.data["comment"] = f"Created by task `{task.name}` ({task.id})" call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
if task.models.output: if task.models and task.models.output:
# model exists, update # model exists, update
model_id = task.models.output[-1].model model_id = task.models.output[-1].model
res = _update_model(call, company_id, model_id=model_id).to_struct() res = _update_model(call, company_id, model_id=model_id).to_struct()
@ -272,7 +271,9 @@ def update_for_task(call: APICall, company_id, _):
company=company_id, company=company_id,
project=task.project, project=task.project,
framework=task.execution.framework, framework=task.execution.framework,
parent=task.models.input[0].model if task.models.input else None, parent=task.models.input[0].model
if task.models and task.models.input
else None,
design=task.execution.model_desc, design=task.execution.model_desc,
labels=task.execution.model_labels, labels=task.execution.model_labels,
ready=(task.status == TaskStatus.published), ready=(task.status == TaskStatus.published),
@ -377,7 +378,9 @@ def edit(call: APICall, company_id, _):
model_id = call.data["model"] model_id = call.data["model"]
with translate_errors_context(): with translate_errors_context():
model = _get_company_model(company_id=company_id, model_id=model_id) model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
)
fields = parse_model_fields(call, create_fields) fields = parse_model_fields(call, create_fields)
fields = prepare_update_fields(call, company_id, fields) fields = prepare_update_fields(call, company_id, fields)
@ -423,7 +426,9 @@ def _update_model(call: APICall, company_id, model_id=None):
model_id = model_id or call.data["model"] model_id = model_id or call.data["model"]
with translate_errors_context(): with translate_errors_context():
model = _get_company_model(company_id=company_id, model_id=model_id) model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
)
data = prepare_update_fields(call, company_id, call.data) data = prepare_update_fields(call, company_id, call.data)
@ -463,81 +468,119 @@ def update(call, company_id, _):
request_data_model=PublishModelRequest, request_data_model=PublishModelRequest,
response_data_model=PublishModelResponse, response_data_model=PublishModelResponse,
) )
def set_ready(call: APICall, company_id, req_model: PublishModelRequest): def set_ready(call: APICall, company_id: str, request: PublishModelRequest):
updated, published_task_data = TaskBLL.model_set_ready( updated, published_task = ModelBLL.publish_model(
model_id=req_model.model, model_id=request.model,
company_id=company_id, company_id=company_id,
publish_task=req_model.publish_task, force_publish_task=request.force_publish_task,
force_publish_task=req_model.force_publish_task, publish_task_func=publish_task if request.publish_task else None,
)
call.result.data_model = PublishModelResponse(
updated=updated, published_task=published_task
) )
call.result.data_model = PublishModelResponse(
updated=updated, @attr.s(auto_attribs=True)
published_task=ModelTaskPublishResponse(**published_task_data) class PublishRes:
if published_task_data published: int = 0
else None, published_tasks: Sequence = []
def __add__(self, other: Tuple[int, ModelTaskPublishResponse]):
published, response = other
return PublishRes(
published=self.published + published,
published_tasks=[*self.published_tasks, *([response] if response else [])],
)
@endpoint(
"models.publish_many",
request_data_model=ModelsPublishManyRequest,
response_data_model=ModelsPublishManyResponse,
)
def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
res, failures = run_batch_operation(
func=partial(
ModelBLL.publish_model,
company_id=company_id,
force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None,
),
ids=request.ids,
init_res=PublishRes(),
)
call.result.data_model = ModelsPublishManyResponse(
published=res.published, published_tasks=res.published_tasks, failures=failures,
) )
@endpoint("models.delete", request_data_model=DeleteModelRequest) @endpoint("models.delete", request_data_model=DeleteModelRequest)
def delete(call: APICall, company_id, request: DeleteModelRequest): def delete(call: APICall, company_id, request: DeleteModelRequest):
model_id = request.model del_count, model = ModelBLL.delete_model(
force = request.force model_id=request.model, company_id=company_id, force=request.force
with translate_errors_context():
model = _get_company_model(
company_id=company_id,
model_id=model_id,
only_fields=("id", "task", "project", "uri"),
) )
deleted_model_id = f"{deleted_prefix}{model_id}"
using_tasks = Task.objects(models__input__model=model_id).only("id")
if using_tasks:
if not force:
raise errors.bad_request.ModelInUse(
"as execution model, use force=True to delete",
num_tasks=len(using_tasks),
)
# update deleted model id in using tasks
Task._get_collection().update_many(
filter={"_id": {"$in": [t.id for t in using_tasks]}},
update={"$set": {"models.input.$[elem].model": deleted_model_id}},
array_filters=[{"elem.model": model_id}],
upsert=False,
)
if model.task:
task: Task = Task.objects(id=model.task).first()
if task and task.status == TaskStatus.published:
if not force:
raise errors.bad_request.ModelCreatingTaskExists(
"and published, use force=True to delete", task=model.task
)
if task.models.output and model_id in task.models.output:
now = datetime.utcnow()
Task._get_collection().update_one(
filter={"_id": model.task, "models.output.model": model_id},
update={
"$set": {
"models.output.$[elem].model": deleted_model_id,
"output.error": f"model deleted on {now.isoformat()}",
},
"last_change": now,
},
array_filters=[{"elem.model": model_id}],
upsert=False,
)
del_count = Model.objects(id=model_id, company=company_id).delete()
if del_count: if del_count:
_reset_cached_tags(company_id, projects=[model.project]) _reset_cached_tags(
call.result.data = dict(deleted=del_count > 0, url=model.uri,) company_id, projects=[model.project] if model.project else []
)
call.result.data = dict(deleted=del_count > 0, url=model.uri)
@attr.s(auto_attribs=True)
class DeleteRes:
deleted: int = 0
projects: Set = set()
urls: Set = set()
def __add__(self, other: Tuple[int, Model]):
del_count, model = other
return DeleteRes(
deleted=self.deleted + del_count,
projects=self.projects | {model.project},
urls=self.urls | {model.uri},
)
@endpoint(
"models.delete_many",
request_data_model=ModelsDeleteManyRequest,
response_data_model=ModelsDeleteManyResponse,
)
def delete(call: APICall, company_id, request: ModelsDeleteManyRequest):
res, failures = run_batch_operation(
func=partial(ModelBLL.delete_model, company_id=company_id, force=request.force),
ids=request.ids,
init_res=DeleteRes(),
)
if res.deleted:
_reset_cached_tags(company_id, projects=list(res.projects))
res.urls.discard(None)
call.result.data_model = ModelsDeleteManyResponse(
deleted=res.deleted, urls=list(res.urls), failures=failures,
)
@endpoint(
"models.archive_many",
request_data_model=ModelsArchiveManyRequest,
response_data_model=ModelsArchiveManyResponse,
)
def archive_many(call: APICall, company_id, request: ModelsArchiveManyRequest):
archived, failures = run_batch_operation(
func=partial(ModelBLL.archive_model, company_id=company_id),
ids=request.ids,
init_res=0,
)
call.result.data_model = ModelsArchiveManyResponse(
archived=archived, failures=failures,
)
@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest) @endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest): def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Model.set_public( call.result.data = Model.set_public(
company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True
) )
@ -547,7 +590,6 @@ def make_public(call: APICall, company_id, request: MakePublicRequest):
"models.make_private", min_version="2.9", request_data_model=MakePublicRequest "models.make_private", min_version="2.9", request_data_model=MakePublicRequest
) )
def make_public(call: APICall, company_id, request: MakePublicRequest): def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Model.set_public( call.result.data = Model.set_public(
company_id, request.ids, invalid_cls=InvalidModelId, enabled=False company_id, request.ids, invalid_cls=InvalidModelId, enabled=False
) )
@ -560,7 +602,6 @@ def move(call: APICall, company_id: str, request: MoveRequest):
"project or project_name is required" "project or project_name is required"
) )
with translate_errors_context():
return { return {
"project_id": project_bll.move_under_project( "project_id": project_bll.move_under_project(
entity_cls=Model, entity_cls=Model,
@ -578,7 +619,7 @@ def add_or_update_metadata(
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest _: APICall, company_id: str, request: AddOrUpdateMetadataRequest
): ):
model_id = request.model model_id = request.model
_get_company_model(company_id=company_id, model_id=model_id) ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
return { return {
"updated": metadata_add_or_update( "updated": metadata_add_or_update(
@ -590,6 +631,8 @@ def add_or_update_metadata(
@endpoint("models.delete_metadata", min_version="2.13") @endpoint("models.delete_metadata", min_version="2.13")
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest): def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
model_id = request.model model_id = request.model
_get_company_model(company_id=company_id, model_id=model_id, only_fields=("id",)) ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
return {"updated": metadata_delete(cls=Model, _id=model_id, keys=request.keys)} return {"updated": metadata_delete(cls=Model, _id=model_id, keys=request.keys)}

View File

@ -1,6 +1,7 @@
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from typing import Sequence, Union, Tuple from functools import partial
from typing import Sequence, Union, Tuple, Set
import attr import attr
import dpath import dpath
@ -8,7 +9,7 @@ from mongoengine import EmbeddedDocument, Q
from mongoengine.queryset.transform import COMPARISON_OPERATORS from mongoengine.queryset.transform import COMPARISON_OPERATORS
from pymongo import UpdateOne from pymongo import UpdateOne
from apiserver.apierrors import errors, APIError from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidTaskId from apiserver.apierrors.errors.bad_request import InvalidTaskId
from apiserver.apimodels.base import ( from apiserver.apimodels.base import (
UpdateResponse, UpdateResponse,
@ -47,8 +48,18 @@ from apiserver.apimodels.tasks import (
AddUpdateModelRequest, AddUpdateModelRequest,
DeleteModelsRequest, DeleteModelsRequest,
ModelItemType, ModelItemType,
StopManyResponse,
StopManyRequest,
EnqueueManyRequest,
EnqueueManyResponse,
ResetManyRequest,
ArchiveManyRequest,
ArchiveManyResponse,
DeleteManyRequest,
PublishManyRequest,
) )
from apiserver.bll.event import EventBLL from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, project_ids_with_children from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.bll.queue import QueueBLL from apiserver.bll.queue import QueueBLL
@ -69,19 +80,23 @@ from apiserver.bll.task.param_utils import (
params_unprepare_from_saved, params_unprepare_from_saved,
escape_paths, escape_paths,
) )
from apiserver.bll.task.task_cleanup import cleanup_task from apiserver.bll.task.task_cleanup import CleanupResult
from apiserver.bll.task.task_operations import (
stop_task,
enqueue_task,
reset_task,
archive_task,
delete_task,
publish_task,
)
from apiserver.bll.task.utils import update_task, deleted_prefix, get_task_for_update from apiserver.bll.task.utils import update_task, deleted_prefix, get_task_for_update
from apiserver.bll.util import SetFieldsResolver from apiserver.bll.util import SetFieldsResolver, run_batch_operation
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.task.output import Output from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import ( from apiserver.database.model.task.task import (
Task, Task,
TaskStatus, TaskStatus,
Script, Script,
DEFAULT_LAST_ITERATION,
Execution,
ArtifactModes,
ModelItem, ModelItem,
) )
from apiserver.database.utils import get_fields_attr, parse_from_call, get_options from apiserver.database.utils import get_fields_attr, parse_from_call, get_options
@ -199,9 +214,7 @@ def get_all_ex(call: APICall, company_id, _):
with TimingContext("mongo", "task_get_all_ex"): with TimingContext("mongo", "task_get_all_ex"):
_process_include_subprojects(call_data) _process_include_subprojects(call_data)
tasks = Task.get_many_with_join( tasks = Task.get_many_with_join(
company=company_id, company=company_id, query_dict=call_data, allow_public=True,
query_dict=call_data,
allow_public=True, # required in case projection is requested for public dataset/versions
) )
unprepare_from_saved(call, tasks) unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks} call.result.data = {"tasks": tasks}
@ -235,7 +248,7 @@ def get_all(call: APICall, company_id, _):
company=company_id, company=company_id,
parameters=call_data, parameters=call_data,
query_dict=call_data, query_dict=call_data,
allow_public=True, # required in case projection is requested for public dataset/versions allow_public=True,
) )
unprepare_from_saved(call, tasks) unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks} call.result.data = {"tasks": tasks}
@ -263,7 +276,7 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
""" """
call.result.data_model = UpdateResponse( call.result.data_model = UpdateResponse(
**TaskBLL.stop_task( **stop_task(
task_id=req_model.task, task_id=req_model.task,
company_id=company_id, company_id=company_id,
user_name=call.identity.user_name, user_name=call.identity.user_name,
@ -273,6 +286,34 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
) )
@attr.s(auto_attribs=True)
class StopRes:
stopped: int = 0
def __add__(self, other: dict):
return StopRes(stopped=self.stopped + 1)
@endpoint(
"tasks.stop_many",
request_data_model=StopManyRequest,
response_data_model=StopManyResponse,
)
def stop_many(call: APICall, company_id, request: StopManyRequest):
res, failures = run_batch_operation(
func=partial(
stop_task,
company_id=company_id,
user_name=call.identity.user_name,
status_reason=request.status_reason,
force=request.force,
),
ids=request.ids,
init_res=StopRes(),
)
call.result.data_model = StopManyResponse(stopped=res.stopped, failures=failures)
@endpoint( @endpoint(
"tasks.stopped", "tasks.stopped",
request_data_model=UpdateRequest, request_data_model=UpdateRequest,
@ -792,61 +833,44 @@ def delete_configuration(
request_data_model=EnqueueRequest, request_data_model=EnqueueRequest,
response_data_model=EnqueueResponse, response_data_model=EnqueueResponse,
) )
def enqueue(call: APICall, company_id, req_model: EnqueueRequest): def enqueue(call: APICall, company_id, request: EnqueueRequest):
task_id = req_model.task queued, res = enqueue_task(
queue_id = req_model.queue task_id=request.task,
status_message = req_model.status_message company_id=company_id,
status_reason = req_model.status_reason queue_id=request.queue,
status_message=request.status_message,
if not queue_id: status_reason=request.status_reason,
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
with translate_errors_context():
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(
_only=("type", "script", "execution", "status", "project", "id"), **query
) )
if not task: call.result.data_model = EnqueueResponse(queued=queued, **res)
raise errors.bad_request.InvalidTaskId(**query)
res = EnqueueResponse(
**ChangeStatusRequest( @attr.s(auto_attribs=True)
task=task, class EnqueueRes:
new_status=TaskStatus.queued, queued: int = 0
status_reason=status_reason,
status_message=status_message, def __add__(self, other: Tuple[int, dict]):
allow_same_state_transition=False, queued, _ = other
).execute() return EnqueueRes(queued=self.queued + queued)
@endpoint(
"tasks.enqueue_many",
request_data_model=EnqueueManyRequest,
response_data_model=EnqueueManyResponse,
) )
def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
try: res, failures = run_batch_operation(
queue_bll.add_task( func=partial(
company_id=company_id, queue_id=queue_id, task_id=task.id enqueue_task,
company_id=company_id,
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
),
ids=request.ids,
init_res=EnqueueRes(),
) )
except Exception: call.result.data_model = EnqueueManyResponse(queued=res.queued, failures=failures)
# failed enqueueing, revert to previous state
ChangeStatusRequest(
task=task,
current_status_override=TaskStatus.queued,
new_status=task.status,
force=True,
status_reason="failed enqueueing",
).execute()
raise
# set the current queue ID in the task
if task.execution:
Task.objects(**query).update(execution__queue=queue_id, multi=False)
else:
Task.objects(**query).update(
execution=Execution(queue=queue_id), multi=False
)
res.queued = 1
res.fields.update(**{"execution.queue": queue_id})
call.result.data_model = res
@endpoint( @endpoint(
@ -878,164 +902,161 @@ def dequeue(call: APICall, company_id, request: UpdateRequest):
"tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse "tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse
) )
def reset(call: APICall, company_id, request: ResetRequest): def reset(call: APICall, company_id, request: ResetRequest):
task = TaskBLL.get_task_with_access( dequeued, cleanup_res, updates = reset_task(
request.task, company_id=company_id, requires_write_access=True task_id=request.task,
) company_id=company_id,
force=request.force,
force = request.force
if not force and task.status == TaskStatus.published:
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
api_results = {}
updates = {}
try:
dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
except APIError:
# dequeue may fail if the task was not enqueued
pass
else:
if dequeued:
api_results.update(dequeued=dequeued)
cleaned_up = cleanup_task(
task,
force=force,
update_children=False,
return_file_urls=request.return_file_urls, return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models, delete_output_models=request.delete_output_models,
) clear_all=request.clear_all,
api_results.update(attr.asdict(cleaned_up))
updates.update(
set__last_iteration=DEFAULT_LAST_ITERATION,
set__last_metrics={},
set__metric_stats={},
set__models__output=[],
unset__output__result=1,
unset__output__error=1,
unset__last_worker=1,
unset__last_worker_report=1,
)
if request.clear_all:
updates.update(
set__execution=Execution(), unset__script=1,
)
else:
updates.update(unset__execution__queue=1)
if task.execution and task.execution.artifacts:
updates.update(
set__execution__artifacts={
key: artifact
for key, artifact in task.execution.artifacts.items()
if artifact.mode == ArtifactModes.input
}
)
res = ResetResponse(
**ChangeStatusRequest(
task=task,
new_status=TaskStatus.created,
force=force,
status_reason="reset",
status_message="reset",
).execute(
started=None,
completed=None,
published=None,
active_duration=None,
**updates,
)
) )
res = ResetResponse(**updates, dequeued=dequeued)
# do not return artifacts since they are not serializable # do not return artifacts since they are not serializable
res.fields.pop("execution.artifacts", None) res.fields.pop("execution.artifacts", None)
for key, value in api_results.items(): for key, value in attr.asdict(cleanup_res).items():
setattr(res, key, value) setattr(res, key, value)
call.result.data_model = res call.result.data_model = res
@attr.s(auto_attribs=True)
class ResetRes:
reset: int = 0
dequeued: int = 0
cleanup_res: CleanupResult = None
def __add__(self, other: Tuple[dict, CleanupResult, dict]):
dequeued, other_res, _ = other
dequeued = dequeued.get("removed", 0) if dequeued else 0
return ResetRes(
reset=self.reset + 1,
dequeued=self.dequeued + dequeued,
cleanup_res=self.cleanup_res + other_res if self.cleanup_res else other_res,
)
@endpoint("tasks.reset_many", request_data_model=ResetManyRequest)
def reset_many(call: APICall, company_id, request: ResetManyRequest):
res, failures = run_batch_operation(
func=partial(
reset_task,
company_id=company_id,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
clear_all=request.clear_all,
),
ids=request.ids,
init_res=ResetRes(),
)
if res.cleanup_res:
cleanup_res = dict(
deleted_models=res.cleanup_res.deleted_models,
urls=attr.asdict(res.cleanup_res.urls),
)
else:
cleanup_res = {}
call.result.data = dict(
reset=res.reset, dequeued=res.dequeued, **cleanup_res, failures=failures,
)
@endpoint( @endpoint(
"tasks.archive", "tasks.archive",
request_data_model=ArchiveRequest, request_data_model=ArchiveRequest,
response_data_model=ArchiveResponse, response_data_model=ArchiveResponse,
) )
def archive(call: APICall, company_id, request: ArchiveRequest): def archive(call: APICall, company_id, request: ArchiveRequest):
archived = 0
tasks = TaskBLL.assert_exists( tasks = TaskBLL.assert_exists(
company_id, company_id,
task_ids=request.tasks, task_ids=request.tasks,
only=("id", "execution", "status", "project", "system_tags"), only=("id", "execution", "status", "project", "system_tags"),
) )
archived = 0
for task in tasks: for task in tasks:
try: archived += archive_task(
TaskBLL.dequeue_and_change_status( company_id=company_id,
task, company_id, request.status_message, request.status_reason, task=task,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
task.update(
status_message=request.status_message, status_message=request.status_message,
status_reason=request.status_reason, status_reason=request.status_reason,
system_tags=sorted(
set(task.system_tags) | {EntityVisibility.archived.value}
),
last_change=datetime.utcnow(),
) )
archived += 1
call.result.data_model = ArchiveResponse(archived=archived) call.result.data_model = ArchiveResponse(archived=archived)
@endpoint(
"tasks.archive_many",
request_data_model=ArchiveManyRequest,
response_data_model=ArchiveManyResponse,
)
def archive_many(call: APICall, company_id, request: ArchiveManyRequest):
archived, failures = run_batch_operation(
func=partial(
archive_task,
company_id=company_id,
status_message=request.status_message,
status_reason=request.status_reason,
),
ids=request.ids,
init_res=0,
)
call.result.data_model = ArchiveManyResponse(archived=archived, failures=failures)
@endpoint("tasks.delete", request_data_model=DeleteRequest) @endpoint("tasks.delete", request_data_model=DeleteRequest)
def delete(call: APICall, company_id, request: DeleteRequest): def delete(call: APICall, company_id, request: DeleteRequest):
task = TaskBLL.get_task_with_access( deleted, task, cleanup_res = delete_task(
request.task, company_id=company_id, requires_write_access=True task_id=request.task,
) company_id=company_id,
move_to_trash=request.move_to_trash,
move_to_trash = request.move_to_trash force=request.force,
force = request.force
if task.status != TaskStatus.created and not force:
raise errors.bad_request.TaskCannotBeDeleted(
"due to status, use force=True",
task=task.id,
expected=TaskStatus.created,
current=task.status,
)
with translate_errors_context():
result = cleanup_task(
task,
force=force,
return_file_urls=request.return_file_urls, return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models, delete_output_models=request.delete_output_models,
) )
if deleted:
_reset_cached_tags(company_id, projects=[task.project] if task.project else [])
call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res))
if move_to_trash:
collection_name = task._get_collection_name()
archived_collection = "{}__trash".format(collection_name)
task.switch_collection(archived_collection)
try:
# A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force
# an insert. However, if for some reason such an ID exists, let's make sure we'll keep going.
with TimingContext("mongo", "save_task"):
task.save(force_insert=True)
except Exception:
pass
task.switch_collection(collection_name)
task.delete() @attr.s(auto_attribs=True)
_reset_cached_tags(company_id, projects=[task.project]) class DeleteRes:
update_project_time(task.project) deleted: int = 0
projects: Set = set()
cleanup_res: CleanupResult = None
call.result.data = dict(deleted=True, **attr.asdict(result)) def __add__(self, other: Tuple[int, Task, CleanupResult]):
del_count, task, other_res = other
return DeleteRes(
deleted=self.deleted + del_count,
projects=self.projects | {task.project},
cleanup_res=self.cleanup_res + other_res if self.cleanup_res else other_res,
)
@endpoint("tasks.delete_many", request_data_model=DeleteManyRequest)
def delete_many(call: APICall, company_id, request: DeleteManyRequest):
res, failures = run_batch_operation(
func=partial(
delete_task,
company_id=company_id,
move_to_trash=request.move_to_trash,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
),
ids=request.ids,
init_res=DeleteRes(),
)
if res.deleted:
_reset_cached_tags(company_id, projects=list(res.projects))
cleanup_res = attr.asdict(res.cleanup_res) if res.cleanup_res else {}
call.result.data = dict(deleted=res.deleted, **cleanup_res, failures=failures)
@endpoint( @endpoint(
@ -1043,18 +1064,45 @@ def delete(call: APICall, company_id, request: DeleteRequest):
request_data_model=PublishRequest, request_data_model=PublishRequest,
response_data_model=PublishResponse, response_data_model=PublishResponse,
) )
def publish(call: APICall, company_id, req_model: PublishRequest): def publish(call: APICall, company_id, request: PublishRequest):
call.result.data_model = PublishResponse( updates = publish_task(
**TaskBLL.publish_task( task_id=request.task,
task_id=req_model.task,
company_id=company_id, company_id=company_id,
publish_model=req_model.publish_model, force=request.force,
force=req_model.force, publish_model_func=ModelBLL.publish_model if request.publish_model else None,
status_reason=req_model.status_reason, status_reason=request.status_reason,
status_message=req_model.status_message, status_message=request.status_message,
) )
call.result.data_model = PublishResponse(**updates)
@attr.s(auto_attribs=True)
class PublishRes:
published: int = 0
def __add__(self, other: dict):
return PublishRes(published=self.published + 1)
@endpoint("tasks.publish_many", request_data_model=PublishManyRequest)
def publish_many(call: APICall, company_id, request: PublishManyRequest):
res, failures = run_batch_operation(
func=partial(
publish_task,
company_id=company_id,
force=request.force,
publish_model_func=ModelBLL.publish_model
if request.publish_model
else None,
status_reason=request.status_reason,
status_message=request.status_message,
),
ids=request.ids,
init_res=PublishRes(),
) )
call.result.data = dict(published=res.published, failures=failures)
@endpoint( @endpoint(
"tasks.completed", "tasks.completed",

View File

@ -28,17 +28,24 @@ def get_tags_response(ret: dict) -> dict:
def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]): def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
""" """
Make sure that tags are always returned sorted
For old clients both tags and system tags are returned in 'tags' field For old clients both tags and system tags are returned in 'tags' field
""" """
if call.requested_endpoint_version >= PartialVersion("2.3"):
return
if isinstance(documents, dict): if isinstance(documents, dict):
documents = [documents] documents = [documents]
merge_tags = call.requested_endpoint_version < PartialVersion("2.3")
for doc in documents: for doc in documents:
if merge_tags:
system_tags = doc.get("system_tags") system_tags = doc.get("system_tags")
if system_tags: if system_tags:
doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags)) doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags))
for field in ("system_tags", "tags"):
tags = doc.get(field)
if tags:
doc[field] = sorted(tags)
def conform_tag_fields(call: APICall, document: dict, validate=False): def conform_tag_fields(call: APICall, document: dict, validate=False):
""" """

View File

@ -69,16 +69,6 @@ class TestService(TestCase, TestServiceInterface):
delete_params=delete_params, delete_params=delete_params,
) )
def create_temp_version(self, *, client=None, **kwargs) -> str:
return self._create_temp_helper(
service="datasets",
create_endpoint="create_version",
delete_endpoint="delete_version",
object_name="version",
create_params=kwargs,
client=client,
)
def setUp(self, version="1.7"): def setUp(self, version="1.7"):
self._api = APIClient(base_url=f"http://localhost:8008/v{version}") self._api = APIClient(base_url=f"http://localhost:8008/v{version}")
self._deferred = [] self._deferred = []

View File

@ -0,0 +1,124 @@
from apiserver.database.utils import id as db_id
from apiserver.tests.automated import TestService
class TestBatchOperations(TestService):
name = "batch operation test"
comment = "this is a comment"
delete_params = dict(can_fail=True, force=True)
def setUp(self, version="2.13"):
super().setUp(version=version)
def test_tasks(self):
tasks = [self._temp_task() for _ in range(2)]
models = [
self._temp_task_model(task=t, uri=f"uri_{idx}")
for idx, t in enumerate(tasks)
]
missing_id = db_id()
ids = [*tasks, missing_id]
# enqueue
res = self.api.tasks.enqueue_many(ids=ids)
self.assertEqual(res.queued, 2)
self._assert_failures(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual({t.status for t in data}, {"queued"})
# stop
for t in tasks:
self.api.tasks.started(task=t)
res = self.api.tasks.stop_many(ids=ids)
self.assertEqual(res.stopped, 2)
self._assert_failures(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual({t.status for t in data}, {"stopped"})
# publish
res = self.api.tasks.publish_many(ids=ids, publish_model=False)
self.assertEqual(res.published, 2)
self._assert_failures(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual({t.status for t in data}, {"published"})
# reset
res = self.api.tasks.reset_many(
ids=ids, delete_output_models=True, return_file_urls=True, force=True
)
self.assertEqual(res.reset, 2)
self.assertEqual(res.deleted_models, 2)
self.assertEqual(set(res.urls.model_urls), {"uri_0", "uri_1"})
self._assert_failures(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual({t.status for t in data}, {"created"})
# archive
res = self.api.tasks.archive_many(ids=ids)
self.assertEqual(res.archived, 2)
self._assert_failures(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertTrue(all("archived" in t.system_tags for t in data))
# delete
res = self.api.tasks.delete_many(
ids=ids, delete_output_models=True, return_file_urls=True
)
self.assertEqual(res.deleted, 2)
self._assert_failures(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual(data, [])
def test_models(self):
uris = [f"file:///{i}" for i in range(2)]
models = [self._temp_model(uri=uri) for uri in uris]
missing_id = db_id()
ids = [*models, missing_id]
# publish
task = self._temp_task()
self.api.models.edit(model=ids[0], ready=False, task=task)
self.api.tasks.add_or_update_model(
task=task, name="output", type="input", model=ids[0]
)
res = self.api.models.publish_many(
ids=ids, publish_task=True, force_publish_task=True
)
self.assertEqual(res.published, 1)
self.assertEqual(res.published_tasks[0].id, task)
self._assert_failures(res, [ids[1], missing_id])
# archive
res = self.api.models.archive_many(ids=ids)
self.assertEqual(res.archived, 2)
self._assert_failures(res, [missing_id])
data = self.api.models.get_all_ex(id=ids).models
for m in data:
self.assertIn("archived", m.system_tags)
# delete
res = self.api.models.delete_many(ids=[*models, missing_id], force=True)
self.assertEqual(res.deleted, 2)
self.assertEqual(set(res.urls), set(uris))
self._assert_failures(res, [missing_id])
data = self.api.models.get_all_ex(id=ids).models
self.assertEqual(data, [])
def _assert_failures(self, res, failed_ids):
self.assertEqual(set(f.id for f in res.failures), set(failed_ids))
def _temp_model(self, **kwargs):
self.update_missing(kwargs, name=self.name, uri="file:///a/b", labels={})
return self.create_temp("models", delete_params=self.delete_params, **kwargs)
def _temp_task(self):
return self.create_temp(
service="tasks", type="testing", name=self.name, input=dict(view={}),
)
def _temp_task_model(self, task, **kwargs) -> str:
model = self._temp_model(ready=False, task=task, **kwargs)
self.api.tasks.add_or_update_model(
task=task, name="output", type="output", model=model
)
return model