clearml-server/apiserver/bll/task/utils.py
allegroai ae4c33fa0e Add support for allow_public flag in get_all_ex endpoint
Add `last_changed_by` field on task updates
Fix reports support
2022-12-21 18:32:56 +02:00

198 lines
6.4 KiB
Python

from datetime import datetime
from typing import Sequence, Union
import attr
import six
from apiserver.apierrors import errors
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
from apiserver.database.utils import get_options
from apiserver.utilities.attrs import typed_attrs
valid_statuses = get_options(TaskStatus)
deleted_prefix = "__DELETED__"
@typed_attrs
class ChangeStatusRequest(object):
task = attr.ib(type=Task)
new_status = attr.ib(
type=six.string_types, validator=attr.validators.in_(valid_statuses)
)
status_reason = attr.ib(type=six.string_types, default="")
status_message = attr.ib(type=six.string_types, default="")
force = attr.ib(type=bool, default=False)
allow_same_state_transition = attr.ib(type=bool, default=True)
current_status_override = attr.ib(default=None)
user_id = attr.ib(type=str, default=None)
def execute(self, **kwargs):
current_status = self.current_status_override or self.task.status
project_id = self.task.project
# Verify new status is allowed from current status (will throw exception if not valid)
self.validate_transition(current_status)
control = dict(upsert=False, multi=False, write_concern=None, full_result=False)
now = datetime.utcnow()
fields = dict(
status=self.new_status,
status_reason=self.status_reason,
status_message=self.status_message,
status_changed=now,
last_update=now,
last_change=now,
last_changed_by=self.user_id,
)
if self.new_status == TaskStatus.queued:
fields["pull__system_tags"] = TaskSystemTags.development
def safe_mongoengine_key(key):
return f"__{key}" if key in control else key
fields.update({safe_mongoengine_key(k): v for k, v in kwargs.items()})
with translate_errors_context():
# atomic change of task status by querying the task with the EXPECTED status before modifying it
params = fields.copy()
params.update(control)
updated = Task.objects(id=self.task.id, status=current_status).update(
**params
)
if not updated:
# failed to change status (someone else beat us to it?)
raise errors.bad_request.FailedChangingTaskStatus(
task_id=self.task.id,
current_status=current_status,
new_status=self.new_status,
)
update_project_time(project_id)
# make sure that _raw_ queries are not returned back to the client
fields.pop("__raw__", None)
return dict(updated=updated, fields=fields)
def validate_transition(self, current_status):
if self.force:
return
if self.new_status != current_status:
validate_status_change(current_status, self.new_status)
elif not self.allow_same_state_transition:
raise errors.bad_request.InvalidTaskStatus(
"Task already in requested status",
current_status=current_status,
new_status=self.new_status,
)
def validate_status_change(current_status, new_status):
assert current_status in valid_statuses
assert new_status in valid_statuses
allowed_statuses = get_possible_status_changes(current_status)
if new_status not in allowed_statuses:
raise errors.bad_request.InvalidTaskStatus(
"Invalid status change",
current_status=current_status,
new_status=new_status,
)
state_machine = {
TaskStatus.created: {TaskStatus.queued, TaskStatus.in_progress},
TaskStatus.queued: {TaskStatus.created, TaskStatus.in_progress, TaskStatus.stopped},
TaskStatus.in_progress: {
TaskStatus.stopped,
TaskStatus.failed,
TaskStatus.created,
TaskStatus.completed,
},
TaskStatus.stopped: {
TaskStatus.closed,
TaskStatus.created,
TaskStatus.failed,
TaskStatus.queued,
TaskStatus.in_progress,
TaskStatus.published,
TaskStatus.publishing,
TaskStatus.completed,
},
TaskStatus.closed: {
TaskStatus.created,
TaskStatus.failed,
TaskStatus.published,
TaskStatus.publishing,
TaskStatus.stopped,
},
TaskStatus.failed: {TaskStatus.created, TaskStatus.stopped, TaskStatus.published},
TaskStatus.publishing: {TaskStatus.published},
TaskStatus.published: set(),
TaskStatus.completed: {
TaskStatus.published,
TaskStatus.in_progress,
TaskStatus.created,
},
}
def get_possible_status_changes(current_status):
"""
:param current_status:
:return possible states from current state
"""
possible = state_machine.get(current_status)
if possible is None:
raise errors.server_error.InternalError(
f"Current status {current_status} not supported by state machine"
)
return possible
def update_project_time(project_ids: Union[str, Sequence[str]]):
if not project_ids:
return
if isinstance(project_ids, str):
project_ids = [project_ids]
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
) -> Task:
"""
Loads only task id and return the task only if it is updatable (status == 'created')
"""
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
if allow_all_statuses:
return task
allowed_statuses = (
[TaskStatus.created, TaskStatus.in_progress] if force else [TaskStatus.created]
)
if task.status not in allowed_statuses:
raise errors.bad_request.InvalidTaskStatus(
expected=TaskStatus.created, status=task.status
)
return task
def update_task(task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True):
now = datetime.utcnow()
last_updates = dict(last_change=now, last_changed_by=user_id)
if set_last_update:
last_updates.update(last_update=now)
return task.update(**update_cmds, **last_updates)