From 0ad049573337afa4aa7b63c0487932c7bd1743fc Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 5 Jan 2021 17:49:08 +0200 Subject: [PATCH] Add tasks.archive support --- apiserver/apimodels/tasks.py | 15 +++- apiserver/bll/task/task_bll.py | 76 ++++++++++++++-- apiserver/schema/services/tasks.conf | 37 ++++++++ apiserver/services/tasks.py | 93 ++++++++++---------- apiserver/tests/automated/test_tasks_edit.py | 44 ++++++++- 5 files changed, 213 insertions(+), 52 deletions(-) diff --git a/apiserver/apimodels/tasks.py b/apiserver/apimodels/tasks.py index d6894f4..23d8647 100644 --- a/apiserver/apimodels/tasks.py +++ b/apiserver/apimodels/tasks.py @@ -7,7 +7,11 @@ from jsonmodels.validators import Enum, Length from apiserver.apimodels import DictField, ListField from apiserver.apimodels.base import UpdateResponse -from apiserver.database.model.task.task import TaskType, ArtifactModes, DEFAULT_ARTIFACT_MODE +from apiserver.database.model.task.task import ( + TaskType, + ArtifactModes, + DEFAULT_ARTIFACT_MODE, +) from apiserver.database.utils import get_options @@ -199,3 +203,12 @@ class EditConfigurationRequest(TaskRequest): class DeleteConfigurationRequest(TaskRequest): configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1)) + + +class ArchiveRequest(MultiTaskRequest): + status_reason = StringField(default="") + status_message = StringField(default="") + + +class ArchiveResponse(models.Base): + archived = IntField() diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index 42df3f1..b5afa81 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -10,6 +10,7 @@ from six import string_types import apiserver.database.utils as dbutils from apiserver.apierrors import errors from apiserver.bll.organization import OrgBLL, Tags +from apiserver.bll.queue import QueueBLL from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context from apiserver.database.model.model import Model @@ -327,14 +328,26 @@ class TaskBLL(object): @staticmethod def set_last_update( - task_ids: Collection[str], company_id: str, last_update: datetime, **extra_updates + task_ids: Collection[str], + company_id: str, + last_update: datetime, + **extra_updates, ): - tasks = Task.objects(id__in=task_ids, company=company_id).only("status", "started") + tasks = Task.objects(id__in=task_ids, company=company_id).only( + "status", "started" + ) for task in tasks: updates = extra_updates if task.status == TaskStatus.in_progress and task.started: - updates = {"active_duration": (datetime.utcnow() - task.started).total_seconds(), **extra_updates} - Task.objects(id=task.id, company=company_id).update(upsert=False, last_update=last_update, **updates) + updates = { + "active_duration": ( + datetime.utcnow() - task.started + ).total_seconds(), + **extra_updates, + } + Task.objects(id=task.id, company=company_id).update( + upsert=False, last_update=last_update, **updates + ) @staticmethod def update_statistics( @@ -398,7 +411,10 @@ class TaskBLL(object): extra_updates["metric_stats"] = metric_stats TaskBLL.set_last_update( - task_ids=[task_id], company_id=company_id, last_update=last_update, **extra_updates + task_ids=[task_id], + company_id=company_id, + last_update=last_update, + **extra_updates, ) @classmethod @@ -613,3 +629,53 @@ class TaskBLL(object): remaining = max(0, total - (len(results) + page * page_size)) return total, remaining, results + + @classmethod + def dequeue_and_change_status( + cls, + task: Task, + company_id: str, + status_message: str, + status_reason: str, + silent_dequeue_fail=False, + ): + cls.dequeue(task, company_id, silent_dequeue_fail) + + return ChangeStatusRequest( + task=task, + new_status=TaskStatus.created, + status_reason=status_reason, + status_message=status_message, + ).execute(unset__execution__queue=1) + + @classmethod + def dequeue(cls, task: Task, company_id: str, silent_fail=False): + """ + Dequeue the task from the queue + :param task: task to dequeue + :param company_id: task's company ID. + :param silent_fail: do not throw exceptions. APIError is still thrown + :raise errors.bad_request.InvalidTaskId: if the task's status is not queued + :raise errors.bad_request.MissingRequiredFields: if the task is not queued + :raise APIError or errors.server_error.TransactionError: if internal call to queues.remove_task fails + :return: the result of queues.remove_task call. None in case of silent failure + """ + if task.status not in (TaskStatus.queued,): + if silent_fail: + return + raise errors.bad_request.InvalidTaskId( + status=task.status, expected=TaskStatus.queued + ) + + if not task.execution or not task.execution.queue: + if silent_fail: + return + raise errors.bad_request.MissingRequiredFields( + "task has no queue value", field="execution.queue" + ) + + return { + "removed": QueueBLL().remove_task( + company_id=company_id, queue_id=task.execution.queue, task_id=task.id + ) + } diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index 4fe276a..c597e14 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -1214,6 +1214,43 @@ delete { } } } +archive { + "2.11" { + description: """Archive tasks. + If a task is queued it will first be dequeued and then archived. + """ + request = { + type: object + required: [ + tasks + ] + properties { + tasks { + description: "List of task ids" + type: array + items { type: string } + } + status_reason { + description: Reason for status change + type: string + } + status_message { + description: Extra information regarding status change + type: string + } + } + } + response { + type: object + properties { + archived { + description: "Indicates number of archived tasks" + type: integer + } + } + } + } +} started { "2.1" { description: "Mark a task status as in_progress. Optionally allows to set the task's execution progress." diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 7659481..d6a7ba3 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -39,6 +39,8 @@ from apiserver.apimodels.tasks import ( DeleteConfigurationRequest, GetConfigurationNamesRequest, DeleteArtifactsRequest, + ArchiveResponse, + ArchiveRequest, ) from apiserver.bll.event import EventBLL from apiserver.bll.organization import OrgBLL, Tags @@ -63,6 +65,7 @@ from apiserver.bll.task.param_utils import ( ) from apiserver.bll.util import SetFieldsResolver from apiserver.database.errors import translate_errors_context +from apiserver.database.model import EntityVisibility from apiserver.database.model.model import Model from apiserver.database.model.task.output import Output from apiserver.database.model.task.task import ( @@ -74,12 +77,18 @@ from apiserver.database.model.task.task import ( ) from apiserver.database.utils import get_fields_attr, parse_from_call from apiserver.service_repo import APICall, endpoint -from apiserver.services.utils import conform_tag_fields, conform_output_tags, validate_tags +from apiserver.services.utils import ( + conform_tag_fields, + conform_output_tags, + validate_tags, +) from apiserver.timing_context import TimingContext from apiserver.utilities.partial_version import PartialVersion task_fields = set(Task.get_fields()) -task_script_stripped_fields = set([f for f, v in get_fields_attr(Script, 'strip').items() if v]) +task_script_stripped_fields = set( + [f for f, v in get_fields_attr(Script, "strip").items() if v] +) task_bll = TaskBLL() event_bll = EventBLL() @@ -172,9 +181,7 @@ def get_by_id_ex(call: APICall, company_id, _): with translate_errors_context(): with TimingContext("mongo", "task_get_by_id_ex"): tasks = Task.get_many_with_join( - company=company_id, - query_dict=call.data, - allow_public=True, + company=company_id, query_dict=call.data, allow_public=True, ) unprepare_from_saved(call, tasks) @@ -782,51 +789,14 @@ def dequeue(call: APICall, company_id, req_model: UpdateRequest): only=("id", "execution", "status", "project"), requires_write_access=True, ) - if task.status not in (TaskStatus.queued,): - raise errors.bad_request.InvalidTaskId( - status=task.status, expected=TaskStatus.queued - ) - - _dequeue(task, company_id) - - status_message = req_model.status_message - status_reason = req_model.status_reason res = DequeueResponse( - **ChangeStatusRequest( - task=task, - new_status=TaskStatus.created, - status_reason=status_reason, - status_message=status_message, - ).execute(unset__execution__queue=1) + **TaskBLL.dequeue_and_change_status(task, company_id, req_model) ) + res.dequeued = 1 - call.result.data_model = res -def _dequeue(task: Task, company_id: str, silent_fail=False): - """ - Dequeue the task from the queue - :param task: task to dequeue - :param silent_fail: do not throw exceptions. APIError is still thrown - :raise errors.bad_request.MissingRequiredFields: if the task is not queued - :raise APIError or errors.server_error.TransactionError: if internal call to queues.remove_task fails - :return: the result of queues.remove_task call. None in case of silent failure - """ - if not task.execution or not task.execution.queue: - if silent_fail: - return - raise errors.bad_request.MissingRequiredFields( - "task has no queue value", field="execution.queue" - ) - - return { - "removed": queue_bll.remove_task( - company_id=company_id, queue_id=task.execution.queue, task_id=task.id - ) - } - - @endpoint( "tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse ) @@ -844,7 +814,7 @@ def reset(call: APICall, company_id, request: ResetRequest): updates = {} try: - dequeued = _dequeue(task, company_id, silent_fail=True) + dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True) except APIError: # dequeue may fail if the task was not enqueued pass @@ -897,6 +867,39 @@ def reset(call: APICall, company_id, request: ResetRequest): call.result.data_model = res +@endpoint( + "tasks.archive", + request_data_model=ArchiveRequest, + response_data_model=ArchiveResponse, +) +def archive(call: APICall, company_id, request: ArchiveRequest): + archived = 0 + tasks = TaskBLL.assert_exists( + company_id, + task_ids=request.tasks, + only=("id", "execution", "status", "project", "system_tags"), + ) + for task in tasks: + TaskBLL.dequeue_and_change_status( + task, + company_id, + request.status_message, + request.status_reason, + silent_dequeue_fail=True, + ) + task.update( + status_message=request.status_message, + status_reason=request.status_reason, + system_tags=sorted( + set(task.system_tags) | {EntityVisibility.archived.value} + ) + ) + + archived += 1 + + call.result.data_model = ArchiveResponse(archived=archived) + + class DocumentGroup(list): """ Operate on a list of documents as if they were a query result diff --git a/apiserver/tests/automated/test_tasks_edit.py b/apiserver/tests/automated/test_tasks_edit.py index 37ee91e..76ab677 100644 --- a/apiserver/tests/automated/test_tasks_edit.py +++ b/apiserver/tests/automated/test_tasks_edit.py @@ -1,9 +1,9 @@ from apiserver.apierrors.errors.bad_request import InvalidModelId, ValidationError, InvalidTaskId from apiserver.apierrors.errors.forbidden import NoWritePermission from apiserver.config_repo import config +from apiserver.tests.api_client import APIError from apiserver.tests.automated import TestService - log = config.logger(__file__) @@ -21,6 +21,10 @@ class TestTasksEdit(TestService): self.update_missing(kwargs, name="test", uri="file:///a/b", labels={}) return self.create_temp("models", **kwargs) + def new_queue(self, **kwargs): + self.update_missing(kwargs, name="test") + return self.create_temp("queues", **kwargs) + def test_task_types(self): with self.api.raises(ValidationError): task = self.new_task(type="Unsupported") @@ -171,3 +175,41 @@ class TestTasksEdit(TestService): res = self.api.tasks.get_all_ex(id=[task]) self.assertEqual([t.id for t in res.tasks], [task]) self.api.tasks.stopped(task=task) + + def test_archive_task(self): + # non-existing task throws an exception + with self.assertRaises(APIError): + self.api.tasks.archive(tasks=["fake-task-id"]) + + system_tag = "existing-system-tag" + status_message = "test-status-message" + status_reason = "test-status-reason" + queue_id = self.new_queue() + + # Create two tasks with system_tags and enqueue one of them + dequeued_task_id = self.new_task(system_tags=[system_tag]) + enqueued_task_id = self.new_task(system_tags=[system_tag]) + self.api.tasks.enqueue(task=enqueued_task_id, queue=queue_id) + + self.api.tasks.archive( + tasks=[enqueued_task_id, dequeued_task_id], + status_message=status_message, + status_reason=status_reason, + ) + + tasks = self.api.tasks.get_all_ex(id=[enqueued_task_id, dequeued_task_id]).tasks + + for task in tasks: + self.assertIn(system_tag, task.system_tags) + self.assertIn("archived", task.system_tags) + self.assertNotIn("queue", task.execution) + self.assertIn(status_message, task.status_message) + self.assertIn(status_reason, task.status_reason) + + # Check that the queue does not contain the enqueued task anymore + queue = self.api.queues.get_by_id(queue=queue_id).queue + task_in_queue = next( + (True for entry in queue.entries if entry["task"] == enqueued_task_id), + False, + ) + self.assertFalse(task_in_queue)