from itertools import chain from operator import attrgetter from typing import Sequence, Dict from boltons import iterutils from apiserver.apierrors import errors from apiserver.apimodels.tasks import ( HyperParamKey, HyperParamItem, ReplaceHyperparams, Configuration, ) from apiserver.bll.task import TaskBLL from apiserver.bll.task.utils import get_task_for_update, update_task from apiserver.config_repo import config from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem from apiserver.service_repo.auth import Identity from apiserver.utilities.parameter_key_escaper import ( ParameterKeyEscaper, mongoengine_safe, ) log = config.logger(__file__) task_bll = TaskBLL() class HyperParams: _properties_section = "properties" @classmethod def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]: only = ("id", "hyperparams") tasks = task_bll.assert_exists( company_id=company_id, task_ids=task_ids, only=only, allow_public=True, ) return { task.id: {"hyperparams": cls._get_params_list(items=task.hyperparams)} for task in tasks } @classmethod def _get_params_list( cls, items: Dict[str, Dict[str, ParamsItem]] ) -> Sequence[dict]: ret = list(chain.from_iterable(v.values() for v in items.values())) return [ p.to_proper_dict() for p in sorted(ret, key=attrgetter("section", "name")) ] @classmethod def _normalize_params(cls, params: Sequence) -> bool: """ Lower case properties section and return True if it is the only section """ for p in params: if p.section.lower() == cls._properties_section: p.section = cls._properties_section return all(p.section == cls._properties_section for p in params) @classmethod def delete_params( cls, company_id: str, identity: Identity, task_id: str, hyperparams: Sequence[HyperParamKey], force: bool, ) -> int: properties_only = cls._normalize_params(hyperparams) task = get_task_for_update( company_id=company_id, task_id=task_id, allow_all_statuses=properties_only, force=force, identity=identity, ) with_param, without_param = iterutils.partition( hyperparams, key=lambda p: bool(p.name) ) sections_to_delete = {p.section for p in without_param} delete_cmds = { f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1 for section in sections_to_delete } for item in with_param: section = ParameterKeyEscaper.escape(item.section) if item.section in sections_to_delete: raise errors.bad_request.FieldsConflict( "Cannot delete section field if the whole section was scheduled for deletion" ) name = ParameterKeyEscaper.escape(item.name) delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1 return update_task( task, user_id=identity.user, update_cmds=delete_cmds, set_last_update=not properties_only, ) @classmethod def edit_params( cls, company_id: str, identity: Identity, task_id: str, hyperparams: Sequence[HyperParamItem], replace_hyperparams: str, force: bool, ) -> int: properties_only = cls._normalize_params(hyperparams) task = get_task_for_update( company_id=company_id, task_id=task_id, allow_all_statuses=properties_only, force=force, identity=identity, ) update_cmds = dict() hyperparams = cls._db_dicts_from_list(hyperparams) if replace_hyperparams == ReplaceHyperparams.all: update_cmds["set__hyperparams"] = hyperparams elif replace_hyperparams == ReplaceHyperparams.section: for section, value in hyperparams.items(): update_cmds[f"set__hyperparams__{mongoengine_safe(section)}"] = value else: for section, section_params in hyperparams.items(): for name, value in section_params.items(): update_cmds[ f"set__hyperparams__{section}__{mongoengine_safe(name)}" ] = value return update_task( task, user_id=identity.user, update_cmds=update_cmds, set_last_update=not properties_only, ) @classmethod def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]: sections = iterutils.bucketize(items, key=attrgetter("section")) return { ParameterKeyEscaper.escape(section): { ParameterKeyEscaper.escape(param.name): ParamsItem(**param.to_struct()) for param in params } for section, params in sections.items() } @classmethod def get_configurations( cls, company_id: str, task_ids: Sequence[str], names: Sequence[str] ) -> Dict[str, dict]: only = ["id"] if names: only.extend( f"configuration.{ParameterKeyEscaper.escape(name)}" for name in names ) else: only.append("configuration") tasks = task_bll.assert_exists( company_id=company_id, task_ids=task_ids, only=only, allow_public=True, ) return { task.id: { "configuration": [ c.to_proper_dict() for c in sorted(task.configuration.values(), key=attrgetter("name")) ] } for task in tasks } @classmethod def get_configuration_names( cls, company_id: str, task_ids: Sequence[str], skip_empty: bool ) -> Dict[str, list]: skip_empty_condition = {"$match": {"items.v.value": {"$nin": [None, ""]}}} pipeline = [ { "$match": { "company": {"$in": [None, "", company_id]}, "_id": {"$in": task_ids}, } }, {"$project": {"items": {"$objectToArray": "$configuration"}}}, {"$unwind": "$items"}, *([skip_empty_condition] if skip_empty else []), {"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}}, ] tasks = Task.aggregate(pipeline) return { task["_id"]: { "names": sorted( ParameterKeyEscaper.unescape(name) for name in task["names"] ) } for task in tasks } @classmethod def edit_configuration( cls, company_id: str, identity: Identity, task_id: str, configuration: Sequence[Configuration], replace_configuration: bool, force: bool, ) -> int: task = get_task_for_update( company_id=company_id, task_id=task_id, force=force, identity=identity ) update_cmds = dict() configuration = { ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct()) for c in configuration } if replace_configuration: update_cmds["set__configuration"] = configuration else: for name, value in configuration.items(): update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value return update_task(task, user_id=identity.user, update_cmds=update_cmds) @classmethod def delete_configuration( cls, company_id: str, identity: Identity, task_id: str, configuration: Sequence[str], force: bool, ) -> int: task = get_task_for_update( company_id=company_id, task_id=task_id, force=force, identity=identity ) delete_cmds = { f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1 for name in set(configuration) } return update_task(task, user_id=identity.user, update_cmds=delete_cmds)