Add user properties support in Task object

This commit is contained in:
allegroai 2020-11-03 10:50:00 +02:00
parent 7bf208eb08
commit c328fbf345
5 changed files with 342 additions and 20 deletions

View File

@ -3838,10 +3838,8 @@ class DeleteHyperParamsRequest(Request):
self._property_hyperparams = None
return
self.assert_isinstance(value, "hyperparams", dict)
self.assert_isinstance(value.keys(), "hyperparams_keys", six.string_types, is_array=True)
self.assert_isinstance(value.values(), "hyperparams_values", (SectionParams, dict), is_array=True)
value = dict((k, SectionParams(**v) if isinstance(v, dict) else v) for k, v in value.items())
self.assert_isinstance(value, "hyperparams", (ParamKey, dict), is_array=True)
value = [(ParamKey(**v) if isinstance(v, dict) else v) for v in value]
self._property_hyperparams = value
@ -4745,11 +4743,8 @@ class EditConfigurationRequest(Request):
self._property_configuration = None
return
self.assert_isinstance(value, "configuration", dict)
self.assert_isinstance(value.keys(), "configuration_keys", six.string_types, is_array=True)
self.assert_isinstance(value.values(), "configuration_values", (ConfigurationItem, dict), is_array=True)
value = dict((k, ConfigurationItem(**v) if isinstance(v, dict) else v) for k, v in value.items())
self.assert_isinstance(value, "configuration", (dict, ConfigurationItem), is_array=True)
value = [(ConfigurationItem(**v) if isinstance(v, dict) else v) for v in value]
self._property_configuration = value
@ -4905,10 +4900,8 @@ class EditHyperParamsRequest(Request):
self._property_hyperparams = None
return
self.assert_isinstance(value, "hyperparams", dict)
self.assert_isinstance(value.keys(), "hyperparams_keys", six.string_types, is_array=True)
self.assert_isinstance(value.values(), "hyperparams_values", (SectionParams, dict), is_array=True)
value = dict((k, SectionParams(**v) if isinstance(v, dict) else v) for k, v in value.items())
self.assert_isinstance(value, "hyperparams", (dict, ParamsItem), is_array=True)
value = [(ParamsItem(**v) if isinstance(v, dict) else v) for v in value]
self._property_hyperparams = value
@ -7000,7 +6993,8 @@ class GetHyperParamsResponse(Response):
'properties': {
'params': {
'description': 'Hyper parameters (keyed by task ID)',
'type': ['object', 'null'],
'type': 'array',
'items': {'type': 'object'}
},
},
'type': 'object',
@ -7021,8 +7015,8 @@ class GetHyperParamsResponse(Response):
self._property_params = None
return
self.assert_isinstance(value, "params", (dict,))
self._property_params = value
self.assert_isinstance(value, "params", (dict,), is_array=True)
self._property_params = list(value)
class GetTypesRequest(Request):

View File

@ -173,3 +173,7 @@ class StringEnum(Enum):
def __str__(self):
return self.value
@classmethod
def has_value(cls, value):
return value in cls._value2member_map_

View File

