clearml-server/apiserver/bll/task/hyperparams.py

231 lines
8.1 KiB
Python
Raw Normal View History

from datetime import datetime
from itertools import chain
from operator import attrgetter
from typing import Sequence, Dict
from boltons import iterutils
2021-01-05 14:28:49 +00:00
from apiserver.apierrors import errors
from apiserver.apimodels.tasks import (
HyperParamKey,
HyperParamItem,
ReplaceHyperparams,
Configuration,
)
2021-01-05 14:28:49 +00:00
from apiserver.bll.task import TaskBLL
2021-01-05 14:40:35 +00:00
from apiserver.bll.task.utils import get_task_for_update
2021-01-05 14:44:31 +00:00
from apiserver.config_repo import config
2021-01-05 14:40:35 +00:00
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
)
log = config.logger(__file__)
task_bll = TaskBLL()
class HyperParams:
_properties_section = "properties"
@classmethod
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
only = ("id", "hyperparams")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
)
return {
task.id: {"hyperparams": cls._get_params_list(items=task.hyperparams)}
for task in tasks
}
@classmethod
def _get_params_list(
cls, items: Dict[str, Dict[str, ParamsItem]]
) -> Sequence[dict]:
ret = list(chain.from_iterable(v.values() for v in items.values()))
return [
p.to_proper_dict() for p in sorted(ret, key=attrgetter("section", "name"))
]
@classmethod
def _normalize_params(cls, params: Sequence) -> bool:
"""
Lower case properties section and return True if it is the only section
"""
for p in params:
if p.section.lower() == cls._properties_section:
p.section = cls._properties_section
return all(p.section == cls._properties_section for p in params)
@classmethod
def delete_params(
2021-01-05 14:40:35 +00:00
cls, company_id: str, task_id: str, hyperparams: Sequence[HyperParamKey]
) -> int:
2021-01-05 14:40:35 +00:00
with TimingContext("mongo", "delete_hyperparams"):
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
2021-01-05 14:40:35 +00:00
)
2021-01-05 14:40:35 +00:00
with_param, without_param = iterutils.partition(
hyperparams, key=lambda p: bool(p.name)
)
sections_to_delete = {p.section for p in without_param}
delete_cmds = {
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
for section in sections_to_delete
}
2021-01-05 14:40:35 +00:00
for item in with_param:
section = ParameterKeyEscaper.escape(item.section)
if item.section in sections_to_delete:
raise errors.bad_request.FieldsConflict(
"Cannot delete section field if the whole section was scheduled for deletion"
)
name = ParameterKeyEscaper.escape(item.name)
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
2021-01-05 14:40:35 +00:00
return task.update(**delete_cmds, last_update=datetime.utcnow())
@classmethod
def edit_params(
cls,
company_id: str,
task_id: str,
hyperparams: Sequence[HyperParamItem],
replace_hyperparams: str,
) -> int:
2021-01-05 14:40:35 +00:00
with TimingContext("mongo", "edit_hyperparams"):
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
2021-01-05 14:40:35 +00:00
)
2021-01-05 14:40:35 +00:00
update_cmds = dict()
hyperparams = cls._db_dicts_from_list(hyperparams)
if replace_hyperparams == ReplaceHyperparams.all:
update_cmds["set__hyperparams"] = hyperparams
elif replace_hyperparams == ReplaceHyperparams.section:
for section, value in hyperparams.items():
update_cmds[
f"set__hyperparams__{mongoengine_safe(section)}"
] = value
2021-01-05 14:40:35 +00:00
else:
for section, section_params in hyperparams.items():
for name, value in section_params.items():
update_cmds[
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
] = value
2021-01-05 14:40:35 +00:00
return task.update(**update_cmds, last_update=datetime.utcnow())
@classmethod
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
sections = iterutils.bucketize(items, key=attrgetter("section"))
return {
ParameterKeyEscaper.escape(section): {
ParameterKeyEscaper.escape(param.name): ParamsItem(**param.to_struct())
for param in params
}
for section, params in sections.items()
}
@classmethod
def get_configurations(
cls, company_id: str, task_ids: Sequence[str], names: Sequence[str]
) -> Dict[str, dict]:
only = ["id"]
if names:
only.extend(
f"configuration.{ParameterKeyEscaper.escape(name)}" for name in names
)
else:
only.append("configuration")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
)
return {
task.id: {
"configuration": [
c.to_proper_dict()
for c in sorted(task.configuration.values(), key=attrgetter("name"))
]
}
for task in tasks
}
@classmethod
def get_configuration_names(
cls, company_id: str, task_ids: Sequence[str]
) -> Dict[str, list]:
2021-01-05 14:40:35 +00:00
with TimingContext("mongo", "get_configuration_names"):
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"_id": {"$in": task_ids},
}
},
{"$project": {"items": {"$objectToArray": "$configuration"}}},
{"$unwind": "$items"},
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
tasks = Task.aggregate(pipeline)
return {
task["_id"]: {
"names": sorted(
ParameterKeyEscaper.unescape(name) for name in task["names"]
)
}
2021-01-05 14:40:35 +00:00
for task in tasks
}
@classmethod
def edit_configuration(
cls,
company_id: str,
task_id: str,
configuration: Sequence[Configuration],
replace_configuration: bool,
) -> int:
2021-01-05 14:40:35 +00:00
with TimingContext("mongo", "edit_configuration"):
task = get_task_for_update(company_id=company_id, task_id=task_id)
2021-01-05 14:40:35 +00:00
update_cmds = dict()
configuration = {
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
for c in configuration
}
if replace_configuration:
update_cmds["set__configuration"] = configuration
else:
for name, value in configuration.items():
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
2021-01-05 14:40:35 +00:00
return task.update(**update_cmds, last_update=datetime.utcnow())
@classmethod
def delete_configuration(
2021-01-05 14:40:35 +00:00
cls, company_id: str, task_id: str, configuration: Sequence[str]
) -> int:
2021-01-05 14:40:35 +00:00
with TimingContext("mongo", "delete_configuration"):
task = get_task_for_update(company_id=company_id, task_id=task_id)
2021-01-05 14:40:35 +00:00
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
for name in set(configuration)
}
2021-01-05 14:40:35 +00:00
return task.update(**delete_cmds, last_update=datetime.utcnow())