mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Add tasks.archive support
This commit is contained in:
parent
03ae90c4a6
commit
0ad0495733
@ -7,7 +7,11 @@ from jsonmodels.validators import Enum, Length
|
||||
|
||||
from apiserver.apimodels import DictField, ListField
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.database.model.task.task import TaskType, ArtifactModes, DEFAULT_ARTIFACT_MODE
|
||||
from apiserver.database.model.task.task import (
|
||||
TaskType,
|
||||
ArtifactModes,
|
||||
DEFAULT_ARTIFACT_MODE,
|
||||
)
|
||||
from apiserver.database.utils import get_options
|
||||
|
||||
|
||||
@ -199,3 +203,12 @@ class EditConfigurationRequest(TaskRequest):
|
||||
|
||||
class DeleteConfigurationRequest(TaskRequest):
|
||||
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
|
||||
|
||||
|
||||
class ArchiveRequest(MultiTaskRequest):
|
||||
status_reason = StringField(default="")
|
||||
status_message = StringField(default="")
|
||||
|
||||
|
||||
class ArchiveResponse(models.Base):
|
||||
archived = IntField()
|
||||
|
@ -10,6 +10,7 @@ from six import string_types
|
||||
import apiserver.database.utils as dbutils
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.model import Model
|
||||
@ -327,14 +328,26 @@ class TaskBLL(object):
|
||||
|
||||
@staticmethod
|
||||
def set_last_update(
|
||||
task_ids: Collection[str], company_id: str, last_update: datetime, **extra_updates
|
||||
task_ids: Collection[str],
|
||||
company_id: str,
|
||||
last_update: datetime,
|
||||
**extra_updates,
|
||||
):
|
||||
tasks = Task.objects(id__in=task_ids, company=company_id).only("status", "started")
|
||||
tasks = Task.objects(id__in=task_ids, company=company_id).only(
|
||||
"status", "started"
|
||||
)
|
||||
for task in tasks:
|
||||
updates = extra_updates
|
||||
if task.status == TaskStatus.in_progress and task.started:
|
||||
updates = {"active_duration": (datetime.utcnow() - task.started).total_seconds(), **extra_updates}
|
||||
Task.objects(id=task.id, company=company_id).update(upsert=False, last_update=last_update, **updates)
|
||||
updates = {
|
||||
"active_duration": (
|
||||
datetime.utcnow() - task.started
|
||||
).total_seconds(),
|
||||
**extra_updates,
|
||||
}
|
||||
Task.objects(id=task.id, company=company_id).update(
|
||||
upsert=False, last_update=last_update, **updates
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_statistics(
|
||||
@ -398,7 +411,10 @@ class TaskBLL(object):
|
||||
extra_updates["metric_stats"] = metric_stats
|
||||
|
||||
TaskBLL.set_last_update(
|
||||
task_ids=[task_id], company_id=company_id, last_update=last_update, **extra_updates
|
||||
task_ids=[task_id],
|
||||
company_id=company_id,
|
||||
last_update=last_update,
|
||||
**extra_updates,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -613,3 +629,53 @@ class TaskBLL(object):
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
@classmethod
|
||||
def dequeue_and_change_status(
|
||||
cls,
|
||||
task: Task,
|
||||
company_id: str,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
silent_dequeue_fail=False,
|
||||
):
|
||||
cls.dequeue(task, company_id, silent_dequeue_fail)
|
||||
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.created,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
).execute(unset__execution__queue=1)
|
||||
|
||||
@classmethod
|
||||
def dequeue(cls, task: Task, company_id: str, silent_fail=False):
|
||||
"""
|
||||
Dequeue the task from the queue
|
||||
:param task: task to dequeue
|
||||
:param company_id: task's company ID.
|
||||
:param silent_fail: do not throw exceptions. APIError is still thrown
|
||||
:raise errors.bad_request.InvalidTaskId: if the task's status is not queued
|
||||
:raise errors.bad_request.MissingRequiredFields: if the task is not queued
|
||||
:raise APIError or errors.server_error.TransactionError: if internal call to queues.remove_task fails
|
||||
:return: the result of queues.remove_task call. None in case of silent failure
|
||||
"""
|
||||
if task.status not in (TaskStatus.queued,):
|
||||
if silent_fail:
|
||||
return
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
status=task.status, expected=TaskStatus.queued
|
||||
)
|
||||
|
||||
if not task.execution or not task.execution.queue:
|
||||
if silent_fail:
|
||||
return
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"task has no queue value", field="execution.queue"
|
||||
)
|
||||
|
||||
return {
|
||||
"removed": QueueBLL().remove_task(
|
||||
company_id=company_id, queue_id=task.execution.queue, task_id=task.id
|
||||
)
|
||||
}
|
||||
|
@ -1214,6 +1214,43 @@ delete {
|
||||
}
|
||||
}
|
||||
}
|
||||
archive {
|
||||
"2.11" {
|
||||
description: """Archive tasks.
|
||||
If a task is queued it will first be dequeued and then archived.
|
||||
"""
|
||||
request = {
|
||||
type: object
|
||||
required: [
|
||||
tasks
|
||||
]
|
||||
properties {
|
||||
tasks {
|
||||
description: "List of task ids"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
status_reason {
|
||||
description: Reason for status change
|
||||
type: string
|
||||
}
|
||||
status_message {
|
||||
description: Extra information regarding status change
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
archived {
|
||||
description: "Indicates number of archived tasks"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
started {
|
||||
"2.1" {
|
||||
description: "Mark a task status as in_progress. Optionally allows to set the task's execution progress."
|
||||
|
@ -39,6 +39,8 @@ from apiserver.apimodels.tasks import (
|
||||
DeleteConfigurationRequest,
|
||||
GetConfigurationNamesRequest,
|
||||
DeleteArtifactsRequest,
|
||||
ArchiveResponse,
|
||||
ArchiveRequest,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
@ -63,6 +65,7 @@ from apiserver.bll.task.param_utils import (
|
||||
)
|
||||
from apiserver.bll.util import SetFieldsResolver
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.output import Output
|
||||
from apiserver.database.model.task.task import (
|
||||
@ -74,12 +77,18 @@ from apiserver.database.model.task.task import (
|
||||
)
|
||||
from apiserver.database.utils import get_fields_attr, parse_from_call
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.services.utils import conform_tag_fields, conform_output_tags, validate_tags
|
||||
from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
validate_tags,
|
||||
)
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.partial_version import PartialVersion
|
||||
|
||||
task_fields = set(Task.get_fields())
|
||||
task_script_stripped_fields = set([f for f, v in get_fields_attr(Script, 'strip').items() if v])
|
||||
task_script_stripped_fields = set(
|
||||
[f for f, v in get_fields_attr(Script, "strip").items() if v]
|
||||
)
|
||||
|
||||
task_bll = TaskBLL()
|
||||
event_bll = EventBLL()
|
||||
@ -172,9 +181,7 @@ def get_by_id_ex(call: APICall, company_id, _):
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_by_id_ex"):
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
company=company_id, query_dict=call.data, allow_public=True,
|
||||
)
|
||||
|
||||
unprepare_from_saved(call, tasks)
|
||||
@ -782,51 +789,14 @@ def dequeue(call: APICall, company_id, req_model: UpdateRequest):
|
||||
only=("id", "execution", "status", "project"),
|
||||
requires_write_access=True,
|
||||
)
|
||||
if task.status not in (TaskStatus.queued,):
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
status=task.status, expected=TaskStatus.queued
|
||||
)
|
||||
|
||||
_dequeue(task, company_id)
|
||||
|
||||
status_message = req_model.status_message
|
||||
status_reason = req_model.status_reason
|
||||
res = DequeueResponse(
|
||||
**ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.created,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
).execute(unset__execution__queue=1)
|
||||
**TaskBLL.dequeue_and_change_status(task, company_id, req_model)
|
||||
)
|
||||
|
||||
res.dequeued = 1
|
||||
|
||||
call.result.data_model = res
|
||||
|
||||
|
||||
def _dequeue(task: Task, company_id: str, silent_fail=False):
|
||||
"""
|
||||
Dequeue the task from the queue
|
||||
:param task: task to dequeue
|
||||
:param silent_fail: do not throw exceptions. APIError is still thrown
|
||||
:raise errors.bad_request.MissingRequiredFields: if the task is not queued
|
||||
:raise APIError or errors.server_error.TransactionError: if internal call to queues.remove_task fails
|
||||
:return: the result of queues.remove_task call. None in case of silent failure
|
||||
"""
|
||||
if not task.execution or not task.execution.queue:
|
||||
if silent_fail:
|
||||
return
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"task has no queue value", field="execution.queue"
|
||||
)
|
||||
|
||||
return {
|
||||
"removed": queue_bll.remove_task(
|
||||
company_id=company_id, queue_id=task.execution.queue, task_id=task.id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse
|
||||
)
|
||||
@ -844,7 +814,7 @@ def reset(call: APICall, company_id, request: ResetRequest):
|
||||
updates = {}
|
||||
|
||||
try:
|
||||
dequeued = _dequeue(task, company_id, silent_fail=True)
|
||||
dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
@ -897,6 +867,39 @@ def reset(call: APICall, company_id, request: ResetRequest):
|
||||
call.result.data_model = res
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.archive",
|
||||
request_data_model=ArchiveRequest,
|
||||
response_data_model=ArchiveResponse,
|
||||
)
|
||||
def archive(call: APICall, company_id, request: ArchiveRequest):
|
||||
archived = 0
|
||||
tasks = TaskBLL.assert_exists(
|
||||
company_id,
|
||||
task_ids=request.tasks,
|
||||
only=("id", "execution", "status", "project", "system_tags"),
|
||||
)
|
||||
for task in tasks:
|
||||
TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
company_id,
|
||||
request.status_message,
|
||||
request.status_reason,
|
||||
silent_dequeue_fail=True,
|
||||
)
|
||||
task.update(
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
system_tags=sorted(
|
||||
set(task.system_tags) | {EntityVisibility.archived.value}
|
||||
)
|
||||
)
|
||||
|
||||
archived += 1
|
||||
|
||||
call.result.data_model = ArchiveResponse(archived=archived)
|
||||
|
||||
|
||||
class DocumentGroup(list):
|
||||
"""
|
||||
Operate on a list of documents as if they were a query result
|
||||
|
@ -1,9 +1,9 @@
|
||||
from apiserver.apierrors.errors.bad_request import InvalidModelId, ValidationError, InvalidTaskId
|
||||
from apiserver.apierrors.errors.forbidden import NoWritePermission
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.tests.api_client import APIError
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
@ -21,6 +21,10 @@ class TestTasksEdit(TestService):
|
||||
self.update_missing(kwargs, name="test", uri="file:///a/b", labels={})
|
||||
return self.create_temp("models", **kwargs)
|
||||
|
||||
def new_queue(self, **kwargs):
|
||||
self.update_missing(kwargs, name="test")
|
||||
return self.create_temp("queues", **kwargs)
|
||||
|
||||
def test_task_types(self):
|
||||
with self.api.raises(ValidationError):
|
||||
task = self.new_task(type="Unsupported")
|
||||
@ -171,3 +175,41 @@ class TestTasksEdit(TestService):
|
||||
res = self.api.tasks.get_all_ex(id=[task])
|
||||
self.assertEqual([t.id for t in res.tasks], [task])
|
||||
self.api.tasks.stopped(task=task)
|
||||
|
||||
def test_archive_task(self):
|
||||
# non-existing task throws an exception
|
||||
with self.assertRaises(APIError):
|
||||
self.api.tasks.archive(tasks=["fake-task-id"])
|
||||
|
||||
system_tag = "existing-system-tag"
|
||||
status_message = "test-status-message"
|
||||
status_reason = "test-status-reason"
|
||||
queue_id = self.new_queue()
|
||||
|
||||
# Create two tasks with system_tags and enqueue one of them
|
||||
dequeued_task_id = self.new_task(system_tags=[system_tag])
|
||||
enqueued_task_id = self.new_task(system_tags=[system_tag])
|
||||
self.api.tasks.enqueue(task=enqueued_task_id, queue=queue_id)
|
||||
|
||||
self.api.tasks.archive(
|
||||
tasks=[enqueued_task_id, dequeued_task_id],
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
)
|
||||
|
||||
tasks = self.api.tasks.get_all_ex(id=[enqueued_task_id, dequeued_task_id]).tasks
|
||||
|
||||
for task in tasks:
|
||||
self.assertIn(system_tag, task.system_tags)
|
||||
self.assertIn("archived", task.system_tags)
|
||||
self.assertNotIn("queue", task.execution)
|
||||
self.assertIn(status_message, task.status_message)
|
||||
self.assertIn(status_reason, task.status_reason)
|
||||
|
||||
# Check that the queue does not contain the enqueued task anymore
|
||||
queue = self.api.queues.get_by_id(queue=queue_id).queue
|
||||
task_in_queue = next(
|
||||
(True for entry in queue.entries if entry["task"] == enqueued_task_id),
|
||||
False,
|
||||
)
|
||||
self.assertFalse(task_in_queue)
|
||||
|
Loading…
Reference in New Issue
Block a user