Add tasks.archive support

This commit is contained in:
allegroai 2021-01-05 17:49:08 +02:00
parent 03ae90c4a6
commit 0ad0495733
5 changed files with 213 additions and 52 deletions

View File

@ -7,7 +7,11 @@ from jsonmodels.validators import Enum, Length
from apiserver.apimodels import DictField, ListField from apiserver.apimodels import DictField, ListField
from apiserver.apimodels.base import UpdateResponse from apiserver.apimodels.base import UpdateResponse
from apiserver.database.model.task.task import TaskType, ArtifactModes, DEFAULT_ARTIFACT_MODE from apiserver.database.model.task.task import (
TaskType,
ArtifactModes,
DEFAULT_ARTIFACT_MODE,
)
from apiserver.database.utils import get_options from apiserver.database.utils import get_options
@ -199,3 +203,12 @@ class EditConfigurationRequest(TaskRequest):
class DeleteConfigurationRequest(TaskRequest): class DeleteConfigurationRequest(TaskRequest):
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1)) configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
class ArchiveRequest(MultiTaskRequest):
status_reason = StringField(default="")
status_message = StringField(default="")
class ArchiveResponse(models.Base):
archived = IntField()

View File

@ -10,6 +10,7 @@ from six import string_types
import apiserver.database.utils as dbutils import apiserver.database.utils as dbutils
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.queue import QueueBLL
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
@ -327,14 +328,26 @@ class TaskBLL(object):
@staticmethod @staticmethod
def set_last_update( def set_last_update(
task_ids: Collection[str], company_id: str, last_update: datetime, **extra_updates task_ids: Collection[str],
company_id: str,
last_update: datetime,
**extra_updates,
): ):
tasks = Task.objects(id__in=task_ids, company=company_id).only("status", "started") tasks = Task.objects(id__in=task_ids, company=company_id).only(
"status", "started"
)
for task in tasks: for task in tasks:
updates = extra_updates updates = extra_updates
if task.status == TaskStatus.in_progress and task.started: if task.status == TaskStatus.in_progress and task.started:
updates = {"active_duration": (datetime.utcnow() - task.started).total_seconds(), **extra_updates} updates = {
Task.objects(id=task.id, company=company_id).update(upsert=False, last_update=last_update, **updates) "active_duration": (
datetime.utcnow() - task.started
).total_seconds(),
**extra_updates,
}
Task.objects(id=task.id, company=company_id).update(
upsert=False, last_update=last_update, **updates
)
@staticmethod @staticmethod
def update_statistics( def update_statistics(
@ -398,7 +411,10 @@ class TaskBLL(object):
extra_updates["metric_stats"] = metric_stats extra_updates["metric_stats"] = metric_stats
TaskBLL.set_last_update( TaskBLL.set_last_update(
task_ids=[task_id], company_id=company_id, last_update=last_update, **extra_updates task_ids=[task_id],
company_id=company_id,
last_update=last_update,
**extra_updates,
) )
@classmethod @classmethod
@ -613,3 +629,53 @@ class TaskBLL(object):
remaining = max(0, total - (len(results) + page * page_size)) remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results return total, remaining, results
@classmethod
def dequeue_and_change_status(
cls,
task: Task,
company_id: str,
status_message: str,
status_reason: str,
silent_dequeue_fail=False,
):
cls.dequeue(task, company_id, silent_dequeue_fail)
return ChangeStatusRequest(
task=task,
new_status=TaskStatus.created,
status_reason=status_reason,
status_message=status_message,
).execute(unset__execution__queue=1)
@classmethod
def dequeue(cls, task: Task, company_id: str, silent_fail=False):
"""
Dequeue the task from the queue
:param task: task to dequeue
:param company_id: task's company ID.
:param silent_fail: do not throw exceptions. APIError is still thrown
:raise errors.bad_request.InvalidTaskId: if the task's status is not queued
:raise errors.bad_request.MissingRequiredFields: if the task is not queued
:raise APIError or errors.server_error.TransactionError: if internal call to queues.remove_task fails
:return: the result of queues.remove_task call. None in case of silent failure
"""
if task.status not in (TaskStatus.queued,):
if silent_fail:
return
raise errors.bad_request.InvalidTaskId(
status=task.status, expected=TaskStatus.queued
)
if not task.execution or not task.execution.queue:
if silent_fail:
return
raise errors.bad_request.MissingRequiredFields(
"task has no queue value", field="execution.queue"
)
return {
"removed": QueueBLL().remove_task(
company_id=company_id, queue_id=task.execution.queue, task_id=task.id
)
}

View File

@ -1214,6 +1214,43 @@ delete {
} }
} }
} }
archive {
"2.11" {
description: """Archive tasks.
If a task is queued it will first be dequeued and then archived.
"""
request = {
type: object
required: [
tasks
]
properties {
tasks {
description: "List of task ids"
type: array
items { type: string }
}
status_reason {
description: Reason for status change
type: string
}
status_message {
description: Extra information regarding status change
type: string
}
}
}
response {
type: object
properties {
archived {
description: "Indicates number of archived tasks"
type: integer
}
}
}
}
}
started { started {
"2.1" { "2.1" {
description: "Mark a task status as in_progress. Optionally allows to set the task's execution progress." description: "Mark a task status as in_progress. Optionally allows to set the task's execution progress."

View File

@ -39,6 +39,8 @@ from apiserver.apimodels.tasks import (
DeleteConfigurationRequest, DeleteConfigurationRequest,
GetConfigurationNamesRequest, GetConfigurationNamesRequest,
DeleteArtifactsRequest, DeleteArtifactsRequest,
ArchiveResponse,
ArchiveRequest,
) )
from apiserver.bll.event import EventBLL from apiserver.bll.event import EventBLL
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
@ -63,6 +65,7 @@ from apiserver.bll.task.param_utils import (
) )
from apiserver.bll.util import SetFieldsResolver from apiserver.bll.util import SetFieldsResolver
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
from apiserver.database.model.task.output import Output from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import ( from apiserver.database.model.task.task import (
@ -74,12 +77,18 @@ from apiserver.database.model.task.task import (
) )
from apiserver.database.utils import get_fields_attr, parse_from_call from apiserver.database.utils import get_fields_attr, parse_from_call
from apiserver.service_repo import APICall, endpoint from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import conform_tag_fields, conform_output_tags, validate_tags from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
validate_tags,
)
from apiserver.timing_context import TimingContext from apiserver.timing_context import TimingContext
from apiserver.utilities.partial_version import PartialVersion from apiserver.utilities.partial_version import PartialVersion
task_fields = set(Task.get_fields()) task_fields = set(Task.get_fields())
task_script_stripped_fields = set([f for f, v in get_fields_attr(Script, 'strip').items() if v]) task_script_stripped_fields = set(
[f for f, v in get_fields_attr(Script, "strip").items() if v]
)
task_bll = TaskBLL() task_bll = TaskBLL()
event_bll = EventBLL() event_bll = EventBLL()
@ -172,9 +181,7 @@ def get_by_id_ex(call: APICall, company_id, _):
with translate_errors_context(): with translate_errors_context():
with TimingContext("mongo", "task_get_by_id_ex"): with TimingContext("mongo", "task_get_by_id_ex"):
tasks = Task.get_many_with_join( tasks = Task.get_many_with_join(
company=company_id, company=company_id, query_dict=call.data, allow_public=True,
query_dict=call.data,
allow_public=True,
) )
unprepare_from_saved(call, tasks) unprepare_from_saved(call, tasks)
@ -782,51 +789,14 @@ def dequeue(call: APICall, company_id, req_model: UpdateRequest):
only=("id", "execution", "status", "project"), only=("id", "execution", "status", "project"),
requires_write_access=True, requires_write_access=True,
) )
if task.status not in (TaskStatus.queued,):
raise errors.bad_request.InvalidTaskId(
status=task.status, expected=TaskStatus.queued
)
_dequeue(task, company_id)
status_message = req_model.status_message
status_reason = req_model.status_reason
res = DequeueResponse( res = DequeueResponse(
**ChangeStatusRequest( **TaskBLL.dequeue_and_change_status(task, company_id, req_model)
task=task,
new_status=TaskStatus.created,
status_reason=status_reason,
status_message=status_message,
).execute(unset__execution__queue=1)
) )
res.dequeued = 1 res.dequeued = 1
call.result.data_model = res call.result.data_model = res
def _dequeue(task: Task, company_id: str, silent_fail=False):
"""
Dequeue the task from the queue
:param task: task to dequeue
:param silent_fail: do not throw exceptions. APIError is still thrown
:raise errors.bad_request.MissingRequiredFields: if the task is not queued
:raise APIError or errors.server_error.TransactionError: if internal call to queues.remove_task fails
:return: the result of queues.remove_task call. None in case of silent failure
"""
if not task.execution or not task.execution.queue:
if silent_fail:
return
raise errors.bad_request.MissingRequiredFields(
"task has no queue value", field="execution.queue"
)
return {
"removed": queue_bll.remove_task(
company_id=company_id, queue_id=task.execution.queue, task_id=task.id
)
}
@endpoint( @endpoint(
"tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse "tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse
) )
@ -844,7 +814,7 @@ def reset(call: APICall, company_id, request: ResetRequest):
updates = {} updates = {}
try: try:
dequeued = _dequeue(task, company_id, silent_fail=True) dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
except APIError: except APIError:
# dequeue may fail if the task was not enqueued # dequeue may fail if the task was not enqueued
pass pass
@ -897,6 +867,39 @@ def reset(call: APICall, company_id, request: ResetRequest):
call.result.data_model = res call.result.data_model = res
@endpoint(
"tasks.archive",
request_data_model=ArchiveRequest,
response_data_model=ArchiveResponse,
)
def archive(call: APICall, company_id, request: ArchiveRequest):
archived = 0
tasks = TaskBLL.assert_exists(
company_id,
task_ids=request.tasks,
only=("id", "execution", "status", "project", "system_tags"),
)
for task in tasks:
TaskBLL.dequeue_and_change_status(
task,
company_id,
request.status_message,
request.status_reason,
silent_dequeue_fail=True,
)
task.update(
status_message=request.status_message,
status_reason=request.status_reason,
system_tags=sorted(
set(task.system_tags) | {EntityVisibility.archived.value}
)
)
archived += 1
call.result.data_model = ArchiveResponse(archived=archived)
class DocumentGroup(list): class DocumentGroup(list):
""" """
Operate on a list of documents as if they were a query result Operate on a list of documents as if they were a query result

