From c328fbf3453960a123a356b1be69d0f77d41e9a3 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 3 Nov 2020 10:50:00 +0200 Subject: [PATCH] Add user properties support in Task object --- trains/backend_api/services/v2_9/tasks.py | 26 +-- trains/backend_api/session/datamodel.py | 4 + trains/backend_interface/task/hyperparams.py | 216 +++++++++++++++++++ trains/backend_interface/task/task.py | 10 +- trains/task.py | 106 +++++++++ 5 files changed, 342 insertions(+), 20 deletions(-) create mode 100644 trains/backend_interface/task/hyperparams.py diff --git a/trains/backend_api/services/v2_9/tasks.py b/trains/backend_api/services/v2_9/tasks.py index 2027d610..4e17c174 100644 --- a/trains/backend_api/services/v2_9/tasks.py +++ b/trains/backend_api/services/v2_9/tasks.py @@ -3838,10 +3838,8 @@ class DeleteHyperParamsRequest(Request): self._property_hyperparams = None return - self.assert_isinstance(value, "hyperparams", dict) - self.assert_isinstance(value.keys(), "hyperparams_keys", six.string_types, is_array=True) - self.assert_isinstance(value.values(), "hyperparams_values", (SectionParams, dict), is_array=True) - value = dict((k, SectionParams(**v) if isinstance(v, dict) else v) for k, v in value.items()) + self.assert_isinstance(value, "hyperparams", (ParamKey, dict), is_array=True) + value = [(ParamKey(**v) if isinstance(v, dict) else v) for v in value] self._property_hyperparams = value @@ -4745,11 +4743,8 @@ class EditConfigurationRequest(Request): self._property_configuration = None return - self.assert_isinstance(value, "configuration", dict) - self.assert_isinstance(value.keys(), "configuration_keys", six.string_types, is_array=True) - self.assert_isinstance(value.values(), "configuration_values", (ConfigurationItem, dict), is_array=True) - - value = dict((k, ConfigurationItem(**v) if isinstance(v, dict) else v) for k, v in value.items()) + self.assert_isinstance(value, "configuration", (dict, ConfigurationItem), is_array=True) + value = [(ConfigurationItem(**v) if isinstance(v, dict) else v) for v in value] self._property_configuration = value @@ -4905,10 +4900,8 @@ class EditHyperParamsRequest(Request): self._property_hyperparams = None return - self.assert_isinstance(value, "hyperparams", dict) - self.assert_isinstance(value.keys(), "hyperparams_keys", six.string_types, is_array=True) - self.assert_isinstance(value.values(), "hyperparams_values", (SectionParams, dict), is_array=True) - value = dict((k, SectionParams(**v) if isinstance(v, dict) else v) for k, v in value.items()) + self.assert_isinstance(value, "hyperparams", (dict, ParamsItem), is_array=True) + value = [(ParamsItem(**v) if isinstance(v, dict) else v) for v in value] self._property_hyperparams = value @@ -7000,7 +6993,8 @@ class GetHyperParamsResponse(Response): 'properties': { 'params': { 'description': 'Hyper parameters (keyed by task ID)', - 'type': ['object', 'null'], + 'type': 'array', + 'items': {'type': 'object'} }, }, 'type': 'object', @@ -7021,8 +7015,8 @@ class GetHyperParamsResponse(Response): self._property_params = None return - self.assert_isinstance(value, "params", (dict,)) - self._property_params = value + self.assert_isinstance(value, "params", (dict,), is_array=True) + self._property_params = list(value) class GetTypesRequest(Request): diff --git a/trains/backend_api/session/datamodel.py b/trains/backend_api/session/datamodel.py index 72d96382..f688062a 100644 --- a/trains/backend_api/session/datamodel.py +++ b/trains/backend_api/session/datamodel.py @@ -173,3 +173,7 @@ class StringEnum(Enum): def __str__(self): return self.value + + @classmethod + def has_value(cls, value): + return value in cls._value2member_map_ diff --git a/trains/backend_interface/task/hyperparams.py b/trains/backend_interface/task/hyperparams.py new file mode 100644 index 00000000..6744e3c7 --- /dev/null +++ b/trains/backend_interface/task/hyperparams.py @@ -0,0 +1,216 @@ +from collections import defaultdict +from typing import Optional, Any, Sequence, Callable, Mapping, Union, Dict, Iterable, Generator + +from ...backend_api import Session +from ...backend_api.services import tasks + + +class HyperParams(object): + def __init__(self, task): + self.task = task + + def get_hyper_params( + self, + sections=None, # type: Optional[Sequence[str]] + selector=None, # type: Optional[Callable[[dict], bool]] + projector=None, # type: Optional[Callable[[dict], Any]] + return_obj=False # type: Optional[bool] + ): + # type: (...) -> Dict[str, Dict[Union[dict, Any]]] + """ + Get hyper-parameters for this task. + Returns a dictionary mapping user property name to user property details dict. + :param sections: Return only hyper-params in the provided sections + :param selector: A callable selecting which hyper-parameters should be returned + :param projector: A callable to project values before they are returned + :param return_obj: If True, returned dictionary values are API objects (tasks.ParamsItem). If ``projeictor + """ + if not Session.check_min_api_version("2.9"): + raise ValueError("Not supported by server") + + task_id = self.task.task_id + + res = self.task.session.send(tasks.GetHyperParamsRequest(tasks=[task_id])) + hyperparams = defaultdict(defaultdict) + if res.ok() and res.response.params: + for entry in res.response.params: + if entry.get("task") == task_id: + for item in entry.get("hyperparams", []): + # noinspection PyBroadException + try: + if (sections and item.get("section") not in sections) or ( + selector and not selector(item) + ): + continue + item = item if not projector else projector(item) + if return_obj: + item = tasks.ParamsItem() + hyperparams[item.get("section")][item.get("name")] = ( + item if not projector else projector(item) + ) + except Exception: + self.task.log.exception("Failed processing hyper-parameter") + return hyperparams + + def edit_hyper_params( + self, + *iterables, # type: Union[Mapping[str, Union[str, dict, None]], Iterable[dict, tasks.ParamsItem]] + replace=None, # type: Optional[str] + default_section=None, # type: Optional[str] + force_section=None # type: Optional[str] + ): + # type: (...) -> bool + """ + Set hyper-parameters for this task. + :param iterables: Hyper parameter iterables, each can be: + * A dictionary of string key (name) to either a string value (value), a tasks.ParamsItem or a dict + (hyperparam details). If ``default_section`` is not provided, each dict must contain a "section" field. + * An iterable of tasks.ParamsItem or dicts (each representing hyperparam details). + Each dict must contain a "name" field. If ``default_section`` is not provided, each dict must + also contain a "section" field. + :param replace: Optional replace strategy, values are: + * 'all' - provided hyper-params replace all existing hyper-params in task + * 'section' - only sections present in the provided hyper-params are replaced + * 'none' (default) - provided hyper-params will be merged into existing task hyper-params (i.e. will be + added or update existing hyper-params) + :param default_section: Optional section name to be used when section is not explicitly provided. + :param force_section: Optional section name to be used for all hyper-params. + """ + if not Session.check_min_api_version("2.9"): + raise ValueError("Not supported by server") + + escape_unsafe = not Session.check_min_api_version("2.11") + + if not tasks.ReplaceHyperparamsEnum.has_value(replace): + replace = None + + def make_item(value, name=None): + if isinstance(value, tasks.ParamsItem): + item = value + elif isinstance(value, dict): + item = tasks.ParamsItem(**value) + else: + item = tasks.ParamsItem(value=str(value)) + if name: + item.name = str(name) + if not item.name: + raise ValueError("Missing hyper-param name for '{}'".format(value)) + section = force_section or item.section or default_section + if not section: + raise ValueError("Missing hyper-param section for '{}'".format(value)) + if escape_unsafe: + item.section, item.name = self._escape_unsafe_values(section, item.name) + else: + item.section = section + return item + + props = {} + for i in iterables: + if isinstance(i, dict): + props.update({name: make_item(value, name) for name, value in i.items()}) + else: + props.update({item.name: item for item in map(make_item, i)}) + + res = self.task.session.send( + tasks.EditHyperParamsRequest( + task=self.task.task_id, + hyperparams=props.values(), + replace_hyperparams=replace, + ), + ) + if res.ok(): + self.task.reload() + return True + + return False + + def delete_hyper_params(self, *iterables): + # type: (Iterable[Union[dict, Iterable[str, str], tasks.ParamKey, tasks.ParamsItem]]) -> bool + """ + Delete hyper-parameters for this task. + :param iterables: Hyper parameter key iterables. Each an iterable whose possible values each represent + a hyper-parameter entry to delete, value formats are: + * A dictionary containing a 'section' and 'name' fields + * An iterable (e.g. tuple, list etc.) whose first two items denote 'section' and 'name' + * An API object of type tasks.ParamKey or tasks.ParamsItem whose section and name fields are not empty + """ + if not Session.check_min_api_version("2.9"): + raise ValueError("Not supported by server") + + def get_key(value): + if isinstance(value, dict): + key = (value.get("section"), value.get("name")) + elif isinstance(value, (tasks.ParamKey, tasks.ParamsItem)): + key = (value.section, value.name) + else: + key = tuple(map(str, value))[:2] + if not all(key): + raise ValueError("Missing section or name in '{}'".format(value)) + return key + + keys = {get_key(value) for iterable in iterables for value in iterable} + + res = self.task.session.send( + tasks.DeleteHyperParamsRequest( + task=self.task.task_id, + hyperparams=[tasks.ParamKey(section=section, name=name) for section, name in keys], + ), + ) + if res.ok(): + self.task.reload() + return True + + return False + + def _escape_unsafe_values(self, *values): + # type: (str) -> Generator[str] + """ Escape unsafe values (name, section name) for API version 2.10 and below """ + for value in values: + if value not in UNSAFE_NAMES_2_10: + yield value + else: + self.task.log.info( + "Converting unsafe hyper parameter name/section '{}' to '{}'".format(value, "_" + value) + ) + yield "_" + value + + +UNSAFE_NAMES_2_10 = { + "ne", + "gt", + "gte", + "lt", + "lte", + "in", + "nin", + "mod", + "all", + "size", + "exists", + "not", + "elemMatch", + "type", + "within_distance", + "within_spherical_distance", + "within_box", + "within_polygon", + "near", + "near_sphere", + "max_distance", + "min_distance", + "geo_within", + "geo_within_box", + "geo_within_polygon", + "geo_within_center", + "geo_within_sphere", + "geo_intersects", + "contains", + "icontains", + "startswith", + "istartswith", + "endswith", + "iendswith", + "exact", + "iexact", + "match", +} \ No newline at end of file diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index 13bbc6d3..65fd1bd5 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -3,17 +3,18 @@ import itertools import json import logging import os -import sys import re +import sys from copy import copy from enum import Enum -from tempfile import gettempdir from multiprocessing import RLock -from pathlib2 import Path +from tempfile import gettempdir from threading import Thread from typing import Optional, Any, Sequence, Callable, Mapping, Union, List from uuid import uuid4 +from pathlib2 import Path + try: # noinspection PyCompatibility from collections.abc import Iterable @@ -49,7 +50,7 @@ from ...storage.helper import StorageHelper, StorageError from .access import AccessMixin from .log import TaskHandler from .repo import ScriptInfo, pip_freeze -from .repo.util import get_command_output +from .hyperparams import HyperParams from ...config import config, PROC_MASTER_ID_ENV_VAR, SUPPRESS_UPDATE_MESSAGE_ENV_VAR @@ -171,6 +172,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): self._log_to_backend = log_to_backend self._setup_log(default_log_to_backend=log_to_backend) self._artifacts_manager = Artifacts(self) + self._hyper_params_manager = HyperParams(self) def _setup_log(self, default_log_to_backend=None, replace_existing=False): """ diff --git a/trains/task.py b/trains/task.py index 2b67146d..477b3e4f 100644 --- a/trains/task.py +++ b/trains/task.py @@ -7,6 +7,7 @@ import sys import threading import time from argparse import ArgumentParser +from operator import attrgetter from tempfile import mkstemp, mkdtemp from zipfile import ZipFile, ZIP_DEFLATED @@ -1520,6 +1521,111 @@ class Task(_Task): """ self._arguments.copy_from_dict(flatten_dictionary(dictionary)) + def get_user_properties(self, value_only=False): + # type: (bool) -> Dict[str, Union[str, dict]] + """ + Get user properties for this task. + Returns a dictionary mapping user property name to user property details dict. + :param value_only: If True, returned user property details will be a string representing the property value. + """ + if not Session.check_min_api_version("2.9"): + self.log.info("User properties are not supported by the server") + return {} + + section = "properties" + + params = self._hyper_params_manager.get_hyper_params( + sections=[section], projector=attrgetter("value") if value_only else None + ) + + return dict(params.get(section, {})) + + def set_user_properties( + self, + *iterables, # type: Union[Mapping[str, Union[str, dict, None]], Iterable[dict]] + **properties # type: Union[str, dict, None] + ): + # type: (...) -> bool + """ + Set user properties for this task. + A user property ca contain the following fields (all of type string): + * name + * value + * description + * type + + :param iterables: Properties iterables, each can be: + * A dictionary of string key (name) to either a string value (value) a dict (property details). If the value + is a dict, it must contain a "value" field. For example: + + .. code-block:: py + + { + "property_name": {"description": "This is a user property", "value": "property value"}, + "another_property_name": {"description": "This is another user property", "value": "another value"}, + "yet_another_property_name": "some value" + } + + * An iterable of dicts (each representing property details). Each dict must contain a "name" field and a + "value" field. For example: + + .. code-block:: py + + [ + { + "name": "property_name", + "description": "This is a user property", + "value": "property value" + }, + { + "name": "another_property_name", + "description": "This is another user property", + "value": "another value" + } + ] + + :param properties: Additional properties keyword arguments. Key is the property name, and value can be + a string (property value) or a dict (property details). If the value is a dict, it must contain a "value" + field. For example: + + .. code-block:: py + + { + "property_name": "string as property value", + "another_property_name": + { + "type": "string", + "description": "This is user property", + "value": "another value" + } + } + """ + if not Session.check_min_api_version("2.9"): + self.log.info("User properties are not supported by the server") + return False + + return self._hyper_params_manager.edit_hyper_params( + properties, + *iterables, + replace='none', + force_section="properties", + ) + + def delete_user_properties(self, *iterables): + # type: (Iterable[Union[dict, Iterable[str, str]]]) -> bool + """ + Delete hyper-parameters for this task. + :param iterables: Hyper parameter key iterables. Each an iterable whose possible values each represent + a hyper-parameter entry to delete, value formats are: + * A dictionary containing a 'section' and 'name' fields + * An iterable (e.g. tuple, list etc.) whose first two items denote 'section' and 'name' + """ + if not Session.check_min_api_version("2.9"): + self.log.info("User properties are not supported by the server") + return False + + return self._hyper_params_manager.delete_hyper_params(*iterables) + def set_base_docker(self, docker_cmd): # type: (str) -> () """