clearml-server/apiserver/bll/task/utils.py
2023-05-25 19:17:40 +03:00

279 lines
9.3 KiB
Python

from datetime import datetime
from typing import Sequence, Union
import attr
import six
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
from apiserver.database.utils import get_options
from apiserver.utilities.attrs import typed_attrs
valid_statuses = get_options(TaskStatus)
deleted_prefix = "__DELETED__"
@typed_attrs
class ChangeStatusRequest(object):
task = attr.ib(type=Task)
new_status = attr.ib(
type=six.string_types, validator=attr.validators.in_(valid_statuses)
)
status_reason = attr.ib(type=six.string_types, default="")
status_message = attr.ib(type=six.string_types, default="")
force = attr.ib(type=bool, default=False)
allow_same_state_transition = attr.ib(type=bool, default=True)
current_status_override = attr.ib(default=None)
user_id = attr.ib(type=str, default=None)
def execute(self, **kwargs):
current_status = self.current_status_override or self.task.status
project_id = self.task.project
# Verify new status is allowed from current status (will throw exception if not valid)
self.validate_transition(current_status)
control = dict(upsert=False, multi=False, write_concern=None, full_result=False)
now = datetime.utcnow()
fields = dict(
status=self.new_status,
status_reason=self.status_reason,
status_message=self.status_message,
status_changed=now,
last_update=now,
last_change=now,
last_changed_by=self.user_id,
)
if self.new_status == TaskStatus.queued:
fields["pull__system_tags"] = TaskSystemTags.development
def safe_mongoengine_key(key):
return f"__{key}" if key in control else key
fields.update({safe_mongoengine_key(k): v for k, v in kwargs.items()})
with translate_errors_context():
# atomic change of task status by querying the task with the EXPECTED status before modifying it
params = fields.copy()
params.update(control)
updated = Task.objects(id=self.task.id, status=current_status).update(
**params
)
if not updated:
# failed to change status (someone else beat us to it?)
raise errors.bad_request.FailedChangingTaskStatus(
task_id=self.task.id,
current_status=current_status,
new_status=self.new_status,
)
update_project_time(project_id)
# make sure that _raw_ queries are not returned back to the client
fields.pop("__raw__", None)
return dict(updated=updated, fields=fields)
def validate_transition(self, current_status):
if self.force:
return
if self.new_status != current_status:
validate_status_change(current_status, self.new_status)
elif not self.allow_same_state_transition:
raise errors.bad_request.InvalidTaskStatus(
"Task already in requested status",
current_status=current_status,
new_status=self.new_status,
)
def validate_status_change(current_status, new_status):
assert current_status in valid_statuses
assert new_status in valid_statuses
allowed_statuses = get_possible_status_changes(current_status)
if new_status not in allowed_statuses:
raise errors.bad_request.InvalidTaskStatus(
"Invalid status change",
current_status=current_status,
new_status=new_status,
)
state_machine = {
TaskStatus.created: {TaskStatus.queued, TaskStatus.in_progress},
TaskStatus.queued: {TaskStatus.created, TaskStatus.in_progress, TaskStatus.stopped},
TaskStatus.in_progress: {
TaskStatus.stopped,
TaskStatus.failed,
TaskStatus.created,
TaskStatus.completed,
},
TaskStatus.stopped: {
TaskStatus.closed,
TaskStatus.created,
TaskStatus.failed,
TaskStatus.queued,
TaskStatus.in_progress,
TaskStatus.published,
TaskStatus.publishing,
TaskStatus.completed,
},
TaskStatus.closed: {
TaskStatus.created,
TaskStatus.failed,
TaskStatus.published,
TaskStatus.publishing,
TaskStatus.stopped,
},
TaskStatus.failed: {TaskStatus.created, TaskStatus.stopped, TaskStatus.published},
TaskStatus.publishing: {TaskStatus.published},
TaskStatus.published: set(),
TaskStatus.completed: {
TaskStatus.published,
TaskStatus.in_progress,
TaskStatus.created,
},
}
def get_possible_status_changes(current_status):
"""
:param current_status:
:return possible states from current state
"""
possible = state_machine.get(current_status)
if possible is None:
raise errors.server_error.InternalError(
f"Current status {current_status} not supported by state machine"
)
return possible
def update_project_time(project_ids: Union[str, Sequence[str]]):
if not project_ids:
return
if isinstance(project_ids, str):
project_ids = [project_ids]
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
) -> Task:
"""
Loads only task id and return the task only if it is updatable (status == 'created')
"""
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
if allow_all_statuses:
return task
allowed_statuses = (
[TaskStatus.created, TaskStatus.in_progress] if force else [TaskStatus.created]
)
if task.status not in allowed_statuses:
raise errors.bad_request.InvalidTaskStatus(
expected=TaskStatus.created, status=task.status
)
return task
def update_task(
task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True
):
now = datetime.utcnow()
last_updates = dict(last_change=now, last_changed_by=user_id)
if set_last_update:
last_updates.update(last_update=now)
return task.update(**update_cmds, **last_updates)
def get_last_metric_updates(
task_id: str,
last_scalar_events: dict,
raw_updates: dict,
extra_updates: dict,
model_events: bool = False,
):
max_values = config.get("services.tasks.max_last_metrics", 2000)
total_metrics = set()
if max_values:
query = dict(id=task_id)
to_add = sum(len(v) for m, v in last_scalar_events.items())
if to_add <= max_values:
query[f"unique_metrics__{max_values - to_add}__exists"] = True
db_cls = Model if model_events else Task
task = db_cls.objects(**query).only("unique_metrics").first()
if task and task.unique_metrics:
total_metrics = set(task.unique_metrics)
new_metrics = []
def add_last_metric_conditional_update(
metric_path: str, metric_value, iter_value: int, is_min: bool
):
"""
Build an aggregation for an atomic update of the min or max value and the corresponding iteration
"""
if is_min:
field_prefix = "min"
op = "$gt"
else:
field_prefix = "max"
op = "$lt"
value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".")
condition = {
"$or": [
{"$lte": [f"${value_field}", None]},
{op: [f"${value_field}", metric_value]},
]
}
raw_updates[value_field] = {
"$cond": [condition, metric_value, f"${value_field}"]
}
value_iteration_field = f"{metric_path}__{field_prefix}_value_iteration".replace(
"__", "."
)
raw_updates[value_iteration_field] = {
"$cond": [condition, iter_value, f"${value_iteration_field}"]
}
for metric_key, metric_data in last_scalar_events.items():
for variant_key, variant_data in metric_data.items():
metric = f"{variant_data.get('metric')}/{variant_data.get('variant')}"
if max_values:
if len(total_metrics) >= max_values and metric not in total_metrics:
continue
total_metrics.add(metric)
new_metrics.append(metric)
path = f"last_metrics__{metric_key}__{variant_key}"
for key, value in variant_data.items():
if key in ("min_value", "max_value"):
add_last_metric_conditional_update(
metric_path=path,
metric_value=value,
iter_value=variant_data.get(f"{key}_iter", 0),
is_min=(key == "min_value"),
)
elif key in ("metric", "variant", "value"):
extra_updates[f"set__{path}__{key}"] = value
if new_metrics:
extra_updates["add_to_set__unique_metrics"] = new_metrics