View File

@ -1,9 +1,9 @@
from apiserver.apierrors.errors.bad_request import InvalidModelId, ValidationError, InvalidTaskId from apiserver.apierrors.errors.bad_request import InvalidModelId, ValidationError, InvalidTaskId
from apiserver.apierrors.errors.forbidden import NoWritePermission from apiserver.apierrors.errors.forbidden import NoWritePermission
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.tests.api_client import APIError
from apiserver.tests.automated import TestService from apiserver.tests.automated import TestService
log = config.logger(__file__) log = config.logger(__file__)
@ -21,6 +21,10 @@ class TestTasksEdit(TestService):
self.update_missing(kwargs, name="test", uri="file:///a/b", labels={}) self.update_missing(kwargs, name="test", uri="file:///a/b", labels={})
return self.create_temp("models", **kwargs) return self.create_temp("models", **kwargs)
def new_queue(self, **kwargs):
self.update_missing(kwargs, name="test")
return self.create_temp("queues", **kwargs)
def test_task_types(self): def test_task_types(self):
with self.api.raises(ValidationError): with self.api.raises(ValidationError):
task = self.new_task(type="Unsupported") task = self.new_task(type="Unsupported")
@ -171,3 +175,41 @@ class TestTasksEdit(TestService):
res = self.api.tasks.get_all_ex(id=[task]) res = self.api.tasks.get_all_ex(id=[task])
self.assertEqual([t.id for t in res.tasks], [task]) self.assertEqual([t.id for t in res.tasks], [task])
self.api.tasks.stopped(task=task) self.api.tasks.stopped(task=task)
def test_archive_task(self):
# non-existing task throws an exception
with self.assertRaises(APIError):
self.api.tasks.archive(tasks=["fake-task-id"])
system_tag = "existing-system-tag"
status_message = "test-status-message"
status_reason = "test-status-reason"
queue_id = self.new_queue()
# Create two tasks with system_tags and enqueue one of them
dequeued_task_id = self.new_task(system_tags=[system_tag])
enqueued_task_id = self.new_task(system_tags=[system_tag])
self.api.tasks.enqueue(task=enqueued_task_id, queue=queue_id)
self.api.tasks.archive(
tasks=[enqueued_task_id, dequeued_task_id],
status_message=status_message,
status_reason=status_reason,
)
tasks = self.api.tasks.get_all_ex(id=[enqueued_task_id, dequeued_task_id]).tasks
for task in tasks:
self.assertIn(system_tag, task.system_tags)
self.assertIn("archived", task.system_tags)
self.assertNotIn("queue", task.execution)
self.assertIn(status_message, task.status_message)
self.assertIn(status_reason, task.status_reason)
# Check that the queue does not contain the enqueued task anymore
queue = self.api.queues.get_by_id(queue=queue_id).queue
task_in_queue = next(
(True for entry in queue.entries if entry["task"] == enqueued_task_id),
False,
)
self.assertFalse(task_in_queue)