@ -0,0 +1,216 @@
from collections import defaultdict
from typing import Optional, Any, Sequence, Callable, Mapping, Union, Dict, Iterable, Generator
from ...backend_api import Session
from ...backend_api.services import tasks
class HyperParams(object):
def __init__(self, task):
self.task = task
def get_hyper_params(
self,
sections=None, # type: Optional[Sequence[str]]
selector=None, # type: Optional[Callable[[dict], bool]]
projector=None, # type: Optional[Callable[[dict], Any]]
return_obj=False # type: Optional[bool]
):
# type: (...) -> Dict[str, Dict[Union[dict, Any]]]
"""
Get hyper-parameters for this task.
Returns a dictionary mapping user property name to user property details dict.
:param sections: Return only hyper-params in the provided sections
:param selector: A callable selecting which hyper-parameters should be returned
:param projector: A callable to project values before they are returned
:param return_obj: If True, returned dictionary values are API objects (tasks.ParamsItem). If ``projeictor
"""
if not Session.check_min_api_version("2.9"):
raise ValueError("Not supported by server")
task_id = self.task.task_id
res = self.task.session.send(tasks.GetHyperParamsRequest(tasks=[task_id]))
hyperparams = defaultdict(defaultdict)
if res.ok() and res.response.params:
for entry in res.response.params:
if entry.get("task") == task_id:
for item in entry.get("hyperparams", []):
# noinspection PyBroadException
try:
if (sections and item.get("section") not in sections) or (
selector and not selector(item)
):
continue
item = item if not projector else projector(item)
if return_obj:
item = tasks.ParamsItem()
hyperparams[item.get("section")][item.get("name")] = (
item if not projector else projector(item)
)
except Exception:
self.task.log.exception("Failed processing hyper-parameter")
return hyperparams
def edit_hyper_params(
self,
*iterables, # type: Union[Mapping[str, Union[str, dict, None]], Iterable[dict, tasks.ParamsItem]]
replace=None, # type: Optional[str]
default_section=None, # type: Optional[str]
force_section=None # type: Optional[str]
):
# type: (...) -> bool
"""
Set hyper-parameters for this task.
:param iterables: Hyper parameter iterables, each can be:
* A dictionary of string key (name) to either a string value (value), a tasks.ParamsItem or a dict
(hyperparam details). If ``default_section`` is not provided, each dict must contain a "section" field.
* An iterable of tasks.ParamsItem or dicts (each representing hyperparam details).
Each dict must contain a "name" field. If ``default_section`` is not provided, each dict must
also contain a "section" field.
:param replace: Optional replace strategy, values are:
* 'all' - provided hyper-params replace all existing hyper-params in task
* 'section' - only sections present in the provided hyper-params are replaced
* 'none' (default) - provided hyper-params will be merged into existing task hyper-params (i.e. will be
added or update existing hyper-params)
:param default_section: Optional section name to be used when section is not explicitly provided.
:param force_section: Optional section name to be used for all hyper-params.
"""
if not Session.check_min_api_version("2.9"):
raise ValueError("Not supported by server")
escape_unsafe = not Session.check_min_api_version("2.11")
if not tasks.ReplaceHyperparamsEnum.has_value(replace):
replace = None
def make_item(value, name=None):
if isinstance(value, tasks.ParamsItem):
item = value
elif isinstance(value, dict):
item = tasks.ParamsItem(**value)
else:
item = tasks.ParamsItem(value=str(value))
if name:
item.name = str(name)
if not item.name:
raise ValueError("Missing hyper-param name for '{}'".format(value))
section = force_section or item.section or default_section
if not section:
raise ValueError("Missing hyper-param section for '{}'".format(value))
if escape_unsafe:
item.section, item.name = self._escape_unsafe_values(section, item.name)
else:
item.section = section
return item
props = {}
for i in iterables:
if isinstance(i, dict):
props.update({name: make_item(value, name) for name, value in i.items()})
else:
props.update({item.name: item for item in map(make_item, i)})
res = self.task.session.send(
tasks.EditHyperParamsRequest(
task=self.task.task_id,
hyperparams=props.values(),
replace_hyperparams=replace,
),
)
if res.ok():
self.task.reload()
return True
return False
def delete_hyper_params(self, *iterables):
# type: (Iterable[Union[dict, Iterable[str, str], tasks.ParamKey, tasks.ParamsItem]]) -> bool
"""
Delete hyper-parameters for this task.
:param iterables: Hyper parameter key iterables. Each an iterable whose possible values each represent
a hyper-parameter entry to delete, value formats are:
* A dictionary containing a 'section' and 'name' fields
* An iterable (e.g. tuple, list etc.) whose first two items denote 'section' and 'name'
* An API object of type tasks.ParamKey or tasks.ParamsItem whose section and name fields are not empty
"""
if not Session.check_min_api_version("2.9"):
raise ValueError("Not supported by server")
def get_key(value):
if isinstance(value, dict):
key = (value.get("section"), value.get("name"))
elif isinstance(value, (tasks.ParamKey, tasks.ParamsItem)):
key = (value.section, value.name)
else:
key = tuple(map(str, value))[:2]
if not all(key):
raise ValueError("Missing section or name in '{}'".format(value))
return key
keys = {get_key(value) for iterable in iterables for value in iterable}
res = self.task.session.send(
tasks.DeleteHyperParamsRequest(
task=self.task.task_id,
hyperparams=[tasks.ParamKey(section=section, name=name) for section, name in keys],
),
)
if res.ok():
self.task.reload()
return True
return False
def _escape_unsafe_values(self, *values):
# type: (str) -> Generator[str]
""" Escape unsafe values (name, section name) for API version 2.10 and below """
for value in values:
if value not in UNSAFE_NAMES_2_10:
yield value
else:
self.task.log.info(
"Converting unsafe hyper parameter name/section '{}' to '{}'".format(value, "_" + value)
)
yield "_" + value
UNSAFE_NAMES_2_10 = {
"ne",
"gt",
"gte",
"lt",
"lte",
"in",
"nin",
"mod",
"all",
"size",
"exists",
"not",
"elemMatch",
"type",
"within_distance",
"within_spherical_distance",
"within_box",
"within_polygon",
"near",
"near_sphere",
"max_distance",
"min_distance",
"geo_within",
"geo_within_box",
"geo_within_polygon",
"geo_within_center",
"geo_within_sphere",
"geo_intersects",
"contains",
"icontains",
"startswith",
"istartswith",
"endswith",
"iendswith",
"exact",
"iexact",
"match",
}

