clearml-server/apiserver/bll/task/hyperparams.py
allegroai 0303c3525f API version bump
Update internal tests
Allow edit/delete task artifacts/hyperparams/configs using force flag
Improve lists query support for get_all calls
2021-01-05 17:57:58 +02:00

243 lines
8.3 KiB
Python

from datetime import datetime
from itertools import chain
from operator import attrgetter
from typing import Sequence, Dict
from boltons import iterutils
from apiserver.apierrors import errors
from apiserver.apimodels.tasks import (
HyperParamKey,
HyperParamItem,
ReplaceHyperparams,
Configuration,
)
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import get_task_for_update
from apiserver.config_repo import config
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
)
log = config.logger(__file__)
task_bll = TaskBLL()
class HyperParams:
_properties_section = "properties"
@classmethod
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
only = ("id", "hyperparams")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
)
return {
task.id: {"hyperparams": cls._get_params_list(items=task.hyperparams)}
for task in tasks
}
@classmethod
def _get_params_list(
cls, items: Dict[str, Dict[str, ParamsItem]]
) -> Sequence[dict]:
ret = list(chain.from_iterable(v.values() for v in items.values()))
return [
p.to_proper_dict() for p in sorted(ret, key=attrgetter("section", "name"))
]
@classmethod
def _normalize_params(cls, params: Sequence) -> bool:
"""
Lower case properties section and return True if it is the only section
"""
for p in params:
if p.section.lower() == cls._properties_section:
p.section = cls._properties_section
return all(p.section == cls._properties_section for p in params)
@classmethod
def delete_params(
cls,
company_id: str,
task_id: str,
hyperparams: Sequence[HyperParamKey],
force: bool,
) -> int:
with TimingContext("mongo", "delete_hyperparams"):
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
)
with_param, without_param = iterutils.partition(
hyperparams, key=lambda p: bool(p.name)
)
sections_to_delete = {p.section for p in without_param}
delete_cmds = {
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
for section in sections_to_delete
}
for item in with_param:
section = ParameterKeyEscaper.escape(item.section)
if item.section in sections_to_delete:
raise errors.bad_request.FieldsConflict(
"Cannot delete section field if the whole section was scheduled for deletion"
)
name = ParameterKeyEscaper.escape(item.name)
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
return task.update(**delete_cmds, last_update=datetime.utcnow())
@classmethod
def edit_params(
cls,
company_id: str,
task_id: str,
hyperparams: Sequence[HyperParamItem],
replace_hyperparams: str,
force: bool,
) -> int:
with TimingContext("mongo", "edit_hyperparams"):
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
)
update_cmds = dict()
hyperparams = cls._db_dicts_from_list(hyperparams)
if replace_hyperparams == ReplaceHyperparams.all:
update_cmds["set__hyperparams"] = hyperparams
elif replace_hyperparams == ReplaceHyperparams.section:
for section, value in hyperparams.items():
update_cmds[
f"set__hyperparams__{mongoengine_safe(section)}"
] = value
else:
for section, section_params in hyperparams.items():
for name, value in section_params.items():
update_cmds[
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
] = value
return task.update(**update_cmds, last_update=datetime.utcnow())
@classmethod
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
sections = iterutils.bucketize(items, key=attrgetter("section"))
return {
ParameterKeyEscaper.escape(section): {
ParameterKeyEscaper.escape(param.name): ParamsItem(**param.to_struct())
for param in params
}
for section, params in sections.items()
}
@classmethod
def get_configurations(
cls, company_id: str, task_ids: Sequence[str], names: Sequence[str]
) -> Dict[str, dict]:
only = ["id"]
if names:
only.extend(
f"configuration.{ParameterKeyEscaper.escape(name)}" for name in names
)
else:
only.append("configuration")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
)
return {
task.id: {
"configuration": [
c.to_proper_dict()
for c in sorted(task.configuration.values(), key=attrgetter("name"))
]
}
for task in tasks
}
@classmethod
def get_configuration_names(
cls, company_id: str, task_ids: Sequence[str]
) -> Dict[str, list]:
with TimingContext("mongo", "get_configuration_names"):
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"_id": {"$in": task_ids},
}
},
{"$project": {"items": {"$objectToArray": "$configuration"}}},
{"$unwind": "$items"},
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
tasks = Task.aggregate(pipeline)
return {
task["_id"]: {
"names": sorted(
ParameterKeyEscaper.unescape(name) for name in task["names"]
)
}
for task in tasks
}
@classmethod
def edit_configuration(
cls,
company_id: str,
task_id: str,
configuration: Sequence[Configuration],
replace_configuration: bool,
force: bool,
) -> int:
with TimingContext("mongo", "edit_configuration"):
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force
)
update_cmds = dict()
configuration = {
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
for c in configuration
}
if replace_configuration:
update_cmds["set__configuration"] = configuration
else:
for name, value in configuration.items():
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
return task.update(**update_cmds, last_update=datetime.utcnow())
@classmethod
def delete_configuration(
cls, company_id: str, task_id: str, configuration: Sequence[str], force: bool
) -> int:
with TimingContext("mongo", "delete_configuration"):
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force
)
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
for name in set(configuration)
}
return task.update(**delete_cmds, last_update=datetime.utcnow())