Added unarchive APIs

This commit is contained in:
allegroai 2021-05-03 18:04:17 +03:00
parent e2f265b4bc
commit b99f620073
10 changed files with 129 additions and 24 deletions

View File

@ -51,10 +51,6 @@ class ModelsDeleteManyRequest(BatchRequest):
force = fields.BoolField(default=False)
class ModelsArchiveManyRequest(BatchRequest):
pass
class ModelsDeleteManyResponse(BatchResponse):
urls = fields.ListField([str])

View File

@ -241,12 +241,9 @@ class StopManyRequest(TaskBatchRequest):
force = BoolField(default=False)
class ArchiveManyRequest(TaskBatchRequest):
pass
class EnqueueManyRequest(TaskBatchRequest):
queue = StringField()
validate_tasks = BoolField(default=False)
class DeleteManyRequest(TaskBatchRequest):

View File

@ -115,3 +115,15 @@ class ModelBLL:
)
return archived
@classmethod
def unarchive_model(cls, model_id: str, company_id: str):
cls.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
unarchived = Model.objects(company=company_id, id=model_id).update(
pull__system_tags=EntityVisibility.archived.value,
last_update=datetime.utcnow(),
)
return unarchived

View File

@ -335,6 +335,7 @@ class TaskBLL:
if (
validate_parent
and task.parent
and not task.parent.startswith(deleted_prefix)
and not Task.get(
company=task.company, id=task.parent, _only=("id",), include_public=True
)

View File

@ -59,30 +59,49 @@ def archive_task(
task.update(
status_message=status_message,
status_reason=status_reason,
add_to_set__system_tags={EntityVisibility.archived.value},
add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
)
return 1
def unarchive_task(
task: str, company_id: str, status_message: str, status_reason: str,
) -> int:
"""
Unarchive task. Return 1 if successful
"""
task = TaskBLL.get_task_with_access(
task, company_id=company_id, only=("id",), requires_write_access=True,
)
return task.update(
status_message=status_message,
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
)
def enqueue_task(
task_id: str,
company_id: str,
queue_id: str,
status_message: str,
status_reason: str,
validate: bool = False,
) -> Tuple[int, dict]:
if not queue_id:
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(
_only=("type", "script", "execution", "status", "project", "id"), **query
)
task = Task.get_for_writing(**query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
if validate:
TaskBLL.validate(task)
res = ChangeStatusRequest(
task=task,
new_status=TaskStatus.queued,

View File

@ -730,6 +730,21 @@ archive_many {
}
}
}
unarchive_many {
"2.13": ${_definitions.batch_operation} {
description: Unarchive models
request {
properties {
ids.description: "IDs of the models to unarchive"
}
}
response {
properties {
succeeded.description: "Number of models unarchived"
}
}
}
}
delete_many {
"2.13": ${_definitions.batch_operation} {
description: Delete models

View File

@ -1623,6 +1623,21 @@ archive_many {
}
}
}
unarchive_many {
"2.13": ${_definitions.change_many_request} {
description: Unarchive tasks
request {
properties {
ids.description: "IDs of the tasks to unarchive"
}
}
response {
properties {
succeeded.description: "Number of tasks unarchived"
}
}
}
}
started {
"2.1" {
description: "Mark a task status as in_progress. Optionally allows to set the task's execution progress."
@ -1813,6 +1828,11 @@ enqueue_many {
description: "Queue id. If not provided, tasks are added to the default queue."
type: string
}
validate_tasks {
description: "If set then tasks are validated before enqueue"
type: boolean
default: false
}
}
}
response {

View File

@ -9,7 +9,7 @@ from apiserver import database
from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidModelId
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, MoveRequest
from apiserver.apimodels.batch import BatchResponse
from apiserver.apimodels.batch import BatchResponse, BatchRequest
from apiserver.apimodels.models import (
CreateModelRequest,
CreateModelResponse,
@ -24,7 +24,6 @@ from apiserver.apimodels.models import (
ModelsPublishManyResponse,
ModelsDeleteManyRequest,
ModelsDeleteManyResponse,
ModelsArchiveManyRequest,
)
from apiserver.bll.model import ModelBLL
from apiserver.bll.organization import OrgBLL, Tags
@ -587,10 +586,10 @@ def delete(call: APICall, company_id, request: ModelsDeleteManyRequest):
@endpoint(
"models.archive_many",
request_data_model=ModelsArchiveManyRequest,
request_data_model=BatchRequest,
response_data_model=BatchResponse,
)
def archive_many(call: APICall, company_id, request: ModelsArchiveManyRequest):
def archive_many(call: APICall, company_id, request: BatchRequest):
archived, failures = run_batch_operation(
func=partial(ModelBLL.archive_model, company_id=company_id),
ids=request.ids,
@ -599,6 +598,20 @@ def archive_many(call: APICall, company_id, request: ModelsArchiveManyRequest):
call.result.data_model = BatchResponse(succeeded=archived, failures=failures)
@endpoint(
"models.unarchive_many",
request_data_model=BatchRequest,
response_data_model=BatchResponse,
)
def unarchive_many(call: APICall, company_id, request: BatchRequest):
unarchived, failures = run_batch_operation(
func=partial(ModelBLL.unarchive_model, company_id=company_id),
ids=request.ids,
init_res=0,
)
call.result.data_model = BatchResponse(succeeded=unarchived, failures=failures,)
@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Model.set_public(

View File

@ -51,9 +51,9 @@ from apiserver.apimodels.tasks import (
StopManyRequest,
EnqueueManyRequest,
ResetManyRequest,
ArchiveManyRequest,
DeleteManyRequest,
PublishManyRequest,
TaskBatchRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL
@ -85,8 +85,9 @@ from apiserver.bll.task.task_operations import (
archive_task,
delete_task,
publish_task,
unarchive_task,
)
from apiserver.bll.task.utils import update_task, deleted_prefix, get_task_for_update
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
from apiserver.bll.util import SetFieldsResolver, run_batch_operation
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.output import Output
@ -875,6 +876,7 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
validate=request.validate_tasks,
),
ids=request.ids,
init_res=EnqueueRes(),
@ -998,10 +1000,10 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
@endpoint(
"tasks.archive_many",
request_data_model=ArchiveManyRequest,
request_data_model=TaskBatchRequest,
response_data_model=BatchResponse,
)
def archive_many(call: APICall, company_id, request: ArchiveManyRequest):
def archive_many(call: APICall, company_id, request: TaskBatchRequest):
archived, failures = run_batch_operation(
func=partial(
archive_task,
@ -1015,6 +1017,25 @@ def archive_many(call: APICall, company_id, request: ArchiveManyRequest):
call.result.data_model = BatchResponse(succeeded=archived, failures=failures)
@endpoint(
"tasks.unarchive_many",
request_data_model=TaskBatchRequest,
response_data_model=BatchResponse,
)
def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
unarchived, failures = run_batch_operation(
func=partial(
unarchive_task,
company_id=company_id,
status_message=request.status_message,
status_reason=request.status_reason,
),
ids=request.ids,
init_res=0,
)
call.result.data_model = BatchResponse(succeeded=unarchived, failures=failures)
@endpoint("tasks.delete", request_data_model=DeleteRequest)
def delete(call: APICall, company_id, request: DeleteRequest):
deleted, task, cleanup_res = delete_task(

View File

@ -53,13 +53,19 @@ class TestBatchOperations(TestService):
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual({t.status for t in data}, {"created"})
# archive
# archive/unarchive
res = self.api.tasks.archive_many(ids=ids)
self.assertEqual(res.succeeded, 2)
self._assert_failures(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertTrue(all("archived" in t.system_tags for t in data))
res = self.api.tasks.unarchive_many(ids=ids)
self.assertEqual(res.succeeded, 2)
self._assert_failures(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertFalse(any("archived" in t.system_tags for t in data))
# delete
res = self.api.tasks.delete_many(
ids=ids, delete_output_models=True, return_file_urls=True
@ -88,13 +94,18 @@ class TestBatchOperations(TestService):
self.assertEqual(res.published_tasks[0].id, task)
self._assert_failures(res, [ids[1], missing_id])
# archive
# archive/unarchive
res = self.api.models.archive_many(ids=ids)
self.assertEqual(res.succeeded, 2)
self._assert_failures(res, [missing_id])
data = self.api.models.get_all_ex(id=ids).models
for m in data:
self.assertIn("archived", m.system_tags)
self.assertTrue(all("archived" in m.system_tags for m in data))
res = self.api.models.unarchive_many(ids=ids)
self.assertEqual(res.succeeded, 2)
self._assert_failures(res, [missing_id])
data = self.api.models.get_all_ex(id=ids).models
self.assertFalse(any("archived" in m.system_tags for m in data))
# delete
res = self.api.models.delete_many(ids=[*models, missing_id], force=True)