mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 10:56:48 +00:00
174 lines
5.5 KiB
Python
174 lines
5.5 KiB
Python
from datetime import datetime
|
|
from typing import TypeVar, Callable, Tuple, Sequence
|
|
|
|
import attr
|
|
import six
|
|
|
|
from apiserver.apierrors import errors
|
|
from apiserver.database.errors import translate_errors_context
|
|
from apiserver.database.model.project import Project
|
|
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
|
|
from apiserver.database.utils import get_options
|
|
from apiserver.timing_context import TimingContext
|
|
from apiserver.utilities.attrs import typed_attrs
|
|
|
|
valid_statuses = get_options(TaskStatus)
|
|
|
|
|
|
@typed_attrs
|
|
class ChangeStatusRequest(object):
|
|
task = attr.ib(type=Task)
|
|
new_status = attr.ib(
|
|
type=six.string_types, validator=attr.validators.in_(valid_statuses)
|
|
)
|
|
status_reason = attr.ib(type=six.string_types, default="")
|
|
status_message = attr.ib(type=six.string_types, default="")
|
|
force = attr.ib(type=bool, default=False)
|
|
allow_same_state_transition = attr.ib(type=bool, default=True)
|
|
current_status_override = attr.ib(default=None)
|
|
|
|
def execute(self, **kwargs):
|
|
current_status = self.current_status_override or self.task.status
|
|
project_id = self.task.project
|
|
|
|
# Verify new status is allowed from current status (will throw exception if not valid)
|
|
self.validate_transition(current_status)
|
|
|
|
control = dict(upsert=False, multi=False, write_concern=None, full_result=False)
|
|
|
|
now = datetime.utcnow()
|
|
fields = dict(
|
|
status=self.new_status,
|
|
status_reason=self.status_reason,
|
|
status_message=self.status_message,
|
|
status_changed=now,
|
|
last_update=now,
|
|
)
|
|
|
|
if self.new_status == TaskStatus.queued:
|
|
fields["pull__system_tags"] = TaskSystemTags.development
|
|
|
|
def safe_mongoengine_key(key):
|
|
return f"__{key}" if key in control else key
|
|
|
|
fields.update({safe_mongoengine_key(k): v for k, v in kwargs.items()})
|
|
|
|
with translate_errors_context(), TimingContext("mongo", "task_status"):
|
|
# atomic change of task status by querying the task with the EXPECTED status before modifying it
|
|
params = fields.copy()
|
|
params.update(control)
|
|
updated = Task.objects(id=self.task.id, status=current_status).update(
|
|
**params
|
|
)
|
|
|
|
if not updated:
|
|
# failed to change status (someone else beat us to it?)
|
|
raise errors.bad_request.FailedChangingTaskStatus(
|
|
task_id=self.task.id,
|
|
current_status=current_status,
|
|
new_status=self.new_status,
|
|
)
|
|
|
|
update_project_time(project_id)
|
|
|
|
# make sure that _raw_ queries are not returned back to the client
|
|
fields.pop("__raw__", None)
|
|
|
|
return dict(updated=updated, fields=fields)
|
|
|
|
def validate_transition(self, current_status):
|
|
if self.force:
|
|
return
|
|
if self.new_status != current_status:
|
|
validate_status_change(current_status, self.new_status)
|
|
elif not self.allow_same_state_transition:
|
|
raise errors.bad_request.InvalidTaskStatus(
|
|
"Task already in requested status",
|
|
current_status=current_status,
|
|
new_status=self.new_status,
|
|
)
|
|
|
|
|
|
def validate_status_change(current_status, new_status):
|
|
assert current_status in valid_statuses
|
|
assert new_status in valid_statuses
|
|
|
|
allowed_statuses = get_possible_status_changes(current_status)
|
|
if new_status not in allowed_statuses:
|
|
raise errors.bad_request.InvalidTaskStatus(
|
|
"Invalid status change",
|
|
current_status=current_status,
|
|
new_status=new_status,
|
|
)
|
|
|
|
|
|
state_machine = {
|
|
TaskStatus.created: {TaskStatus.queued, TaskStatus.in_progress},
|
|
TaskStatus.queued: {TaskStatus.created, TaskStatus.in_progress},
|
|
TaskStatus.in_progress: {
|
|
TaskStatus.stopped,
|
|
TaskStatus.failed,
|
|
TaskStatus.created,
|
|
TaskStatus.completed,
|
|
},
|
|
TaskStatus.stopped: {
|
|
TaskStatus.closed,
|
|
TaskStatus.created,
|
|
TaskStatus.failed,
|
|
TaskStatus.in_progress,
|
|
TaskStatus.published,
|
|
TaskStatus.publishing,
|
|
TaskStatus.completed,
|
|
},
|
|
TaskStatus.closed: {
|
|
TaskStatus.created,
|
|
TaskStatus.failed,
|
|
TaskStatus.published,
|
|
TaskStatus.publishing,
|
|
TaskStatus.stopped,
|
|
},
|
|
TaskStatus.failed: {TaskStatus.created, TaskStatus.stopped, TaskStatus.published},
|
|
TaskStatus.publishing: {TaskStatus.published},
|
|
TaskStatus.published: set(),
|
|
TaskStatus.completed: {
|
|
TaskStatus.published,
|
|
TaskStatus.in_progress,
|
|
TaskStatus.created,
|
|
},
|
|
}
|
|
|
|
|
|
def get_possible_status_changes(current_status):
|
|
"""
|
|
:param current_status:
|
|
:return possible states from current state
|
|
"""
|
|
possible = state_machine.get(current_status)
|
|
if possible is None:
|
|
raise errors.server_error.InternalError(
|
|
f"Current status {current_status} not supported by state machine"
|
|
)
|
|
|
|
return possible
|
|
|
|
|
|
def update_project_time(project_id):
|
|
if project_id:
|
|
Project.objects(id=project_id).update(last_update=datetime.utcnow())
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def split_by(
|
|
condition: Callable[[T], bool], items: Sequence[T]
|
|
) -> Tuple[Sequence[T], Sequence[T]]:
|
|
"""
|
|
split "items" to two lists by "condition"
|
|
"""
|
|
applied = zip(map(condition, items), items)
|
|
return (
|
|
[item for cond, item in applied if cond],
|
|
[item for cond, item in applied if not cond],
|
|
)
|