clearml-server/apiserver/bll/task/task_bll.py
2021-01-05 16:29:25 +02:00

702 lines
25 KiB
Python

from collections import OrderedDict
from datetime import datetime
from operator import attrgetter
from random import random
from time import sleep
from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict
import dpath
import pymongo.results
import six
from mongoengine import Q
from six import string_types
import apiserver.database.utils as dbutils
from apiserver.es_factory import es_factory
from apiserver.apierrors import errors
from apiserver.apimodels.tasks import Artifact as ApiArtifact
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.config import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.metrics import EventStats, MetricEventStats
from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import (
Task,
TaskStatus,
TaskStatusMessage,
TaskSystemTags,
ArtifactModes,
Artifact,
external_task_types,
)
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
from apiserver.service_repo import APICall
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import deep_merge
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .param_utils import params_prepare_for_save
from .utils import ChangeStatusRequest, validate_status_change
log = config.logger(__file__)
org_bll = OrgBLL()
class TaskBLL(object):
def __init__(self, events_es=None):
self.events_es = (
events_es if events_es is not None else es_factory.connect("events")
)
@classmethod
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
"""
Return the list of unique task types used by company and public tasks
If project ids passed then only tasks from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
query &= Q(project__in=project_ids)
res = Task.objects(query).distinct(field="type")
return set(res).intersection(external_task_types)
@staticmethod
def get_task_with_access(
task_id, company_id, only=None, allow_public=False, requires_write_access=False
) -> Task:
"""
Gets a task that has a required write access
:except errors.bad_request.InvalidTaskId: if the task is not found
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
"""
with translate_errors_context():
query = dict(id=task_id, company=company_id)
with TimingContext("mongo", "task_with_access"):
if requires_write_access:
task = Task.get_for_writing(_only=only, **query)
else:
task = Task.get(_only=only, **query, include_public=allow_public)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
return task
@staticmethod
def get_by_id(
company_id, task_id, required_status=None, only_fields=None, allow_public=False,
):
if only_fields:
if isinstance(only_fields, string_types):
only_fields = [only_fields]
else:
only_fields = list(only_fields)
only_fields = only_fields + ["status"]
with TimingContext("mongo", "task_by_id_all"):
tasks = Task.get_many(
company=company_id,
query=Q(id=task_id),
allow_public=allow_public,
override_projection=only_fields,
return_dicts=False,
)
task = None if not tasks else tasks[0]
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
if required_status and not task.status == required_status:
raise errors.bad_request.InvalidTaskStatus(expected=required_status)
return task
@staticmethod
def assert_exists(
company_id, task_ids, only=None, allow_public=False, return_tasks=True
) -> Optional[Sequence[Task]]:
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
with translate_errors_context(), TimingContext("mongo", "task_exists"):
ids = set(task_ids)
q = Task.get_many(
company=company_id,
query=Q(id__in=ids),
allow_public=allow_public,
return_dicts=False,
)
if only:
# Make sure to reset fields filters (some fields are excluded by default) since this
# is an internal call and specific fields were requested.
q = q.all_fields().only(*only)
if q.count() != len(ids):
raise errors.bad_request.InvalidTaskId(ids=task_ids)
if return_tasks:
return list(q)
@staticmethod
def create(call: APICall, fields: dict):
identity = call.identity
now = datetime.utcnow()
return Task(
id=create_id(),
user=identity.user,
company=identity.company,
created=now,
last_update=now,
**fields,
)
@staticmethod
def validate_execution_model(task, allow_only_public=False):
if not task.execution or not task.execution.model:
return
company = None if allow_only_public else task.company
model_id = task.execution.model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(company)
).first()
if not model:
raise errors.bad_request.InvalidModelId(model=model_id)
return model
@classmethod
def clone_task(
cls,
company_id,
user_id,
task_id,
name: Optional[str] = None,
comment: Optional[str] = None,
parent: Optional[str] = None,
project: Optional[str] = None,
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
hyperparams: Optional[dict] = None,
configuration: Optional[dict] = None,
execution_overrides: Optional[dict] = None,
validate_references: bool = False,
) -> Task:
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False
params_dict = {
field: value
for field, value in (
("hyperparams", hyperparams),
("configuration", configuration),
)
if value is not None
}
if execution_overrides:
params_dict["execution"] = {}
for legacy_param in ("parameters", "configuration"):
legacy_value = execution_overrides.pop(legacy_param, None)
if legacy_value is not None:
params_dict["execution"] = legacy_value
execution_dict = deep_merge(execution_dict, execution_overrides)
execution_model_overriden = execution_overrides.get("model") is not None
params_prepare_for_save(params_dict, previous_task=task)
artifacts = execution_dict.get("artifacts")
if artifacts:
execution_dict["artifacts"] = [
a for a in artifacts if a.get("mode") != ArtifactModes.output
]
now = datetime.utcnow()
with translate_errors_context():
new_task = Task(
id=create_id(),
user=user_id,
company=company_id,
created=now,
last_update=now,
name=name or task.name,
comment=comment or task.comment,
parent=parent or task.parent,
project=project or task.project,
tags=tags or task.tags,
system_tags=system_tags or [],
type=task.type,
script=task.script,
output=Output(destination=task.output.destination)
if task.output
else None,
execution=execution_dict,
configuration=params_dict.get("configuration") or task.configuration,
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
)
cls.validate(
new_task,
validate_model=validate_references or execution_model_overriden,
validate_parent=validate_references or parent,
validate_project=validate_references or project,
)
new_task.save()
if task.project == new_task.project:
updated_tags = tags
updated_system_tags = system_tags
else:
updated_tags = new_task.tags
updated_system_tags = new_task.system_tags
org_bll.update_tags(
company_id,
Tags.Task,
project=new_task.project,
tags=updated_tags,
system_tags=updated_system_tags,
)
return new_task
@classmethod
def validate(
cls,
task: Task,
validate_model=True,
validate_parent=True,
validate_project=True,
):
if (
validate_parent
and task.parent
and not Task.get(
company=task.company, id=task.parent, _only=("id",), include_public=True
)
):
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
if (
validate_project
and task.project
and not Project.get_for_writing(company=task.company, id=task.project)
):
raise errors.bad_request.InvalidProjectId(id=task.project)
if validate_model:
cls.validate_execution_model(task)
@staticmethod
def get_unique_metric_variants(company_id, project_ids=None):
pipeline = [
{
"$match": dict(
company=company_id,
**({"project": {"$in": project_ids}} if project_ids else {}),
)
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
{"$unwind": "$metrics"},
{
"$project": {
"metric": "$metrics.k",
"variants": {"$objectToArray": "$metrics.v"},
}
},
{"$unwind": "$variants"},
{
"$group": {
"_id": {
"metric": "$variants.v.metric",
"variant": "$variants.v.variant",
},
"metrics": {
"$addToSet": {
"metric": "$variants.v.metric",
"metric_hash": "$metric",
"variant": "$variants.v.variant",
"variant_hash": "$variants.k",
}
},
}
},
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
]
with translate_errors_context():
result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@staticmethod
def set_last_update(
task_ids: Collection[str], company_id: str, last_update: datetime
):
return Task.objects(id__in=task_ids, company=company_id).update(
upsert=False, last_update=last_update
)
@staticmethod
def update_statistics(
task_id: str,
company_id: str,
last_update: datetime = None,
last_iteration: int = None,
last_iteration_max: int = None,
last_scalar_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
last_events: Dict[str, Dict[str, dict]] = None,
**extra_updates,
):
"""
Update task statistics
:param task_id: Task's ID.
:param company_id: Task's company ID.
:param last_update: Last update time. If not provided, defaults to datetime.utcnow().
:param last_iteration: Last reported iteration. Use this to set a value regardless of current
task's last iteration value.
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
if the current task's last iteration value is smaller than the provided value.
:param last_scalar_values: Last reported metrics summary for scalar events (value, metric, variant).
:param last_events: Last reported metrics summary (value, metric, event type).
:param extra_updates: Extra task updates to include in this update call.
:return:
"""
last_update = last_update or datetime.utcnow()
if last_iteration is not None:
extra_updates.update(last_iteration=last_iteration)
elif last_iteration_max is not None:
extra_updates.update(max__last_iteration=last_iteration_max)
if last_scalar_values is not None:
def op_path(op, *path):
return "__".join((op, "last_metrics") + path)
for path, value in last_scalar_values:
if path[-1] == "min_value":
extra_updates[op_path("min", *path[:-1], "min_value")] = value
elif path[-1] == "max_value":
extra_updates[op_path("max", *path[:-1], "max_value")] = value
else:
extra_updates[op_path("set", *path)] = value
if last_events is not None:
def events_per_type(metric_data: Dict[str, dict]) -> Dict[str, EventStats]:
return {
event_type: EventStats(last_update=event["timestamp"])
for event_type, event in metric_data.items()
}
metric_stats = {
dbutils.hash_field_name(metric_key): MetricEventStats(
metric=metric_key, event_stats_by_type=events_per_type(metric_data)
)
for metric_key, metric_data in last_events.items()
}
extra_updates["metric_stats"] = metric_stats
Task.objects(id=task_id, company=company_id).update(
upsert=False, last_update=last_update, **extra_updates
)
@classmethod
def model_set_ready(
cls,
model_id: str,
company_id: str,
publish_task: bool,
force_publish_task: bool = False,
) -> tuple:
with translate_errors_context():
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
elif model.ready:
raise errors.bad_request.ModelIsReady(**query)
published_task_data = {}
if model.task and publish_task:
task = (
Task.objects(id=model.task, company=company_id)
.only("id", "status")
.first()
)
if task and task.status != TaskStatus.published:
published_task_data["data"] = cls.publish_task(
task_id=model.task,
company_id=company_id,
publish_model=False,
force=force_publish_task,
)
published_task_data["id"] = model.task
updated = model.update(upsert=False, ready=True)
return updated, published_task_data
@classmethod
def publish_task(
cls,
task_id: str,
company_id: str,
publish_model: bool,
force: bool,
status_reason: str = "",
status_message: str = "",
) -> dict:
task = cls.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
if not force:
validate_status_change(task.status, TaskStatus.published)
previous_task_status = task.status
output = task.output or Output()
publish_failed = False
try:
# set state to publishing
task.status = TaskStatus.publishing
task.save()
# publish task models
if task.output.model and publish_model:
output_model = (
Model.objects(id=task.output.model)
.only("id", "task", "ready")
.first()
)
if output_model and not output_model.ready:
cls.model_set_ready(
model_id=task.output.model,
company_id=company_id,
publish_task=False,
)
# set task status to published, and update (or set) it's new output (view and models)
return ChangeStatusRequest(
task=task,
new_status=TaskStatus.published,
force=force,
status_reason=status_reason,
status_message=status_message,
).execute(published=datetime.utcnow(), output=output)
except Exception as ex:
publish_failed = True
raise ex
finally:
if publish_failed:
task.status = previous_task_status
task.save()
@classmethod
def stop_task(
cls,
task_id: str,
company_id: str,
user_name: str,
status_reason: str,
force: bool,
) -> dict:
"""
Stop a running task. Requires task status 'in_progress' and
execution_progress 'running', or force=True. Development task or
task that has no associated worker is stopped immediately.
For a non-development task with worker only the status message
is set to 'stopping' to allow the worker to stop the task and report by itself
:return: updated task fields
"""
task = cls.get_task_with_access(
task_id,
company_id=company_id,
only=(
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
),
requires_write_access=True,
)
def is_run_by_worker(t: Task) -> bool:
"""Checks if there is an active worker running the task"""
update_timeout = config.get("apiserver.workers.task_update_timeout", 600)
return (
t.last_worker
and t.last_update
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
)
if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task):
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
new_status = task.status
status_message = TaskStatusMessage.stopping
return ChangeStatusRequest(
task=task,
new_status=new_status,
status_reason=status_reason,
status_message=status_message,
force=force,
).execute()
@classmethod
def add_or_update_artifacts(
cls, task_id: str, company_id: str, artifacts: List[ApiArtifact]
) -> Tuple[List[str], List[str]]:
key = attrgetter("key", "mode")
if not artifacts:
return [], []
with translate_errors_context(), TimingContext("mongo", "update_artifacts"):
artifacts: List[Artifact] = [
Artifact(**artifact.to_struct()) for artifact in artifacts
]
attempts = int(config.get("services.tasks.artifacts.update_attempts", 10))
for retry in range(attempts):
task = cls.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
current = list(map(key, task.execution.artifacts))
updated = [a for a in artifacts if key(a) in current]
added = [a for a in artifacts if a not in updated]
filter = {"_id": task_id, "company": company_id}
update = {}
array_filters = None
if current:
filter["execution.artifacts"] = {
"$size": len(current),
"$all": [
*(
{"$elemMatch": {"key": key, "mode": mode}}
for key, mode in current
)
],
}
else:
filter["$or"] = [
{"execution.artifacts": {"$exists": False}},
{"execution.artifacts": {"$size": 0}},
]
if added:
update["$push"] = {
"execution.artifacts": {"$each": [a.to_mongo() for a in added]}
}
if updated:
update["$set"] = {
f"execution.artifacts.$[artifact{index}]": artifact.to_mongo()
for index, artifact in enumerate(updated)
}
array_filters = [
{
f"artifact{index}.key": artifact.key,
f"artifact{index}.mode": artifact.mode,
}
for index, artifact in enumerate(updated)
]
if not update:
return [], []
result: pymongo.results.UpdateResult = Task._get_collection().update_one(
filter=filter,
update=update,
array_filters=array_filters,
upsert=False,
)
if result.matched_count >= 1:
break
wait_msec = random() * int(
config.get("services.tasks.artifacts.update_retry_msec", 500)
)
log.warning(
f"Failed to update artifacts for task {task_id} (updated by another party),"
f" retrying {retry+1}/{attempts} in {wait_msec}ms"
)
sleep(wait_msec / 1000)
else:
raise errors.server_error.UpdateFailed(
"task artifacts updated by another party"
)
return [a.key for a in added], [a.key for a in updated]
@staticmethod
def get_aggregated_project_parameters(
company_id,
project_ids: Sequence[str] = None,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
"company": company_id,
"hyperparams": {"$exists": True, "$gt": {}},
**({"project": {"$in": project_ids}} if project_ids else {}),
}
},
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
{"$unwind": "$sections"},
{
"$project": {
"section": "$sections.k",
"names": {"$objectToArray": "$sections.v"},
}
},
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
{
"$project": {
"total": 1,
"results": {"$slice": ["$results", page * page_size, page_size]},
}
},
]
with translate_errors_context():
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [
{
"section": ParameterKeyEscaper.unescape(
dpath.get(r, "_id/section")
),
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
}
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results