View File

@ -3,17 +3,18 @@ import itertools
import json
import logging
import os
import sys
import re
import sys
from copy import copy
from enum import Enum
from tempfile import gettempdir
from multiprocessing import RLock
from pathlib2 import Path
from tempfile import gettempdir
from threading import Thread
from typing import Optional, Any, Sequence, Callable, Mapping, Union, List
from uuid import uuid4
from pathlib2 import Path
try:
# noinspection PyCompatibility
from collections.abc import Iterable
@ -49,7 +50,7 @@ from ...storage.helper import StorageHelper, StorageError
from .access import AccessMixin
from .log import TaskHandler
from .repo import ScriptInfo, pip_freeze
from .repo.util import get_command_output
from .hyperparams import HyperParams
from ...config import config, PROC_MASTER_ID_ENV_VAR, SUPPRESS_UPDATE_MESSAGE_ENV_VAR
@ -171,6 +172,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._log_to_backend = log_to_backend
self._setup_log(default_log_to_backend=log_to_backend)
self._artifacts_manager = Artifacts(self)
self._hyper_params_manager = HyperParams(self)
def _setup_log(self, default_log_to_backend=None, replace_existing=False):
"""

View File

@ -7,6 +7,7 @@ import sys
import threading
import time
from argparse import ArgumentParser
from operator import attrgetter
from tempfile import mkstemp, mkdtemp
from zipfile import ZipFile, ZIP_DEFLATED
@ -1520,6 +1521,111 @@ class Task(_Task):
"""
self._arguments.copy_from_dict(flatten_dictionary(dictionary))
def get_user_properties(self, value_only=False):
# type: (bool) -> Dict[str, Union[str, dict]]
"""
Get user properties for this task.
Returns a dictionary mapping user property name to user property details dict.
:param value_only: If True, returned user property details will be a string representing the property value.
"""
if not Session.check_min_api_version("2.9"):
self.log.info("User properties are not supported by the server")
return {}
section = "properties"
params = self._hyper_params_manager.get_hyper_params(
sections=[section], projector=attrgetter("value") if value_only else None
)
return dict(params.get(section, {}))
def set_user_properties(
self,
*iterables, # type: Union[Mapping[str, Union[str, dict, None]], Iterable[dict]]
**properties # type: Union[str, dict, None]
):
# type: (...) -> bool
"""
Set user properties for this task.
A user property ca contain the following fields (all of type string):
* name
* value
* description
* type
:param iterables: Properties iterables, each can be:
* A dictionary of string key (name) to either a string value (value) a dict (property details). If the value
is a dict, it must contain a "value" field. For example:
.. code-block:: py
{
"property_name": {"description": "This is a user property", "value": "property value"},
"another_property_name": {"description": "This is another user property", "value": "another value"},
"yet_another_property_name": "some value"
}
* An iterable of dicts (each representing property details). Each dict must contain a "name" field and a
"value" field. For example:
.. code-block:: py
[
{
"name": "property_name",
"description": "This is a user property",
"value": "property value"
},
{
"name": "another_property_name",
"description": "This is another user property",
"value": "another value"
}
]
:param properties: Additional properties keyword arguments. Key is the property name, and value can be
a string (property value) or a dict (property details). If the value is a dict, it must contain a "value"
field. For example:
.. code-block:: py
{
"property_name": "string as property value",
"another_property_name":
{
"type": "string",
"description": "This is user property",
"value": "another value"
}
}
"""
if not Session.check_min_api_version("2.9"):
self.log.info("User properties are not supported by the server")
return False
return self._hyper_params_manager.edit_hyper_params(
properties,
*iterables,
replace='none',
force_section="properties",
)
def delete_user_properties(self, *iterables):
# type: (Iterable[Union[dict, Iterable[str, str]]]) -> bool
"""
Delete hyper-parameters for this task.
:param iterables: Hyper parameter key iterables. Each an iterable whose possible values each represent
a hyper-parameter entry to delete, value formats are:
* A dictionary containing a 'section' and 'name' fields
* An iterable (e.g. tuple, list etc.) whose first two items denote 'section' and 'name'
"""
if not Session.check_min_api_version("2.9"):
self.log.info("User properties are not supported by the server")
return False
return self._hyper_params_manager.delete_hyper_params(*iterables)
def set_base_docker(self, docker_cmd):
# type: (str) -> ()
"""