mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Add tasks.archive support
This commit is contained in:
parent
03ae90c4a6
commit
0ad0495733
@ -7,7 +7,11 @@ from jsonmodels.validators import Enum, Length
|
|||||||
|
|
||||||
from apiserver.apimodels import DictField, ListField
|
from apiserver.apimodels import DictField, ListField
|
||||||
from apiserver.apimodels.base import UpdateResponse
|
from apiserver.apimodels.base import UpdateResponse
|
||||||
from apiserver.database.model.task.task import TaskType, ArtifactModes, DEFAULT_ARTIFACT_MODE
|
from apiserver.database.model.task.task import (
|
||||||
|
TaskType,
|
||||||
|
ArtifactModes,
|
||||||
|
DEFAULT_ARTIFACT_MODE,
|
||||||
|
)
|
||||||
from apiserver.database.utils import get_options
|
from apiserver.database.utils import get_options
|
||||||
|
|
||||||
|
|
||||||
@ -199,3 +203,12 @@ class EditConfigurationRequest(TaskRequest):
|
|||||||
|
|
||||||
class DeleteConfigurationRequest(TaskRequest):
|
class DeleteConfigurationRequest(TaskRequest):
|
||||||
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
|
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
|
||||||
|
|
||||||
|
|
||||||
|
class ArchiveRequest(MultiTaskRequest):
|
||||||
|
status_reason = StringField(default="")
|
||||||
|
status_message = StringField(default="")
|
||||||
|
|
||||||
|
|
||||||
|
class ArchiveResponse(models.Base):
|
||||||
|
archived = IntField()
|
||||||
|
@ -10,6 +10,7 @@ from six import string_types
|
|||||||
import apiserver.database.utils as dbutils
|
import apiserver.database.utils as dbutils
|
||||||
from apiserver.apierrors import errors
|
from apiserver.apierrors import errors
|
||||||
from apiserver.bll.organization import OrgBLL, Tags
|
from apiserver.bll.organization import OrgBLL, Tags
|
||||||
|
from apiserver.bll.queue import QueueBLL
|
||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
from apiserver.database.errors import translate_errors_context
|
from apiserver.database.errors import translate_errors_context
|
||||||
from apiserver.database.model.model import Model
|
from apiserver.database.model.model import Model
|
||||||
@ -327,14 +328,26 @@ class TaskBLL(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_last_update(
|
def set_last_update(
|
||||||
task_ids: Collection[str], company_id: str, last_update: datetime, **extra_updates
|
task_ids: Collection[str],
|
||||||
|
company_id: str,
|
||||||
|
last_update: datetime,
|
||||||
|
**extra_updates,
|
||||||
):
|
):
|
||||||
tasks = Task.objects(id__in=task_ids, company=company_id).only("status", "started")
|
tasks = Task.objects(id__in=task_ids, company=company_id).only(
|
||||||
|
"status", "started"
|
||||||
|
)
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
updates = extra_updates
|
updates = extra_updates
|
||||||
if task.status == TaskStatus.in_progress and task.started:
|
if task.status == TaskStatus.in_progress and task.started:
|
||||||
updates = {"active_duration": (datetime.utcnow() - task.started).total_seconds(), **extra_updates}
|
updates = {
|
||||||
Task.objects(id=task.id, company=company_id).update(upsert=False, last_update=last_update, **updates)
|
"active_duration": (
|
||||||
|
datetime.utcnow() - task.started
|
||||||
|
).total_seconds(),
|
||||||
|
**extra_updates,
|
||||||
|
}
|
||||||
|
Task.objects(id=task.id, company=company_id).update(
|
||||||
|
upsert=False, last_update=last_update, **updates
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_statistics(
|
def update_statistics(
|
||||||
@ -398,7 +411,10 @@ class TaskBLL(object):
|
|||||||
extra_updates["metric_stats"] = metric_stats
|
extra_updates["metric_stats"] = metric_stats
|
||||||
|
|
||||||
TaskBLL.set_last_update(
|
TaskBLL.set_last_update(
|
||||||
task_ids=[task_id], company_id=company_id, last_update=last_update, **extra_updates
|
task_ids=[task_id],
|
||||||
|
company_id=company_id,
|
||||||
|
last_update=last_update,
|
||||||
|
**extra_updates,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -613,3 +629,53 @@ class TaskBLL(object):
|
|||||||
remaining = max(0, total - (len(results) + page * page_size))
|
remaining = max(0, total - (len(results) + page * page_size))
|
||||||
|
|
||||||
return total, remaining, results
|
return total, remaining, results
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dequeue_and_change_status(
|
||||||
|
cls,
|
||||||
|
task: Task,
|
||||||
|
company_id: str,
|
||||||
|
status_message: str,
|
||||||
|
status_reason: str,
|
||||||
|
silent_dequeue_fail=False,
|
||||||
|
):
|
||||||
|
cls.dequeue(task, company_id, silent_dequeue_fail)
|
||||||
|
|
||||||
|
return ChangeStatusRequest(
|
||||||
|
task=task,
|
||||||
|
new_status=TaskStatus.created,
|
||||||
|
status_reason=status_reason,
|
||||||
|
status_message=status_message,
|
||||||
|
).execute(unset__execution__queue=1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dequeue(cls, task: Task, company_id: str, silent_fail=False):
|
||||||
|
"""
|
||||||
|
Dequeue the task from the queue
|
||||||
|
:param task: task to dequeue
|
||||||
|
:param company_id: task's company ID.
|
||||||
|
:param silent_fail: do not throw exceptions. APIError is still thrown
|
||||||
|
:raise errors.bad_request.InvalidTaskId: if the task's status is not queued
|
||||||
|
:raise errors.bad_request.MissingRequiredFields: if the task is not queued
|
||||||
|
:raise APIError or errors.server_error.TransactionError: if internal call to queues.remove_task fails
|
||||||
|
:return: the result of queues.remove_task call. None in case of silent failure
|
||||||
|
"""
|
||||||
|
if task.status not in (TaskStatus.queued,):
|
||||||
|
if silent_fail:
|
||||||
|
return
|
||||||
|
raise errors.bad_request.InvalidTaskId(
|
||||||
|
status=task.status, expected=TaskStatus.queued
|
||||||
|
)
|
||||||
|
|
||||||
|
if not task.execution or not task.execution.queue:
|
||||||
|
if silent_fail:
|
||||||
|
return
|
||||||
|
raise errors.bad_request.MissingRequiredFields(
|
||||||
|
"task has no queue value", field="execution.queue"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"removed": QueueBLL().remove_task(
|
||||||
|
company_id=company_id, queue_id=task.execution.queue, task_id=task.id
|
||||||
|
)
|
||||||
|
}
|
||||||
|
@ -1214,6 +1214,43 @@ delete {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
archive {
|
||||||
|
"2.11" {
|
||||||
|
description: """Archive tasks.
|
||||||
|
If a task is queued it will first be dequeued and then archived.
|
||||||
|
"""
|
||||||
|
request = {
|
||||||
|
type: object
|
||||||
|
required: [
|
||||||
|
tasks
|
||||||
|
]
|
||||||
|
properties {
|
||||||
|
tasks {
|
||||||
|
description: "List of task ids"
|
||||||
|
type: array
|
||||||
|
items { type: string }
|
||||||
|
}
|
||||||
|
status_reason {
|
||||||
|
description: Reason for status change
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
status_message {
|
||||||
|
description: Extra information regarding status change
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
archived {
|
||||||
|
description: "Indicates number of archived tasks"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
started {
|
started {
|
||||||
"2.1" {
|
"2.1" {
|
||||||
description: "Mark a task status as in_progress. Optionally allows to set the task's execution progress."
|
description: "Mark a task status as in_progress. Optionally allows to set the task's execution progress."
|
||||||
|
@ -39,6 +39,8 @@ from apiserver.apimodels.tasks import (
|
|||||||
DeleteConfigurationRequest,
|
DeleteConfigurationRequest,
|
||||||
GetConfigurationNamesRequest,
|
GetConfigurationNamesRequest,
|
||||||
DeleteArtifactsRequest,
|
DeleteArtifactsRequest,
|
||||||
|
ArchiveResponse,
|
||||||
|
ArchiveRequest,
|
||||||
)
|
)
|
||||||
from apiserver.bll.event import EventBLL
|
from apiserver.bll.event import EventBLL
|
||||||
from apiserver.bll.organization import OrgBLL, Tags
|
from apiserver.bll.organization import OrgBLL, Tags
|
||||||
@ -63,6 +65,7 @@ from apiserver.bll.task.param_utils import (
|
|||||||
)
|
)
|
||||||
from apiserver.bll.util import SetFieldsResolver
|
from apiserver.bll.util import SetFieldsResolver
|
||||||
from apiserver.database.errors import translate_errors_context
|
from apiserver.database.errors import translate_errors_context
|
||||||
|
from apiserver.database.model import EntityVisibility
|
||||||
from apiserver.database.model.model import Model
|
from apiserver.database.model.model import Model
|
||||||
from apiserver.database.model.task.output import Output
|
from apiserver.database.model.task.output import Output
|
||||||
from apiserver.database.model.task.task import (
|
from apiserver.database.model.task.task import (
|
||||||
@ -74,12 +77,18 @@ from apiserver.database.model.task.task import (
|
|||||||
)
|
)
|
||||||
from apiserver.database.utils import get_fields_attr, parse_from_call
|
from apiserver.database.utils import get_fields_attr, parse_from_call
|
||||||
from apiserver.service_repo import APICall, endpoint
|
from apiserver.service_repo import APICall, endpoint
|
||||||
from apiserver.services.utils import conform_tag_fields, conform_output_tags, validate_tags
|
from apiserver.services.utils import (
|
||||||
|
conform_tag_fields,
|
||||||
|
conform_output_tags,
|
||||||
|
validate_tags,
|
||||||
|
)
|
||||||
from apiserver.timing_context import TimingContext
|
from apiserver.timing_context import TimingContext
|
||||||
from apiserver.utilities.partial_version import PartialVersion
|
from apiserver.utilities.partial_version import PartialVersion
|
||||||
|
|
||||||
task_fields = set(Task.get_fields())
|
task_fields = set(Task.get_fields())
|
||||||
task_script_stripped_fields = set([f for f, v in get_fields_attr(Script, 'strip').items() if v])
|
task_script_stripped_fields = set(
|
||||||
|
[f for f, v in get_fields_attr(Script, "strip").items() if v]
|
||||||
|
)
|
||||||
|
|
||||||
task_bll = TaskBLL()
|
task_bll = TaskBLL()
|
||||||
event_bll = EventBLL()
|
event_bll = EventBLL()
|
||||||
@ -172,9 +181,7 @@ def get_by_id_ex(call: APICall, company_id, _):
|
|||||||
with translate_errors_context():
|
with translate_errors_context():
|
||||||
with TimingContext("mongo", "task_get_by_id_ex"):
|
with TimingContext("mongo", "task_get_by_id_ex"):
|
||||||
tasks = Task.get_many_with_join(
|
tasks = Task.get_many_with_join(
|
||||||
company=company_id,
|
company=company_id, query_dict=call.data, allow_public=True,
|
||||||
query_dict=call.data,
|
|
||||||
allow_public=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
unprepare_from_saved(call, tasks)
|
unprepare_from_saved(call, tasks)
|
||||||
@ -782,51 +789,14 @@ def dequeue(call: APICall, company_id, req_model: UpdateRequest):
|
|||||||
only=("id", "execution", "status", "project"),
|
only=("id", "execution", "status", "project"),
|
||||||
requires_write_access=True,
|
requires_write_access=True,
|
||||||
)
|
)
|
||||||
if task.status not in (TaskStatus.queued,):
|
|
||||||
raise errors.bad_request.InvalidTaskId(
|
|
||||||
status=task.status, expected=TaskStatus.queued
|
|
||||||
)
|
|
||||||
|
|
||||||
_dequeue(task, company_id)
|
|
||||||
|
|
||||||
status_message = req_model.status_message
|
|
||||||
status_reason = req_model.status_reason
|
|
||||||
res = DequeueResponse(
|
res = DequeueResponse(
|
||||||
**ChangeStatusRequest(
|
**TaskBLL.dequeue_and_change_status(task, company_id, req_model)
|
||||||
task=task,
|
|
||||||
new_status=TaskStatus.created,
|
|
||||||
status_reason=status_reason,
|
|
||||||
status_message=status_message,
|
|
||||||
).execute(unset__execution__queue=1)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
res.dequeued = 1
|
res.dequeued = 1
|
||||||
|
|
||||||
call.result.data_model = res
|
call.result.data_model = res
|
||||||
|
|
||||||
|
|
||||||
def _dequeue(task: Task, company_id: str, silent_fail=False):
|
|
||||||
"""
|
|
||||||
Dequeue the task from the queue
|
|
||||||
:param task: task to dequeue
|
|
||||||
:param silent_fail: do not throw exceptions. APIError is still thrown
|
|
||||||
:raise errors.bad_request.MissingRequiredFields: if the task is not queued
|
|
||||||
:raise APIError or errors.server_error.TransactionError: if internal call to queues.remove_task fails
|
|
||||||
:return: the result of queues.remove_task call. None in case of silent failure
|
|
||||||
"""
|
|
||||||
if not task.execution or not task.execution.queue:
|
|
||||||
if silent_fail:
|
|
||||||
return
|
|
||||||
raise errors.bad_request.MissingRequiredFields(
|
|
||||||
"task has no queue value", field="execution.queue"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"removed": queue_bll.remove_task(
|
|
||||||
company_id=company_id, queue_id=task.execution.queue, task_id=task.id
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@endpoint(
|
@endpoint(
|
||||||
"tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse
|
"tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse
|
||||||
)
|
)
|
||||||
@ -844,7 +814,7 @@ def reset(call: APICall, company_id, request: ResetRequest):
|
|||||||
updates = {}
|
updates = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dequeued = _dequeue(task, company_id, silent_fail=True)
|
dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
|
||||||
except APIError:
|
except APIError:
|
||||||
# dequeue may fail if the task was not enqueued
|
# dequeue may fail if the task was not enqueued
|
||||||
pass
|
pass
|
||||||
@ -897,6 +867,39 @@ def reset(call: APICall, company_id, request: ResetRequest):
|
|||||||
call.result.data_model = res
|
call.result.data_model = res
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint(
|
||||||
|
"tasks.archive",
|
||||||
|
request_data_model=ArchiveRequest,
|
||||||
|
response_data_model=ArchiveResponse,
|
||||||
|
)
|
||||||
|
def archive(call: APICall, company_id, request: ArchiveRequest):
|
||||||
|
archived = 0
|
||||||
|
tasks = TaskBLL.assert_exists(
|
||||||
|
company_id,
|
||||||
|
task_ids=request.tasks,
|
||||||
|
only=("id", "execution", "status", "project", "system_tags"),
|
||||||
|
)
|
||||||
|
for task in tasks:
|
||||||
|
TaskBLL.dequeue_and_change_status(
|
||||||
|
task,
|
||||||
|
company_id,
|
||||||
|
request.status_message,
|
||||||
|
request.status_reason,
|
||||||
|
silent_dequeue_fail=True,
|
||||||
|
)
|
||||||
|
task.update(
|
||||||
|
status_message=request.status_message,
|
||||||
|
status_reason=request.status_reason,
|
||||||
|
system_tags=sorted(
|
||||||
|
set(task.system_tags) | {EntityVisibility.archived.value}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
archived += 1
|
||||||
|
|
||||||
|
call.result.data_model = ArchiveResponse(archived=archived)
|
||||||
|
|
||||||
|
|
||||||
class DocumentGroup(list):
|
class DocumentGroup(list):
|
||||||
"""
|
"""
|
||||||
Operate on a list of documents as if they were a query result
|
Operate on a list of documents as if they were a query result
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
from apiserver.apierrors.errors.bad_request import InvalidModelId, ValidationError, InvalidTaskId
|
from apiserver.apierrors.errors.bad_request import InvalidModelId, ValidationError, InvalidTaskId
|
||||||
from apiserver.apierrors.errors.forbidden import NoWritePermission
|
from apiserver.apierrors.errors.forbidden import NoWritePermission
|
||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
|
from apiserver.tests.api_client import APIError
|
||||||
from apiserver.tests.automated import TestService
|
from apiserver.tests.automated import TestService
|
||||||
|
|
||||||
|
|
||||||
log = config.logger(__file__)
|
log = config.logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
@ -21,6 +21,10 @@ class TestTasksEdit(TestService):
|
|||||||
self.update_missing(kwargs, name="test", uri="file:///a/b", labels={})
|
self.update_missing(kwargs, name="test", uri="file:///a/b", labels={})
|
||||||
return self.create_temp("models", **kwargs)
|
return self.create_temp("models", **kwargs)
|
||||||
|
|
||||||
|
def new_queue(self, **kwargs):
|
||||||
|
self.update_missing(kwargs, name="test")
|
||||||
|
return self.create_temp("queues", **kwargs)
|
||||||
|
|
||||||
def test_task_types(self):
|
def test_task_types(self):
|
||||||
with self.api.raises(ValidationError):
|
with self.api.raises(ValidationError):
|
||||||
task = self.new_task(type="Unsupported")
|
task = self.new_task(type="Unsupported")
|
||||||
@ -171,3 +175,41 @@ class TestTasksEdit(TestService):
|
|||||||
res = self.api.tasks.get_all_ex(id=[task])
|
res = self.api.tasks.get_all_ex(id=[task])
|
||||||
self.assertEqual([t.id for t in res.tasks], [task])
|
self.assertEqual([t.id for t in res.tasks], [task])
|
||||||
self.api.tasks.stopped(task=task)
|
self.api.tasks.stopped(task=task)
|
||||||
|
|
||||||
|
def test_archive_task(self):
|
||||||
|
# non-existing task throws an exception
|
||||||
|
with self.assertRaises(APIError):
|
||||||
|
self.api.tasks.archive(tasks=["fake-task-id"])
|
||||||
|
|
||||||
|
system_tag = "existing-system-tag"
|
||||||
|
status_message = "test-status-message"
|
||||||
|
status_reason = "test-status-reason"
|
||||||
|
queue_id = self.new_queue()
|
||||||
|
|
||||||
|
# Create two tasks with system_tags and enqueue one of them
|
||||||
|
dequeued_task_id = self.new_task(system_tags=[system_tag])
|
||||||
|
enqueued_task_id = self.new_task(system_tags=[system_tag])
|
||||||
|
self.api.tasks.enqueue(task=enqueued_task_id, queue=queue_id)
|
||||||
|
|
||||||
|
self.api.tasks.archive(
|
||||||
|
tasks=[enqueued_task_id, dequeued_task_id],
|
||||||
|
status_message=status_message,
|
||||||
|
status_reason=status_reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
tasks = self.api.tasks.get_all_ex(id=[enqueued_task_id, dequeued_task_id]).tasks
|
||||||
|
|
||||||
|
for task in tasks:
|
||||||
|
self.assertIn(system_tag, task.system_tags)
|
||||||
|
self.assertIn("archived", task.system_tags)
|
||||||
|
self.assertNotIn("queue", task.execution)
|
||||||
|
self.assertIn(status_message, task.status_message)
|
||||||
|
self.assertIn(status_reason, task.status_reason)
|
||||||
|
|
||||||
|
# Check that the queue does not contain the enqueued task anymore
|
||||||
|
queue = self.api.queues.get_by_id(queue=queue_id).queue
|
||||||
|
task_in_queue = next(
|
||||||
|
(True for entry in queue.entries if entry["task"] == enqueued_task_id),
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
self.assertFalse(task_in_queue)
|
||||||
|
Loading…
Reference in New Issue
Block a user