mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 19:06:55 +00:00
ae4c33fa0e
Add `last_changed_by` field on task updates Fix reports support
250 lines
8.0 KiB
Python
250 lines
8.0 KiB
Python
from itertools import chain
|
|
from operator import attrgetter
|
|
from typing import Sequence, Dict
|
|
|
|
from boltons import iterutils
|
|
|
|
from apiserver.apierrors import errors
|
|
from apiserver.apimodels.tasks import (
|
|
HyperParamKey,
|
|
HyperParamItem,
|
|
ReplaceHyperparams,
|
|
Configuration,
|
|
)
|
|
from apiserver.bll.task import TaskBLL
|
|
from apiserver.bll.task.utils import get_task_for_update, update_task
|
|
from apiserver.config_repo import config
|
|
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
|
|
from apiserver.utilities.parameter_key_escaper import (
|
|
ParameterKeyEscaper,
|
|
mongoengine_safe,
|
|
)
|
|
|
|
log = config.logger(__file__)
|
|
task_bll = TaskBLL()
|
|
|
|
|
|
class HyperParams:
|
|
_properties_section = "properties"
|
|
|
|
@classmethod
|
|
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
|
|
only = ("id", "hyperparams")
|
|
tasks = task_bll.assert_exists(
|
|
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
|
)
|
|
|
|
return {
|
|
task.id: {"hyperparams": cls._get_params_list(items=task.hyperparams)}
|
|
for task in tasks
|
|
}
|
|
|
|
@classmethod
|
|
def _get_params_list(
|
|
cls, items: Dict[str, Dict[str, ParamsItem]]
|
|
) -> Sequence[dict]:
|
|
ret = list(chain.from_iterable(v.values() for v in items.values()))
|
|
return [
|
|
p.to_proper_dict() for p in sorted(ret, key=attrgetter("section", "name"))
|
|
]
|
|
|
|
@classmethod
|
|
def _normalize_params(cls, params: Sequence) -> bool:
|
|
"""
|
|
Lower case properties section and return True if it is the only section
|
|
"""
|
|
for p in params:
|
|
if p.section.lower() == cls._properties_section:
|
|
p.section = cls._properties_section
|
|
|
|
return all(p.section == cls._properties_section for p in params)
|
|
|
|
@classmethod
|
|
def delete_params(
|
|
cls,
|
|
company_id: str,
|
|
user_id: str,
|
|
task_id: str,
|
|
hyperparams: Sequence[HyperParamKey],
|
|
force: bool,
|
|
) -> int:
|
|
properties_only = cls._normalize_params(hyperparams)
|
|
task = get_task_for_update(
|
|
company_id=company_id,
|
|
task_id=task_id,
|
|
allow_all_statuses=properties_only,
|
|
force=force,
|
|
)
|
|
|
|
with_param, without_param = iterutils.partition(
|
|
hyperparams, key=lambda p: bool(p.name)
|
|
)
|
|
sections_to_delete = {p.section for p in without_param}
|
|
delete_cmds = {
|
|
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
|
|
for section in sections_to_delete
|
|
}
|
|
|
|
for item in with_param:
|
|
section = ParameterKeyEscaper.escape(item.section)
|
|
if item.section in sections_to_delete:
|
|
raise errors.bad_request.FieldsConflict(
|
|
"Cannot delete section field if the whole section was scheduled for deletion"
|
|
)
|
|
name = ParameterKeyEscaper.escape(item.name)
|
|
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
|
|
|
|
return update_task(
|
|
task,
|
|
user_id=user_id,
|
|
update_cmds=delete_cmds,
|
|
set_last_update=not properties_only,
|
|
)
|
|
|
|
@classmethod
|
|
def edit_params(
|
|
cls,
|
|
company_id: str,
|
|
user_id: str,
|
|
task_id: str,
|
|
hyperparams: Sequence[HyperParamItem],
|
|
replace_hyperparams: str,
|
|
force: bool,
|
|
) -> int:
|
|
properties_only = cls._normalize_params(hyperparams)
|
|
task = get_task_for_update(
|
|
company_id=company_id,
|
|
task_id=task_id,
|
|
allow_all_statuses=properties_only,
|
|
force=force,
|
|
)
|
|
|
|
update_cmds = dict()
|
|
hyperparams = cls._db_dicts_from_list(hyperparams)
|
|
if replace_hyperparams == ReplaceHyperparams.all:
|
|
update_cmds["set__hyperparams"] = hyperparams
|
|
elif replace_hyperparams == ReplaceHyperparams.section:
|
|
for section, value in hyperparams.items():
|
|
update_cmds[f"set__hyperparams__{mongoengine_safe(section)}"] = value
|
|
else:
|
|
for section, section_params in hyperparams.items():
|
|
for name, value in section_params.items():
|
|
update_cmds[
|
|
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
|
|
] = value
|
|
|
|
return update_task(
|
|
task,
|
|
user_id=user_id,
|
|
update_cmds=update_cmds,
|
|
set_last_update=not properties_only,
|
|
)
|
|
|
|
@classmethod
|
|
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
|
|
sections = iterutils.bucketize(items, key=attrgetter("section"))
|
|
return {
|
|
ParameterKeyEscaper.escape(section): {
|
|
ParameterKeyEscaper.escape(param.name): ParamsItem(**param.to_struct())
|
|
for param in params
|
|
}
|
|
for section, params in sections.items()
|
|
}
|
|
|
|
@classmethod
|
|
def get_configurations(
|
|
cls, company_id: str, task_ids: Sequence[str], names: Sequence[str]
|
|
) -> Dict[str, dict]:
|
|
only = ["id"]
|
|
if names:
|
|
only.extend(
|
|
f"configuration.{ParameterKeyEscaper.escape(name)}" for name in names
|
|
)
|
|
else:
|
|
only.append("configuration")
|
|
tasks = task_bll.assert_exists(
|
|
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
|
)
|
|
|
|
return {
|
|
task.id: {
|
|
"configuration": [
|
|
c.to_proper_dict()
|
|
for c in sorted(task.configuration.values(), key=attrgetter("name"))
|
|
]
|
|
}
|
|
for task in tasks
|
|
}
|
|
|
|
@classmethod
|
|
def get_configuration_names(
|
|
cls, company_id: str, task_ids: Sequence[str], skip_empty: bool
|
|
) -> Dict[str, list]:
|
|
skip_empty_condition = {"$match": {"items.v.value": {"$nin": [None, ""]}}}
|
|
pipeline = [
|
|
{
|
|
"$match": {
|
|
"company": {"$in": [None, "", company_id]},
|
|
"_id": {"$in": task_ids},
|
|
}
|
|
},
|
|
{"$project": {"items": {"$objectToArray": "$configuration"}}},
|
|
{"$unwind": "$items"},
|
|
*([skip_empty_condition] if skip_empty else []),
|
|
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
|
|
]
|
|
|
|
tasks = Task.aggregate(pipeline)
|
|
|
|
return {
|
|
task["_id"]: {
|
|
"names": sorted(
|
|
ParameterKeyEscaper.unescape(name) for name in task["names"]
|
|
)
|
|
}
|
|
for task in tasks
|
|
}
|
|
|
|
@classmethod
|
|
def edit_configuration(
|
|
cls,
|
|
company_id: str,
|
|
user_id: str,
|
|
task_id: str,
|
|
configuration: Sequence[Configuration],
|
|
replace_configuration: bool,
|
|
force: bool,
|
|
) -> int:
|
|
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
|
|
|
|
update_cmds = dict()
|
|
configuration = {
|
|
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
|
|
for c in configuration
|
|
}
|
|
if replace_configuration:
|
|
update_cmds["set__configuration"] = configuration
|
|
else:
|
|
for name, value in configuration.items():
|
|
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
|
|
|
|
return update_task(task, user_id=user_id, update_cmds=update_cmds)
|
|
|
|
@classmethod
|
|
def delete_configuration(
|
|
cls,
|
|
company_id: str,
|
|
user_id: str,
|
|
task_id: str,
|
|
configuration: Sequence[str],
|
|
force: bool,
|
|
) -> int:
|
|
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
|
|
|
|
delete_cmds = {
|
|
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
|
for name in set(configuration)
|
|
}
|
|
|
|
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
|