mirror of
https://github.com/clearml/clearml
synced 2025-04-15 21:12:54 +00:00
Add user properties support in Task object
This commit is contained in:
parent
7bf208eb08
commit
c328fbf345
@ -3838,10 +3838,8 @@ class DeleteHyperParamsRequest(Request):
|
||||
self._property_hyperparams = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "hyperparams", dict)
|
||||
self.assert_isinstance(value.keys(), "hyperparams_keys", six.string_types, is_array=True)
|
||||
self.assert_isinstance(value.values(), "hyperparams_values", (SectionParams, dict), is_array=True)
|
||||
value = dict((k, SectionParams(**v) if isinstance(v, dict) else v) for k, v in value.items())
|
||||
self.assert_isinstance(value, "hyperparams", (ParamKey, dict), is_array=True)
|
||||
value = [(ParamKey(**v) if isinstance(v, dict) else v) for v in value]
|
||||
|
||||
self._property_hyperparams = value
|
||||
|
||||
@ -4745,11 +4743,8 @@ class EditConfigurationRequest(Request):
|
||||
self._property_configuration = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "configuration", dict)
|
||||
self.assert_isinstance(value.keys(), "configuration_keys", six.string_types, is_array=True)
|
||||
self.assert_isinstance(value.values(), "configuration_values", (ConfigurationItem, dict), is_array=True)
|
||||
|
||||
value = dict((k, ConfigurationItem(**v) if isinstance(v, dict) else v) for k, v in value.items())
|
||||
self.assert_isinstance(value, "configuration", (dict, ConfigurationItem), is_array=True)
|
||||
value = [(ConfigurationItem(**v) if isinstance(v, dict) else v) for v in value]
|
||||
|
||||
self._property_configuration = value
|
||||
|
||||
@ -4905,10 +4900,8 @@ class EditHyperParamsRequest(Request):
|
||||
self._property_hyperparams = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "hyperparams", dict)
|
||||
self.assert_isinstance(value.keys(), "hyperparams_keys", six.string_types, is_array=True)
|
||||
self.assert_isinstance(value.values(), "hyperparams_values", (SectionParams, dict), is_array=True)
|
||||
value = dict((k, SectionParams(**v) if isinstance(v, dict) else v) for k, v in value.items())
|
||||
self.assert_isinstance(value, "hyperparams", (dict, ParamsItem), is_array=True)
|
||||
value = [(ParamsItem(**v) if isinstance(v, dict) else v) for v in value]
|
||||
|
||||
self._property_hyperparams = value
|
||||
|
||||
@ -7000,7 +6993,8 @@ class GetHyperParamsResponse(Response):
|
||||
'properties': {
|
||||
'params': {
|
||||
'description': 'Hyper parameters (keyed by task ID)',
|
||||
'type': ['object', 'null'],
|
||||
'type': 'array',
|
||||
'items': {'type': 'object'}
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
@ -7021,8 +7015,8 @@ class GetHyperParamsResponse(Response):
|
||||
self._property_params = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "params", (dict,))
|
||||
self._property_params = value
|
||||
self.assert_isinstance(value, "params", (dict,), is_array=True)
|
||||
self._property_params = list(value)
|
||||
|
||||
|
||||
class GetTypesRequest(Request):
|
||||
|
@ -173,3 +173,7 @@ class StringEnum(Enum):
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
@classmethod
|
||||
def has_value(cls, value):
|
||||
return value in cls._value2member_map_
|
||||
|
216
trains/backend_interface/task/hyperparams.py
Normal file
216
trains/backend_interface/task/hyperparams.py
Normal file
@ -0,0 +1,216 @@
|
||||
from collections import defaultdict
|
||||
from typing import Optional, Any, Sequence, Callable, Mapping, Union, Dict, Iterable, Generator
|
||||
|
||||
from ...backend_api import Session
|
||||
from ...backend_api.services import tasks
|
||||
|
||||
|
||||
class HyperParams(object):
|
||||
def __init__(self, task):
|
||||
self.task = task
|
||||
|
||||
def get_hyper_params(
|
||||
self,
|
||||
sections=None, # type: Optional[Sequence[str]]
|
||||
selector=None, # type: Optional[Callable[[dict], bool]]
|
||||
projector=None, # type: Optional[Callable[[dict], Any]]
|
||||
return_obj=False # type: Optional[bool]
|
||||
):
|
||||
# type: (...) -> Dict[str, Dict[Union[dict, Any]]]
|
||||
"""
|
||||
Get hyper-parameters for this task.
|
||||
Returns a dictionary mapping user property name to user property details dict.
|
||||
:param sections: Return only hyper-params in the provided sections
|
||||
:param selector: A callable selecting which hyper-parameters should be returned
|
||||
:param projector: A callable to project values before they are returned
|
||||
:param return_obj: If True, returned dictionary values are API objects (tasks.ParamsItem). If ``projeictor
|
||||
"""
|
||||
if not Session.check_min_api_version("2.9"):
|
||||
raise ValueError("Not supported by server")
|
||||
|
||||
task_id = self.task.task_id
|
||||
|
||||
res = self.task.session.send(tasks.GetHyperParamsRequest(tasks=[task_id]))
|
||||
hyperparams = defaultdict(defaultdict)
|
||||
if res.ok() and res.response.params:
|
||||
for entry in res.response.params:
|
||||
if entry.get("task") == task_id:
|
||||
for item in entry.get("hyperparams", []):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if (sections and item.get("section") not in sections) or (
|
||||
selector and not selector(item)
|
||||
):
|
||||
continue
|
||||
item = item if not projector else projector(item)
|
||||
if return_obj:
|
||||
item = tasks.ParamsItem()
|
||||
hyperparams[item.get("section")][item.get("name")] = (
|
||||
item if not projector else projector(item)
|
||||
)
|
||||
except Exception:
|
||||
self.task.log.exception("Failed processing hyper-parameter")
|
||||
return hyperparams
|
||||
|
||||
def edit_hyper_params(
|
||||
self,
|
||||
*iterables, # type: Union[Mapping[str, Union[str, dict, None]], Iterable[dict, tasks.ParamsItem]]
|
||||
replace=None, # type: Optional[str]
|
||||
default_section=None, # type: Optional[str]
|
||||
force_section=None # type: Optional[str]
|
||||
):
|
||||
# type: (...) -> bool
|
||||
"""
|
||||
Set hyper-parameters for this task.
|
||||
:param iterables: Hyper parameter iterables, each can be:
|
||||
* A dictionary of string key (name) to either a string value (value), a tasks.ParamsItem or a dict
|
||||
(hyperparam details). If ``default_section`` is not provided, each dict must contain a "section" field.
|
||||
* An iterable of tasks.ParamsItem or dicts (each representing hyperparam details).
|
||||
Each dict must contain a "name" field. If ``default_section`` is not provided, each dict must
|
||||
also contain a "section" field.
|
||||
:param replace: Optional replace strategy, values are:
|
||||
* 'all' - provided hyper-params replace all existing hyper-params in task
|
||||
* 'section' - only sections present in the provided hyper-params are replaced
|
||||
* 'none' (default) - provided hyper-params will be merged into existing task hyper-params (i.e. will be
|
||||
added or update existing hyper-params)
|
||||
:param default_section: Optional section name to be used when section is not explicitly provided.
|
||||
:param force_section: Optional section name to be used for all hyper-params.
|
||||
"""
|
||||
if not Session.check_min_api_version("2.9"):
|
||||
raise ValueError("Not supported by server")
|
||||
|
||||
escape_unsafe = not Session.check_min_api_version("2.11")
|
||||
|
||||
if not tasks.ReplaceHyperparamsEnum.has_value(replace):
|
||||
replace = None
|
||||
|
||||
def make_item(value, name=None):
|
||||
if isinstance(value, tasks.ParamsItem):
|
||||
item = value
|
||||
elif isinstance(value, dict):
|
||||
item = tasks.ParamsItem(**value)
|
||||
else:
|
||||
item = tasks.ParamsItem(value=str(value))
|
||||
if name:
|
||||
item.name = str(name)
|
||||
if not item.name:
|
||||
raise ValueError("Missing hyper-param name for '{}'".format(value))
|
||||
section = force_section or item.section or default_section
|
||||
if not section:
|
||||
raise ValueError("Missing hyper-param section for '{}'".format(value))
|
||||
if escape_unsafe:
|
||||
item.section, item.name = self._escape_unsafe_values(section, item.name)
|
||||
else:
|
||||
item.section = section
|
||||
return item
|
||||
|
||||
props = {}
|
||||
for i in iterables:
|
||||
if isinstance(i, dict):
|
||||
props.update({name: make_item(value, name) for name, value in i.items()})
|
||||
else:
|
||||
props.update({item.name: item for item in map(make_item, i)})
|
||||
|
||||
res = self.task.session.send(
|
||||
tasks.EditHyperParamsRequest(
|
||||
task=self.task.task_id,
|
||||
hyperparams=props.values(),
|
||||
replace_hyperparams=replace,
|
||||
),
|
||||
)
|
||||
if res.ok():
|
||||
self.task.reload()
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def delete_hyper_params(self, *iterables):
|
||||
# type: (Iterable[Union[dict, Iterable[str, str], tasks.ParamKey, tasks.ParamsItem]]) -> bool
|
||||
"""
|
||||
Delete hyper-parameters for this task.
|
||||
:param iterables: Hyper parameter key iterables. Each an iterable whose possible values each represent
|
||||
a hyper-parameter entry to delete, value formats are:
|
||||
* A dictionary containing a 'section' and 'name' fields
|
||||
* An iterable (e.g. tuple, list etc.) whose first two items denote 'section' and 'name'
|
||||
* An API object of type tasks.ParamKey or tasks.ParamsItem whose section and name fields are not empty
|
||||
"""
|
||||
if not Session.check_min_api_version("2.9"):
|
||||
raise ValueError("Not supported by server")
|
||||
|
||||
def get_key(value):
|
||||
if isinstance(value, dict):
|
||||
key = (value.get("section"), value.get("name"))
|
||||
elif isinstance(value, (tasks.ParamKey, tasks.ParamsItem)):
|
||||
key = (value.section, value.name)
|
||||
else:
|
||||
key = tuple(map(str, value))[:2]
|
||||
if not all(key):
|
||||
raise ValueError("Missing section or name in '{}'".format(value))
|
||||
return key
|
||||
|
||||
keys = {get_key(value) for iterable in iterables for value in iterable}
|
||||
|
||||
res = self.task.session.send(
|
||||
tasks.DeleteHyperParamsRequest(
|
||||
task=self.task.task_id,
|
||||
hyperparams=[tasks.ParamKey(section=section, name=name) for section, name in keys],
|
||||
),
|
||||
)
|
||||
if res.ok():
|
||||
self.task.reload()
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _escape_unsafe_values(self, *values):
|
||||
# type: (str) -> Generator[str]
|
||||
""" Escape unsafe values (name, section name) for API version 2.10 and below """
|
||||
for value in values:
|
||||
if value not in UNSAFE_NAMES_2_10:
|
||||
yield value
|
||||
else:
|
||||
self.task.log.info(
|
||||
"Converting unsafe hyper parameter name/section '{}' to '{}'".format(value, "_" + value)
|
||||
)
|
||||
yield "_" + value
|
||||
|
||||
|
||||
UNSAFE_NAMES_2_10 = {
|
||||
"ne",
|
||||
"gt",
|
||||
"gte",
|
||||
"lt",
|
||||
"lte",
|
||||
"in",
|
||||
"nin",
|
||||
"mod",
|
||||
"all",
|
||||
"size",
|
||||
"exists",
|
||||
"not",
|
||||
"elemMatch",
|
||||
"type",
|
||||
"within_distance",
|
||||
"within_spherical_distance",
|
||||
"within_box",
|
||||
"within_polygon",
|
||||
"near",
|
||||
"near_sphere",
|
||||
"max_distance",
|
||||
"min_distance",
|
||||
"geo_within",
|
||||
"geo_within_box",
|
||||
"geo_within_polygon",
|
||||
"geo_within_center",
|
||||
"geo_within_sphere",
|
||||
"geo_intersects",
|
||||
"contains",
|
||||
"icontains",
|
||||
"startswith",
|
||||
"istartswith",
|
||||
"endswith",
|
||||
"iendswith",
|
||||
"exact",
|
||||
"iexact",
|
||||
"match",
|
||||
}
|
@ -3,17 +3,18 @@ import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import sys
|
||||
from copy import copy
|
||||
from enum import Enum
|
||||
from tempfile import gettempdir
|
||||
from multiprocessing import RLock
|
||||
from pathlib2 import Path
|
||||
from tempfile import gettempdir
|
||||
from threading import Thread
|
||||
from typing import Optional, Any, Sequence, Callable, Mapping, Union, List
|
||||
from uuid import uuid4
|
||||
|
||||
from pathlib2 import Path
|
||||
|
||||
try:
|
||||
# noinspection PyCompatibility
|
||||
from collections.abc import Iterable
|
||||
@ -49,7 +50,7 @@ from ...storage.helper import StorageHelper, StorageError
|
||||
from .access import AccessMixin
|
||||
from .log import TaskHandler
|
||||
from .repo import ScriptInfo, pip_freeze
|
||||
from .repo.util import get_command_output
|
||||
from .hyperparams import HyperParams
|
||||
from ...config import config, PROC_MASTER_ID_ENV_VAR, SUPPRESS_UPDATE_MESSAGE_ENV_VAR
|
||||
|
||||
|
||||
@ -171,6 +172,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
self._log_to_backend = log_to_backend
|
||||
self._setup_log(default_log_to_backend=log_to_backend)
|
||||
self._artifacts_manager = Artifacts(self)
|
||||
self._hyper_params_manager = HyperParams(self)
|
||||
|
||||
def _setup_log(self, default_log_to_backend=None, replace_existing=False):
|
||||
"""
|
||||
|
106
trains/task.py
106
trains/task.py
@ -7,6 +7,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from operator import attrgetter
|
||||
from tempfile import mkstemp, mkdtemp
|
||||
from zipfile import ZipFile, ZIP_DEFLATED
|
||||
|
||||
@ -1520,6 +1521,111 @@ class Task(_Task):
|
||||
"""
|
||||
self._arguments.copy_from_dict(flatten_dictionary(dictionary))
|
||||
|
||||
def get_user_properties(self, value_only=False):
|
||||
# type: (bool) -> Dict[str, Union[str, dict]]
|
||||
"""
|
||||
Get user properties for this task.
|
||||
Returns a dictionary mapping user property name to user property details dict.
|
||||
:param value_only: If True, returned user property details will be a string representing the property value.
|
||||
"""
|
||||
if not Session.check_min_api_version("2.9"):
|
||||
self.log.info("User properties are not supported by the server")
|
||||
return {}
|
||||
|
||||
section = "properties"
|
||||
|
||||
params = self._hyper_params_manager.get_hyper_params(
|
||||
sections=[section], projector=attrgetter("value") if value_only else None
|
||||
)
|
||||
|
||||
return dict(params.get(section, {}))
|
||||
|
||||
def set_user_properties(
|
||||
self,
|
||||
*iterables, # type: Union[Mapping[str, Union[str, dict, None]], Iterable[dict]]
|
||||
**properties # type: Union[str, dict, None]
|
||||
):
|
||||
# type: (...) -> bool
|
||||
"""
|
||||
Set user properties for this task.
|
||||
A user property ca contain the following fields (all of type string):
|
||||
* name
|
||||
* value
|
||||
* description
|
||||
* type
|
||||
|
||||
:param iterables: Properties iterables, each can be:
|
||||
* A dictionary of string key (name) to either a string value (value) a dict (property details). If the value
|
||||
is a dict, it must contain a "value" field. For example:
|
||||
|
||||
.. code-block:: py
|
||||
|
||||
{
|
||||
"property_name": {"description": "This is a user property", "value": "property value"},
|
||||
"another_property_name": {"description": "This is another user property", "value": "another value"},
|
||||
"yet_another_property_name": "some value"
|
||||
}
|
||||
|
||||
* An iterable of dicts (each representing property details). Each dict must contain a "name" field and a
|
||||
"value" field. For example:
|
||||
|
||||
.. code-block:: py
|
||||
|
||||
[
|
||||
{
|
||||
"name": "property_name",
|
||||
"description": "This is a user property",
|
||||
"value": "property value"
|
||||
},
|
||||
{
|
||||
"name": "another_property_name",
|
||||
"description": "This is another user property",
|
||||
"value": "another value"
|
||||
}
|
||||
]
|
||||
|
||||
:param properties: Additional properties keyword arguments. Key is the property name, and value can be
|
||||
a string (property value) or a dict (property details). If the value is a dict, it must contain a "value"
|
||||
field. For example:
|
||||
|
||||
.. code-block:: py
|
||||
|
||||
{
|
||||
"property_name": "string as property value",
|
||||
"another_property_name":
|
||||
{
|
||||
"type": "string",
|
||||
"description": "This is user property",
|
||||
"value": "another value"
|
||||
}
|
||||
}
|
||||
"""
|
||||
if not Session.check_min_api_version("2.9"):
|
||||
self.log.info("User properties are not supported by the server")
|
||||
return False
|
||||
|
||||
return self._hyper_params_manager.edit_hyper_params(
|
||||
properties,
|
||||
*iterables,
|
||||
replace='none',
|
||||
force_section="properties",
|
||||
)
|
||||
|
||||
def delete_user_properties(self, *iterables):
|
||||
# type: (Iterable[Union[dict, Iterable[str, str]]]) -> bool
|
||||
"""
|
||||
Delete hyper-parameters for this task.
|
||||
:param iterables: Hyper parameter key iterables. Each an iterable whose possible values each represent
|
||||
a hyper-parameter entry to delete, value formats are:
|
||||
* A dictionary containing a 'section' and 'name' fields
|
||||
* An iterable (e.g. tuple, list etc.) whose first two items denote 'section' and 'name'
|
||||
"""
|
||||
if not Session.check_min_api_version("2.9"):
|
||||
self.log.info("User properties are not supported by the server")
|
||||
return False
|
||||
|
||||
return self._hyper_params_manager.delete_hyper_params(*iterables)
|
||||
|
||||
def set_base_docker(self, docker_cmd):
|
||||
# type: (str) -> ()
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user