from collections import defaultdict from typing import Optional, Any, Sequence, Callable, Mapping, Union, Dict, Iterable, Generator from ...backend_api import Session from ...backend_api.services import tasks class HyperParams(object): def __init__(self, task): self.task = task def get_hyper_params( self, sections=None, # type: Optional[Sequence[str]] selector=None, # type: Optional[Callable[[dict], bool]] projector=None, # type: Optional[Callable[[dict], Any]] return_obj=False # type: Optional[bool] ): # type: (...) -> Dict[str, Dict[Union[dict, Any]]] """ Get hyper-parameters for this task. Returns a dictionary mapping user property name to user property details dict. :param sections: Return only hyper-params in the provided sections :param selector: A callable selecting which hyper-parameters should be returned :param projector: A callable to project values before they are returned :param return_obj: If True, returned dictionary values are API objects (tasks.ParamsItem). If ``projeictor """ if not Session.check_min_api_version("2.9"): raise ValueError("Not supported by server") task_id = self.task.task_id res = self.task.session.send(tasks.GetHyperParamsRequest(tasks=[task_id])) hyperparams = defaultdict(defaultdict) if res.ok() and res.response.params: for entry in res.response.params: if entry.get("task") == task_id: for item in entry.get("hyperparams", []): # noinspection PyBroadException try: if (sections and item.get("section") not in sections) or ( selector and not selector(item) ): continue item = item if not projector else projector(item) if return_obj: item = tasks.ParamsItem() hyperparams[item.get("section")][item.get("name")] = ( item if not projector else projector(item) ) except Exception: self.task.log.exception("Failed processing hyper-parameter") return hyperparams def edit_hyper_params( self, *iterables, # type: Union[Mapping[str, Union[str, dict, None]], Iterable[dict, tasks.ParamsItem]] replace=None, # type: Optional[str] default_section=None, # type: Optional[str] force_section=None # type: Optional[str] ): # type: (...) -> bool """ Set hyper-parameters for this task. :param iterables: Hyper parameter iterables, each can be: * A dictionary of string key (name) to either a string value (value), a tasks.ParamsItem or a dict (hyperparam details). If ``default_section`` is not provided, each dict must contain a "section" field. * An iterable of tasks.ParamsItem or dicts (each representing hyperparam details). Each dict must contain a "name" field. If ``default_section`` is not provided, each dict must also contain a "section" field. :param replace: Optional replace strategy, values are: * 'all' - provided hyper-params replace all existing hyper-params in task * 'section' - only sections present in the provided hyper-params are replaced * 'none' (default) - provided hyper-params will be merged into existing task hyper-params (i.e. will be added or update existing hyper-params) :param default_section: Optional section name to be used when section is not explicitly provided. :param force_section: Optional section name to be used for all hyper-params. """ if not Session.check_min_api_version("2.9"): raise ValueError("Not supported by server") escape_unsafe = not Session.check_min_api_version("2.11") if not tasks.ReplaceHyperparamsEnum.has_value(replace): replace = None def make_item(value, name=None): if isinstance(value, tasks.ParamsItem): item = value elif isinstance(value, dict): item = tasks.ParamsItem(**value) else: item = tasks.ParamsItem(value=str(value)) if name: item.name = str(name) if not item.name: raise ValueError("Missing hyper-param name for '{}'".format(value)) section = force_section or item.section or default_section if not section: raise ValueError("Missing hyper-param section for '{}'".format(value)) if escape_unsafe: item.section, item.name = self._escape_unsafe_values(section, item.name) else: item.section = section return item props = {} for i in iterables: if isinstance(i, dict): props.update({name: make_item(value, name) for name, value in i.items()}) else: props.update({item.name: item for item in map(make_item, i)}) res = self.task.session.send( tasks.EditHyperParamsRequest( task=self.task.task_id, hyperparams=props.values(), replace_hyperparams=replace, ), ) if res.ok(): self.task.reload() return True return False def delete_hyper_params(self, *iterables): # type: (Iterable[Union[dict, Iterable[str, str], tasks.ParamKey, tasks.ParamsItem]]) -> bool """ Delete hyper-parameters for this task. :param iterables: Hyper parameter key iterables. Each an iterable whose possible values each represent a hyper-parameter entry to delete, value formats are: * A dictionary containing a 'section' and 'name' fields * An iterable (e.g. tuple, list etc.) whose first two items denote 'section' and 'name' * An API object of type tasks.ParamKey or tasks.ParamsItem whose section and name fields are not empty """ if not Session.check_min_api_version("2.9"): raise ValueError("Not supported by server") def get_key(value): if isinstance(value, dict): key = (value.get("section"), value.get("name")) elif isinstance(value, (tasks.ParamKey, tasks.ParamsItem)): key = (value.section, value.name) else: key = tuple(map(str, value))[:2] if not all(key): raise ValueError("Missing section or name in '{}'".format(value)) return key keys = {get_key(value) for iterable in iterables for value in iterable} res = self.task.session.send( tasks.DeleteHyperParamsRequest( task=self.task.task_id, hyperparams=[tasks.ParamKey(section=section, name=name) for section, name in keys], ), ) if res.ok(): self.task.reload() return True return False def _escape_unsafe_values(self, *values): # type: (str) -> Generator[str] """ Escape unsafe values (name, section name) for API version 2.10 and below """ for value in values: if value not in UNSAFE_NAMES_2_10: yield value else: self.task.log.info( "Converting unsafe hyper parameter name/section '{}' to '{}'".format(value, "_" + value) ) yield "_" + value UNSAFE_NAMES_2_10 = { "ne", "gt", "gte", "lt", "lte", "in", "nin", "mod", "all", "size", "exists", "not", "elemMatch", "type", "within_distance", "within_spherical_distance", "within_box", "within_polygon", "near", "near_sphere", "max_distance", "min_distance", "geo_within", "geo_within_box", "geo_within_polygon", "geo_within_center", "geo_within_sphere", "geo_intersects", "contains", "icontains", "startswith", "istartswith", "endswith", "iendswith", "exact", "iexact", "match", }