mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 10:56:48 +00:00
279 lines
9.3 KiB
Python
279 lines
9.3 KiB
Python
from datetime import datetime
|
|
from typing import Sequence, Union
|
|
|
|
import attr
|
|
import six
|
|
|
|
from apiserver.apierrors import errors
|
|
from apiserver.config_repo import config
|
|
from apiserver.database.errors import translate_errors_context
|
|
from apiserver.database.model.model import Model
|
|
from apiserver.database.model.project import Project
|
|
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
|
|
from apiserver.database.utils import get_options
|
|
from apiserver.utilities.attrs import typed_attrs
|
|
|
|
valid_statuses = get_options(TaskStatus)
|
|
deleted_prefix = "__DELETED__"
|
|
|
|
|
|
@typed_attrs
|
|
class ChangeStatusRequest(object):
|
|
task = attr.ib(type=Task)
|
|
new_status = attr.ib(
|
|
type=six.string_types, validator=attr.validators.in_(valid_statuses)
|
|
)
|
|
status_reason = attr.ib(type=six.string_types, default="")
|
|
status_message = attr.ib(type=six.string_types, default="")
|
|
force = attr.ib(type=bool, default=False)
|
|
allow_same_state_transition = attr.ib(type=bool, default=True)
|
|
current_status_override = attr.ib(default=None)
|
|
user_id = attr.ib(type=str, default=None)
|
|
|
|
def execute(self, **kwargs):
|
|
current_status = self.current_status_override or self.task.status
|
|
project_id = self.task.project
|
|
|
|
# Verify new status is allowed from current status (will throw exception if not valid)
|
|
self.validate_transition(current_status)
|
|
|
|
control = dict(upsert=False, multi=False, write_concern=None, full_result=False)
|
|
|
|
now = datetime.utcnow()
|
|
fields = dict(
|
|
status=self.new_status,
|
|
status_reason=self.status_reason,
|
|
status_message=self.status_message,
|
|
status_changed=now,
|
|
last_update=now,
|
|
last_change=now,
|
|
last_changed_by=self.user_id,
|
|
)
|
|
|
|
if self.new_status == TaskStatus.queued:
|
|
fields["pull__system_tags"] = TaskSystemTags.development
|
|
|
|
def safe_mongoengine_key(key):
|
|
return f"__{key}" if key in control else key
|
|
|
|
fields.update({safe_mongoengine_key(k): v for k, v in kwargs.items()})
|
|
|
|
with translate_errors_context():
|
|
# atomic change of task status by querying the task with the EXPECTED status before modifying it
|
|
params = fields.copy()
|
|
params.update(control)
|
|
updated = Task.objects(id=self.task.id, status=current_status).update(
|
|
**params
|
|
)
|
|
|
|
if not updated:
|
|
# failed to change status (someone else beat us to it?)
|
|
raise errors.bad_request.FailedChangingTaskStatus(
|
|
task_id=self.task.id,
|
|
current_status=current_status,
|
|
new_status=self.new_status,
|
|
)
|
|
|
|
update_project_time(project_id)
|
|
|
|
# make sure that _raw_ queries are not returned back to the client
|
|
fields.pop("__raw__", None)
|
|
|
|
return dict(updated=updated, fields=fields)
|
|
|
|
def validate_transition(self, current_status):
|
|
if self.force:
|
|
return
|
|
if self.new_status != current_status:
|
|
validate_status_change(current_status, self.new_status)
|
|
elif not self.allow_same_state_transition:
|
|
raise errors.bad_request.InvalidTaskStatus(
|
|
"Task already in requested status",
|
|
current_status=current_status,
|
|
new_status=self.new_status,
|
|
)
|
|
|
|
|
|
def validate_status_change(current_status, new_status):
|
|
assert current_status in valid_statuses
|
|
assert new_status in valid_statuses
|
|
|
|
allowed_statuses = get_possible_status_changes(current_status)
|
|
if new_status not in allowed_statuses:
|
|
raise errors.bad_request.InvalidTaskStatus(
|
|
"Invalid status change",
|
|
current_status=current_status,
|
|
new_status=new_status,
|
|
)
|
|
|
|
|
|
state_machine = {
|
|
TaskStatus.created: {TaskStatus.queued, TaskStatus.in_progress},
|
|
TaskStatus.queued: {TaskStatus.created, TaskStatus.in_progress, TaskStatus.stopped},
|
|
TaskStatus.in_progress: {
|
|
TaskStatus.stopped,
|
|
TaskStatus.failed,
|
|
TaskStatus.created,
|
|
TaskStatus.completed,
|
|
},
|
|
TaskStatus.stopped: {
|
|
TaskStatus.closed,
|
|
TaskStatus.created,
|
|
TaskStatus.failed,
|
|
TaskStatus.queued,
|
|
TaskStatus.in_progress,
|
|
TaskStatus.published,
|
|
TaskStatus.publishing,
|
|
TaskStatus.completed,
|
|
},
|
|
TaskStatus.closed: {
|
|
TaskStatus.created,
|
|
TaskStatus.failed,
|
|
TaskStatus.published,
|
|
TaskStatus.publishing,
|
|
TaskStatus.stopped,
|
|
},
|
|
TaskStatus.failed: {TaskStatus.created, TaskStatus.stopped, TaskStatus.published},
|
|
TaskStatus.publishing: {TaskStatus.published},
|
|
TaskStatus.published: set(),
|
|
TaskStatus.completed: {
|
|
TaskStatus.published,
|
|
TaskStatus.in_progress,
|
|
TaskStatus.created,
|
|
},
|
|
}
|
|
|
|
|
|
def get_possible_status_changes(current_status):
|
|
"""
|
|
:param current_status:
|
|
:return possible states from current state
|
|
"""
|
|
possible = state_machine.get(current_status)
|
|
if possible is None:
|
|
raise errors.server_error.InternalError(
|
|
f"Current status {current_status} not supported by state machine"
|
|
)
|
|
|
|
return possible
|
|
|
|
|
|
def update_project_time(project_ids: Union[str, Sequence[str]]):
|
|
if not project_ids:
|
|
return
|
|
|
|
if isinstance(project_ids, str):
|
|
project_ids = [project_ids]
|
|
|
|
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
|
|
|
|
|
|
def get_task_for_update(
|
|
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
|
|
) -> Task:
|
|
"""
|
|
Loads only task id and return the task only if it is updatable (status == 'created')
|
|
"""
|
|
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
|
|
if not task:
|
|
raise errors.bad_request.InvalidTaskId(id=task_id)
|
|
|
|
if allow_all_statuses:
|
|
return task
|
|
|
|
allowed_statuses = (
|
|
[TaskStatus.created, TaskStatus.in_progress] if force else [TaskStatus.created]
|
|
)
|
|
if task.status not in allowed_statuses:
|
|
raise errors.bad_request.InvalidTaskStatus(
|
|
expected=TaskStatus.created, status=task.status
|
|
)
|
|
return task
|
|
|
|
|
|
def update_task(
|
|
task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True
|
|
):
|
|
now = datetime.utcnow()
|
|
last_updates = dict(last_change=now, last_changed_by=user_id)
|
|
if set_last_update:
|
|
last_updates.update(last_update=now)
|
|
return task.update(**update_cmds, **last_updates)
|
|
|
|
|
|
def get_last_metric_updates(
|
|
task_id: str,
|
|
last_scalar_events: dict,
|
|
raw_updates: dict,
|
|
extra_updates: dict,
|
|
model_events: bool = False,
|
|
):
|
|
max_values = config.get("services.tasks.max_last_metrics", 2000)
|
|
total_metrics = set()
|
|
if max_values:
|
|
query = dict(id=task_id)
|
|
to_add = sum(len(v) for m, v in last_scalar_events.items())
|
|
if to_add <= max_values:
|
|
query[f"unique_metrics__{max_values - to_add}__exists"] = True
|
|
db_cls = Model if model_events else Task
|
|
task = db_cls.objects(**query).only("unique_metrics").first()
|
|
if task and task.unique_metrics:
|
|
total_metrics = set(task.unique_metrics)
|
|
|
|
new_metrics = []
|
|
|
|
def add_last_metric_conditional_update(
|
|
metric_path: str, metric_value, iter_value: int, is_min: bool
|
|
):
|
|
"""
|
|
Build an aggregation for an atomic update of the min or max value and the corresponding iteration
|
|
"""
|
|
if is_min:
|
|
field_prefix = "min"
|
|
op = "$gt"
|
|
else:
|
|
field_prefix = "max"
|
|
op = "$lt"
|
|
|
|
value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".")
|
|
condition = {
|
|
"$or": [
|
|
{"$lte": [f"${value_field}", None]},
|
|
{op: [f"${value_field}", metric_value]},
|
|
]
|
|
}
|
|
raw_updates[value_field] = {
|
|
"$cond": [condition, metric_value, f"${value_field}"]
|
|
}
|
|
|
|
value_iteration_field = f"{metric_path}__{field_prefix}_value_iteration".replace(
|
|
"__", "."
|
|
)
|
|
raw_updates[value_iteration_field] = {
|
|
"$cond": [condition, iter_value, f"${value_iteration_field}"]
|
|
}
|
|
|
|
for metric_key, metric_data in last_scalar_events.items():
|
|
for variant_key, variant_data in metric_data.items():
|
|
metric = f"{variant_data.get('metric')}/{variant_data.get('variant')}"
|
|
if max_values:
|
|
if len(total_metrics) >= max_values and metric not in total_metrics:
|
|
continue
|
|
total_metrics.add(metric)
|
|
|
|
new_metrics.append(metric)
|
|
path = f"last_metrics__{metric_key}__{variant_key}"
|
|
for key, value in variant_data.items():
|
|
if key in ("min_value", "max_value"):
|
|
add_last_metric_conditional_update(
|
|
metric_path=path,
|
|
metric_value=value,
|
|
iter_value=variant_data.get(f"{key}_iter", 0),
|
|
is_min=(key == "min_value"),
|
|
)
|
|
elif key in ("metric", "variant", "value"):
|
|
extra_updates[f"set__{path}__{key}"] = value
|
|
|
|
if new_metrics:
|
|
extra_updates["add_to_set__unique_metrics"] = new_metrics
|