mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 10:56:48 +00:00
197 lines
6.4 KiB
Python
197 lines
6.4 KiB
Python
from datetime import datetime
|
|
from typing import Sequence, Union
|
|
|
|
import attr
|
|
import six
|
|
|
|
from apiserver.apierrors import errors
|
|
from apiserver.database.errors import translate_errors_context
|
|
from apiserver.database.model.project import Project
|
|
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
|
|
from apiserver.database.utils import get_options
|
|
from apiserver.timing_context import TimingContext
|
|
from apiserver.utilities.attrs import typed_attrs
|
|
|
|
valid_statuses = get_options(TaskStatus)
|
|
deleted_prefix = "__DELETED__"
|
|
|
|
|
|
@typed_attrs
|
|
class ChangeStatusRequest(object):
|
|
task = attr.ib(type=Task)
|
|
new_status = attr.ib(
|
|
type=six.string_types, validator=attr.validators.in_(valid_statuses)
|
|
)
|
|
status_reason = attr.ib(type=six.string_types, default="")
|
|
status_message = attr.ib(type=six.string_types, default="")
|
|
force = attr.ib(type=bool, default=False)
|
|
allow_same_state_transition = attr.ib(type=bool, default=True)
|
|
current_status_override = attr.ib(default=None)
|
|
|
|
def execute(self, **kwargs):
|
|
current_status = self.current_status_override or self.task.status
|
|
project_id = self.task.project
|
|
|
|
# Verify new status is allowed from current status (will throw exception if not valid)
|
|
self.validate_transition(current_status)
|
|
|
|
control = dict(upsert=False, multi=False, write_concern=None, full_result=False)
|
|
|
|
now = datetime.utcnow()
|
|
fields = dict(
|
|
status=self.new_status,
|
|
status_reason=self.status_reason,
|
|
status_message=self.status_message,
|
|
status_changed=now,
|
|
last_update=now,
|
|
last_change=now,
|
|
)
|
|
|
|
if self.new_status == TaskStatus.queued:
|
|
fields["pull__system_tags"] = TaskSystemTags.development
|
|
|
|
def safe_mongoengine_key(key):
|
|
return f"__{key}" if key in control else key
|
|
|
|
fields.update({safe_mongoengine_key(k): v for k, v in kwargs.items()})
|
|
|
|
with translate_errors_context(), TimingContext("mongo", "task_status"):
|
|
# atomic change of task status by querying the task with the EXPECTED status before modifying it
|
|
params = fields.copy()
|
|
params.update(control)
|
|
updated = Task.objects(id=self.task.id, status=current_status).update(
|
|
**params
|
|
)
|
|
|
|
if not updated:
|
|
# failed to change status (someone else beat us to it?)
|
|
raise errors.bad_request.FailedChangingTaskStatus(
|
|
task_id=self.task.id,
|
|
current_status=current_status,
|
|
new_status=self.new_status,
|
|
)
|
|
|
|
update_project_time(project_id)
|
|
|
|
# make sure that _raw_ queries are not returned back to the client
|
|
fields.pop("__raw__", None)
|
|
|
|
return dict(updated=updated, fields=fields)
|
|
|
|
def validate_transition(self, current_status):
|
|
if self.force:
|
|
return
|
|
if self.new_status != current_status:
|
|
validate_status_change(current_status, self.new_status)
|
|
elif not self.allow_same_state_transition:
|
|
raise errors.bad_request.InvalidTaskStatus(
|
|
"Task already in requested status",
|
|
current_status=current_status,
|
|
new_status=self.new_status,
|
|
)
|
|
|
|
|
|
def validate_status_change(current_status, new_status):
|
|
assert current_status in valid_statuses
|
|
assert new_status in valid_statuses
|
|
|
|
allowed_statuses = get_possible_status_changes(current_status)
|
|
if new_status not in allowed_statuses:
|
|
raise errors.bad_request.InvalidTaskStatus(
|
|
"Invalid status change",
|
|
current_status=current_status,
|
|
new_status=new_status,
|
|
)
|
|
|
|
|
|
state_machine = {
|
|
TaskStatus.created: {TaskStatus.queued, TaskStatus.in_progress},
|
|
TaskStatus.queued: {TaskStatus.created, TaskStatus.in_progress, TaskStatus.stopped},
|
|
TaskStatus.in_progress: {
|
|
TaskStatus.stopped,
|
|
TaskStatus.failed,
|
|
TaskStatus.created,
|
|
TaskStatus.completed,
|
|
},
|
|
TaskStatus.stopped: {
|
|
TaskStatus.closed,
|
|
TaskStatus.created,
|
|
TaskStatus.failed,
|
|
TaskStatus.queued,
|
|
TaskStatus.in_progress,
|
|
TaskStatus.published,
|
|
TaskStatus.publishing,
|
|
TaskStatus.completed,
|
|
},
|
|
TaskStatus.closed: {
|
|
TaskStatus.created,
|
|
TaskStatus.failed,
|
|
TaskStatus.published,
|
|
TaskStatus.publishing,
|
|
TaskStatus.stopped,
|
|
},
|
|
TaskStatus.failed: {TaskStatus.created, TaskStatus.stopped, TaskStatus.published},
|
|
TaskStatus.publishing: {TaskStatus.published},
|
|
TaskStatus.published: set(),
|
|
TaskStatus.completed: {
|
|
TaskStatus.published,
|
|
TaskStatus.in_progress,
|
|
TaskStatus.created,
|
|
},
|
|
}
|
|
|
|
|
|
def get_possible_status_changes(current_status):
|
|
"""
|
|
:param current_status:
|
|
:return possible states from current state
|
|
"""
|
|
possible = state_machine.get(current_status)
|
|
if possible is None:
|
|
raise errors.server_error.InternalError(
|
|
f"Current status {current_status} not supported by state machine"
|
|
)
|
|
|
|
return possible
|
|
|
|
|
|
def update_project_time(project_ids: Union[str, Sequence[str]]):
|
|
if not project_ids:
|
|
return
|
|
|
|
if isinstance(project_ids, str):
|
|
project_ids = [project_ids]
|
|
|
|
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
|
|
|
|
|
|
def get_task_for_update(
|
|
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
|
|
) -> Task:
|
|
"""
|
|
Loads only task id and return the task only if it is updatable (status == 'created')
|
|
"""
|
|
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
|
|
if not task:
|
|
raise errors.bad_request.InvalidTaskId(id=task_id)
|
|
|
|
if allow_all_statuses:
|
|
return task
|
|
|
|
allowed_statuses = (
|
|
[TaskStatus.created, TaskStatus.in_progress] if force else [TaskStatus.created]
|
|
)
|
|
if task.status not in allowed_statuses:
|
|
raise errors.bad_request.InvalidTaskStatus(
|
|
expected=TaskStatus.created, status=task.status
|
|
)
|
|
return task
|
|
|
|
|
|
def update_task(task: Task, update_cmds: dict, set_last_update: bool = True):
|
|
now = datetime.utcnow()
|
|
last_updates = dict(last_change=now)
|
|
if set_last_update:
|
|
last_updates.update(last_update=now)
|
|
return task.update(**update_cmds, **last_updates)
|