Add tasks.archive support

This commit is contained in:
allegroai 2021-01-05 17:49:08 +02:00
parent 03ae90c4a6
commit 0ad0495733
5 changed files with 213 additions and 52 deletions

View File

@ -7,7 +7,11 @@ from jsonmodels.validators import Enum, Length
from apiserver.apimodels import DictField, ListField
from apiserver.apimodels.base import UpdateResponse
from apiserver.database.model.task.task import TaskType, ArtifactModes, DEFAULT_ARTIFACT_MODE
from apiserver.database.model.task.task import (
TaskType,
ArtifactModes,
DEFAULT_ARTIFACT_MODE,
)
from apiserver.database.utils import get_options
@ -199,3 +203,12 @@ class EditConfigurationRequest(TaskRequest):
class DeleteConfigurationRequest(TaskRequest):
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
class ArchiveRequest(MultiTaskRequest):
status_reason = StringField(default="")
status_message = StringField(default="")
class ArchiveResponse(models.Base):
archived = IntField()

View File

@ -10,6 +10,7 @@ from six import string_types
import apiserver.database.utils as dbutils
from apiserver.apierrors import errors
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.queue import QueueBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
@ -327,14 +328,26 @@ class TaskBLL(object):
@staticmethod
def set_last_update(
task_ids: Collection[str], company_id: str, last_update: datetime, **extra_updates
task_ids: Collection[str],
company_id: str,
last_update: datetime,
**extra_updates,
):
tasks = Task.objects(id__in=task_ids, company=company_id).only("status", "started")
tasks = Task.objects(id__in=task_ids, company=company_id).only(
"status", "started"
)
for task in tasks:
updates = extra_updates
if task.status == TaskStatus.in_progress and task.started:
updates = {"active_duration": (datetime.utcnow() - task.started).total_seconds(), **extra_updates}
Task.objects(id=task.id, company=company_id).update(upsert=False, last_update=last_update, **updates)
updates = {
"active_duration": (
datetime.utcnow() - task.started
).total_seconds(),
**extra_updates,
}
Task.objects(id=task.id, company=company_id).update(
upsert=False, last_update=last_update, **updates
)
@staticmethod
def update_statistics(
@ -398,7 +411,10 @@ class TaskBLL(object):
extra_updates["metric_stats"] = metric_stats
TaskBLL.set_last_update(
task_ids=[task_id], company_id=company_id, last_update=last_update, **extra_updates
task_ids=[task_id],
company_id=company_id,
last_update=last_update,
**extra_updates,
)
@classmethod
@ -613,3 +629,53 @@ class TaskBLL(object):
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results
@classmethod
def dequeue_and_change_status(
cls,
task: Task,
company_id: str,
status_message: str,
status_reason: str,
silent_dequeue_fail=False,
):
cls.dequeue(task, company_id, silent_dequeue_fail)
return ChangeStatusRequest(
task=task,
new_status=TaskStatus.created,
status_reason=status_reason,
status_message=status_message,
).execute(unset__execution__queue=1)
@classmethod
def dequeue(cls, task: Task, company_id: str, silent_fail=False):
"""
Dequeue the task from the queue
:param task: task to dequeue
:param company_id: task's company ID.
:param silent_fail: do not throw exceptions. APIError is still thrown
:raise errors.bad_request.InvalidTaskId: if the task's status is not queued
:raise errors.bad_request.MissingRequiredFields: if the task is not queued
:raise APIError or errors.server_error.TransactionError: if internal call to queues.remove_task fails
:return: the result of queues.remove_task call. None in case of silent failure
"""
if task.status not in (TaskStatus.queued,):
if silent_fail:
return
raise errors.bad_request.InvalidTaskId(
status=task.status, expected=TaskStatus.queued
)
if not task.execution or not task.execution.queue:
if silent_fail:
return
raise errors.bad_request.MissingRequiredFields(
"task has no queue value", field="execution.queue"
)
return {
"removed": QueueBLL().remove_task(
company_id=company_id, queue_id=task.execution.queue, task_id=task.id
)
}

View File

@ -1214,6 +1214,43 @@ delete {
}
}
}
archive {
"2.11" {
description: """Archive tasks.
If a task is queued it will first be dequeued and then archived.
"""
request = {
type: object
required: [
tasks
]
properties {
tasks {
description: "List of task ids"
type: array
items { type: string }
}
status_reason {
description: Reason for status change
type: string
}
status_message {
description: Extra information regarding status change
type: string
}
}
}
response {
type: object
properties {
archived {
description: "Indicates number of archived tasks"
type: integer
}
}
}
}
}
started {
"2.1" {
description: "Mark a task status as in_progress. Optionally allows to set the task's execution progress."

View File

@ -39,6 +39,8 @@ from apiserver.apimodels.tasks import (
DeleteConfigurationRequest,
GetConfigurationNamesRequest,
DeleteArtifactsRequest,
ArchiveResponse,
ArchiveRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.organization import OrgBLL, Tags
@ -63,6 +65,7 @@ from apiserver.bll.task.param_utils import (
)
from apiserver.bll.util import SetFieldsResolver
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import (
@ -74,12 +77,18 @@ from apiserver.database.model.task.task import (
)
from apiserver.database.utils import get_fields_attr, parse_from_call
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import conform_tag_fields, conform_output_tags, validate_tags
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
validate_tags,
)
from apiserver.timing_context import TimingContext
from apiserver.utilities.partial_version import PartialVersion
task_fields = set(Task.get_fields())
task_script_stripped_fields = set([f for f, v in get_fields_attr(Script, 'strip').items() if v])
task_script_stripped_fields = set(
[f for f, v in get_fields_attr(Script, "strip").items() if v]
)
task_bll = TaskBLL()
event_bll = EventBLL()
@ -172,9 +181,7 @@ def get_by_id_ex(call: APICall, company_id, _):
with translate_errors_context():
with TimingContext("mongo", "task_get_by_id_ex"):
tasks = Task.get_many_with_join(
company=company_id,
query_dict=call.data,
allow_public=True,
company=company_id, query_dict=call.data, allow_public=True,
)
unprepare_from_saved(call, tasks)
@ -782,51 +789,14 @@ def dequeue(call: APICall, company_id, req_model: UpdateRequest):
only=("id", "execution", "status", "project"),
requires_write_access=True,
)
if task.status not in (TaskStatus.queued,):
raise errors.bad_request.InvalidTaskId(
status=task.status, expected=TaskStatus.queued
)
_dequeue(task, company_id)
status_message = req_model.status_message
status_reason = req_model.status_reason
res = DequeueResponse(
**ChangeStatusRequest(
task=task,
new_status=TaskStatus.created,
status_reason=status_reason,
status_message=status_message,
).execute(unset__execution__queue=1)
**TaskBLL.dequeue_and_change_status(task, company_id, req_model)
)
res.dequeued = 1
call.result.data_model = res
def _dequeue(task: Task, company_id: str, silent_fail=False):
"""
Dequeue the task from the queue
:param task: task to dequeue
:param silent_fail: do not throw exceptions. APIError is still thrown
:raise errors.bad_request.MissingRequiredFields: if the task is not queued
:raise APIError or errors.server_error.TransactionError: if internal call to queues.remove_task fails
:return: the result of queues.remove_task call. None in case of silent failure
"""
if not task.execution or not task.execution.queue:
if silent_fail:
return
raise errors.bad_request.MissingRequiredFields(
"task has no queue value", field="execution.queue"
)
return {
"removed": queue_bll.remove_task(
company_id=company_id, queue_id=task.execution.queue, task_id=task.id
)
}
@endpoint(
"tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse
)
@ -844,7 +814,7 @@ def reset(call: APICall, company_id, request: ResetRequest):
updates = {}
try:
dequeued = _dequeue(task, company_id, silent_fail=True)
dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
except APIError:
# dequeue may fail if the task was not enqueued
pass
@ -897,6 +867,39 @@ def reset(call: APICall, company_id, request: ResetRequest):
call.result.data_model = res
@endpoint(
"tasks.archive",
request_data_model=ArchiveRequest,
response_data_model=ArchiveResponse,
)
def archive(call: APICall, company_id, request: ArchiveRequest):
archived = 0
tasks = TaskBLL.assert_exists(
company_id,
task_ids=request.tasks,
only=("id", "execution", "status", "project", "system_tags"),
)
for task in tasks:
TaskBLL.dequeue_and_change_status(
task,
company_id,
request.status_message,
request.status_reason,
silent_dequeue_fail=True,
)
task.update(
status_message=request.status_message,
status_reason=request.status_reason,
system_tags=sorted(
set(task.system_tags) | {EntityVisibility.archived.value}
)
)
archived += 1
call.result.data_model = ArchiveResponse(archived=archived)
class DocumentGroup(list):
"""
Operate on a list of documents as if they were a query result

View File

@ -1,9 +1,9 @@
from apiserver.apierrors.errors.bad_request import InvalidModelId, ValidationError, InvalidTaskId
from apiserver.apierrors.errors.forbidden import NoWritePermission
from apiserver.config_repo import config
from apiserver.tests.api_client import APIError
from apiserver.tests.automated import TestService
log = config.logger(__file__)
@ -21,6 +21,10 @@ class TestTasksEdit(TestService):
self.update_missing(kwargs, name="test", uri="file:///a/b", labels={})
return self.create_temp("models", **kwargs)
def new_queue(self, **kwargs):
self.update_missing(kwargs, name="test")
return self.create_temp("queues", **kwargs)
def test_task_types(self):
with self.api.raises(ValidationError):
task = self.new_task(type="Unsupported")
@ -171,3 +175,41 @@ class TestTasksEdit(TestService):
res = self.api.tasks.get_all_ex(id=[task])
self.assertEqual([t.id for t in res.tasks], [task])
self.api.tasks.stopped(task=task)
def test_archive_task(self):
# non-existing task throws an exception
with self.assertRaises(APIError):
self.api.tasks.archive(tasks=["fake-task-id"])
system_tag = "existing-system-tag"
status_message = "test-status-message"
status_reason = "test-status-reason"
queue_id = self.new_queue()
# Create two tasks with system_tags and enqueue one of them
dequeued_task_id = self.new_task(system_tags=[system_tag])
enqueued_task_id = self.new_task(system_tags=[system_tag])
self.api.tasks.enqueue(task=enqueued_task_id, queue=queue_id)
self.api.tasks.archive(
tasks=[enqueued_task_id, dequeued_task_id],
status_message=status_message,
status_reason=status_reason,
)
tasks = self.api.tasks.get_all_ex(id=[enqueued_task_id, dequeued_task_id]).tasks
for task in tasks:
self.assertIn(system_tag, task.system_tags)
self.assertIn("archived", task.system_tags)
self.assertNotIn("queue", task.execution)
self.assertIn(status_message, task.status_message)
self.assertIn(status_reason, task.status_reason)
# Check that the queue does not contain the enqueued task anymore
queue = self.api.queues.get_by_id(queue=queue_id).queue
task_in_queue = next(
(True for entry in queue.entries if entry["task"] == enqueued_task_id),
False,
)
self.assertFalse(task_in_queue)