Add multi configuration section support (hyperparams and configurations)

Support setting offline mode API version using TRAINS_OFFLINE_MODE env var
This commit is contained in:
allegroai 2020-08-08 12:35:03 +03:00
parent 6d4e85de0a
commit e378de1e41
6 changed files with 355 additions and 99 deletions

View File

@ -64,7 +64,11 @@ class DataModel(object):
@classmethod @classmethod
def _to_base_type(cls, value): def _to_base_type(cls, value):
if isinstance(value, DataModel): if isinstance(value, dict):
# Note: this should come before DataModel to handle data models that are simply a dict
# (and thus are not expected to have additional named properties)
return {k: cls._to_base_type(v) for k, v in value.items()}
elif isinstance(value, DataModel):
return value.to_dict() return value.to_dict()
elif isinstance(value, enum.Enum): elif isinstance(value, enum.Enum):
return value.value return value.value

View File

@ -62,6 +62,7 @@ class Session(TokenManager):
default_files = "https://demofiles.trains.allegro.ai" default_files = "https://demofiles.trains.allegro.ai"
default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW" default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8" default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
force_max_api_version = None
# TODO: add requests.codes.gateway_timeout once we support async commits # TODO: add requests.codes.gateway_timeout once we support async commits
_retry_codes = [ _retry_codes = [
@ -182,6 +183,9 @@ class Session(TokenManager):
self.__class__._sessions_created += 1 self.__class__._sessions_created += 1
if self.force_max_api_version and self.check_min_api_version(self.force_max_api_version):
Session.api_version = str(self.force_max_api_version)
def _send_request( def _send_request(
self, self,
service, service,
@ -539,7 +543,17 @@ class Session(TokenManager):
# If no session was created, create a default one, in order to get the backend api version. # If no session was created, create a default one, in order to get the backend api version.
if cls._sessions_created <= 0: if cls._sessions_created <= 0:
if cls._offline_mode: if cls._offline_mode:
cls.api_version = cls._offline_default_version # allow to change the offline mode version by setting ENV_OFFLINE_MODE to the required API version
if cls.api_version != cls._offline_default_version:
offline_api = ENV_OFFLINE_MODE.get(converter=lambda x: x)
if offline_api:
try:
# check cast to float, but leave original str if we pass it.
float(offline_api)
cls._offline_default_version = str(offline_api)
except ValueError:
pass
cls.api_version = cls._offline_default_version
else: else:
# noinspection PyBroadException # noinspection PyBroadException
try: try:

View File

@ -10,32 +10,35 @@ from ...utilities.args import call_original_argparser
class _Arguments(object): class _Arguments(object):
_prefix_sep = '/' _prefix_sep = '/'
# TODO: separate dict and argparse after we add UI support # TODO: separate dict and argparse after we add UI support
_prefix_dict = 'dict' + _prefix_sep _prefix_args = 'Args' + _prefix_sep
_prefix_args = 'argparse' + _prefix_sep
_prefix_tf_defines = 'TF_DEFINE' + _prefix_sep _prefix_tf_defines = 'TF_DEFINE' + _prefix_sep
class _ProxyDictWrite(dict): class _ProxyDictWrite(dict):
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """ """ Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
def __init__(self, arguments, *args, **kwargs): def __init__(self, __arguments, __section_name, *args, **kwargs):
super(_Arguments._ProxyDictWrite, self).__init__(*args, **kwargs) super(_Arguments._ProxyDictWrite, self).__init__(*args, **kwargs)
self._arguments = arguments self._arguments = __arguments
self._section_name = (__section_name.strip(_Arguments._prefix_sep) + _Arguments._prefix_sep) \
if __section_name else None
def __setitem__(self, key, value): def __setitem__(self, key, value):
super(_Arguments._ProxyDictWrite, self).__setitem__(key, value) super(_Arguments._ProxyDictWrite, self).__setitem__(key, value)
if self._arguments: if self._arguments:
self._arguments.copy_from_dict(self) self._arguments.copy_from_dict(self, prefix=self._section_name)
class _ProxyDictReadOnly(dict): class _ProxyDictReadOnly(dict):
""" Dictionary wrapper that prevents modifications to the dictionary """ """ Dictionary wrapper that prevents modifications to the dictionary """
def __init__(self, arguments, *args, **kwargs): def __init__(self, __arguments, __section_name, *args, **kwargs):
super(_Arguments._ProxyDictReadOnly, self).__init__(*args, **kwargs) super(_Arguments._ProxyDictReadOnly, self).__init__(*args, **kwargs)
self._arguments = arguments self._arguments = __arguments
self._section_name = (__section_name.strip(_Arguments._prefix_sep) + _Arguments._prefix_sep) \
if __section_name else None
def __setitem__(self, key, value): def __setitem__(self, key, value):
if self._arguments: if self._arguments:
param_dict = self._arguments.copy_to_dict({key: value}) param_dict = self._arguments.copy_to_dict({key: value}, prefix=self._section_name)
value = param_dict.get(key, value) value = param_dict.get(key, value)
super(_Arguments._ProxyDictReadOnly, self).__setitem__(key, value) super(_Arguments._ProxyDictReadOnly, self).__setitem__(key, value)
@ -45,24 +48,32 @@ class _Arguments(object):
self._exclude_parser_args = {} self._exclude_parser_args = {}
def exclude_parser_args(self, excluded_args): def exclude_parser_args(self, excluded_args):
"""
You can use a dictionary for fined grained control of connected
arguments. The dictionary keys are argparse variable names and the values are booleans.
The ``False`` value excludes the specified argument from the Task's parameter section.
Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``.
:param excluded_args: dict
"""
self._exclude_parser_args = excluded_args or {} self._exclude_parser_args = excluded_args or {}
def set_defaults(self, *dicts, **kwargs): def set_defaults(self, *dicts, **kwargs):
self._task.set_parameters(*dicts, **kwargs) # noinspection PyProtectedMember
self._task._set_parameters(*dicts, __parameters_prefix=self._prefix_args, **kwargs)
def add_argument(self, option_strings, type=None, default=None, help=None): def add_argument(self, option_strings, type=None, default=None, help=None):
if not option_strings: if not option_strings:
raise Exception('Expected at least one argument name (option string)') raise Exception('Expected at least one argument name (option string)')
name = option_strings[0].strip('- \t') if isinstance(option_strings, list) else option_strings.strip('- \t') name = option_strings[0].strip('- \t') if isinstance(option_strings, list) else option_strings.strip('- \t')
# TODO: add argparse prefix name = self._prefix_args + name
# name = self._prefix_args + name
self._task.set_parameter(name=name, value=default, description=help) self._task.set_parameter(name=name, value=default, description=help)
def connect(self, parser): def connect(self, parser):
self._task.connect_argparse(parser) self._task.connect_argparse(parser)
@classmethod @classmethod
def _add_to_defaults(cls, a_parser, defaults, a_args=None, a_namespace=None, a_parsed_args=None): def _add_to_defaults(cls, a_parser, defaults, descriptions, a_args=None, a_namespace=None, a_parsed_args=None):
actions = [ actions = [
a for a in a_parser._actions a for a in a_parser._actions
if isinstance(a, _StoreAction) or isinstance(a, _StoreConstAction) if isinstance(a, _StoreAction) or isinstance(a, _StoreConstAction)
@ -87,6 +98,9 @@ class _Arguments(object):
for a in actions for a in actions
} }
desc_ = {a.dest: a.help for a in actions}
descriptions.update(desc_)
full_args_dict = copy(defaults) full_args_dict = copy(defaults)
full_args_dict.update(args_dict) full_args_dict.update(args_dict)
defaults.update(defaults_) defaults.update(defaults_)
@ -101,17 +115,19 @@ class _Arguments(object):
defaults[sub_parser.dest] = full_args_dict.get(sub_parser.dest) or '' defaults[sub_parser.dest] = full_args_dict.get(sub_parser.dest) or ''
for choice in sub_parser.choices.values(): for choice in sub_parser.choices.values():
# recursively parse # recursively parse
defaults = cls._add_to_defaults( defaults, descriptions = cls._add_to_defaults(
a_parser=choice, a_parser=choice,
defaults=defaults, defaults=defaults,
descriptions=descriptions,
a_parsed_args=a_parsed_args or full_args_dict a_parsed_args=a_parsed_args or full_args_dict
) )
return defaults return defaults, descriptions
def copy_defaults_from_argparse(self, parser, args=None, namespace=None, parsed_args=None): def copy_defaults_from_argparse(self, parser, args=None, namespace=None, parsed_args=None):
task_defaults = {} task_defaults = {}
self._add_to_defaults(parser, task_defaults, args, namespace, parsed_args) task_defaults_descriptions = {}
self._add_to_defaults(parser, task_defaults, task_defaults_descriptions, args, namespace, parsed_args)
# Make sure we didn't miss anything # Make sure we didn't miss anything
if parsed_args: if parsed_args:
@ -133,12 +149,13 @@ class _Arguments(object):
except Exception: except Exception:
del task_defaults[k] del task_defaults[k]
# Skip excluded arguments, Add prefix, TODO: add argparse prefix # Skip excluded arguments, Add prefix.
# task_defaults = dict([(self._prefix_args + k, v) for k, v in task_defaults.items() task_defaults = dict([(self._prefix_args + k, v) for k, v in task_defaults.items()
# if k not in self._exclude_parser_args]) if self._exclude_parser_args.get(k, True)])
task_defaults = dict([(k, v) for k, v in task_defaults.items() if self._exclude_parser_args.get(k, True)]) task_defaults_descriptions = dict([(self._prefix_args + k, v) for k, v in task_defaults_descriptions.items()
if self._exclude_parser_args.get(k, True)])
# Store to task # Store to task
self._task.update_parameters(task_defaults) self._task.update_parameters(task_defaults, __parameters_descriptions=task_defaults_descriptions)
@classmethod @classmethod
def _find_parser_action(cls, a_parser, name): def _find_parser_action(cls, a_parser, name):
@ -158,11 +175,10 @@ class _Arguments(object):
return _actions return _actions
def copy_to_parser(self, parser, parsed_args): def copy_to_parser(self, parser, parsed_args):
# todo: change to argparse prefix only # Change to argparse prefix only
# task_arguments = dict([(k[len(self._prefix_args):], v) for k, v in self._task.get_parameters().items() task_arguments = dict([(k[len(self._prefix_args):], v) for k, v in self._task.get_parameters().items()
# if k.startswith(self._prefix_args)]) if k.startswith(self._prefix_args) and
task_arguments = dict([(k, v) for k, v in self._task.get_parameters().items() self._exclude_parser_args.get(k[len(self._prefix_args):], True)])
if not k.startswith(self._prefix_tf_defines) and self._exclude_parser_args.get(k, True)])
arg_parser_argeuments = {} arg_parser_argeuments = {}
for k, v in task_arguments.items(): for k, v in task_arguments.items():
# python2 unicode support # python2 unicode support
@ -304,25 +320,24 @@ class _Arguments(object):
parser.set_defaults(**arg_parser_argeuments) parser.set_defaults(**arg_parser_argeuments)
def copy_from_dict(self, dictionary, prefix=None): def copy_from_dict(self, dictionary, prefix=None):
# TODO: add dict prefix # add dict prefix
prefix = prefix or '' # self._prefix_dict prefix = prefix # or self._prefix_dict
if prefix: if prefix:
with self._task._edit_lock: prefix = prefix.strip(self._prefix_sep) + self._prefix_sep
prefix_dictionary = dict([(prefix + k, v) for k, v in dictionary.items()]) # this will only set the specific section
cur_params = dict([(k, v) for k, v in self._task.get_parameters().items() if not k.startswith(prefix)]) self._task.set_parameters(dictionary, __parameters_prefix=prefix)
cur_params.update(prefix_dictionary)
self._task.set_parameters(cur_params)
else: else:
self._task.update_parameters(dictionary) self._task.update_parameters(dictionary)
if not isinstance(dictionary, self._ProxyDictWrite): if not isinstance(dictionary, self._ProxyDictWrite):
return self._ProxyDictWrite(self, **dictionary) return self._ProxyDictWrite(self, prefix, **dictionary)
return dictionary return dictionary
def copy_to_dict(self, dictionary, prefix=None): def copy_to_dict(self, dictionary, prefix=None):
# iterate over keys and merge values according to parameter type in dictionary # iterate over keys and merge values according to parameter type in dictionary
# TODO: add dict prefix # add dict prefix
prefix = prefix or '' # self._prefix_dict prefix = prefix # or self._prefix_dict
if prefix: if prefix:
prefix = prefix.strip(self._prefix_sep) + self._prefix_sep
parameters = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items() parameters = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items()
if k.startswith(prefix)]) if k.startswith(prefix)])
else: else:
@ -396,5 +411,5 @@ class _Arguments(object):
# dictionary[k] = v # dictionary[k] = v
if not isinstance(dictionary, self._ProxyDictReadOnly): if not isinstance(dictionary, self._ProxyDictReadOnly):
return self._ProxyDictReadOnly(self, **dictionary) return self._ProxyDictReadOnly(self, prefix, **dictionary)
return dictionary return dictionary

View File

@ -30,13 +30,15 @@ from ...backend_api import Session
from ...backend_api.services import tasks, models, events, projects from ...backend_api.services import tasks, models, events, projects
from ...backend_api.session.defs import ENV_OFFLINE_MODE from ...backend_api.session.defs import ENV_OFFLINE_MODE
from ...utilities.pyhocon import ConfigTree, ConfigFactory from ...utilities.pyhocon import ConfigTree, ConfigFactory
from ...utilities.config import config_dict_to_text, text_to_config_dict
from ..base import IdObjectBase, InterfaceBase from ..base import IdObjectBase, InterfaceBase
from ..metrics import Metrics, Reporter from ..metrics import Metrics, Reporter
from ..model import Model from ..model import Model
from ..setupuploadmixin import SetupUploadMixin from ..setupuploadmixin import SetupUploadMixin
from ..util import make_message, get_or_create_project, get_single_result, \ from ..util import (
exact_match_regex make_message, get_or_create_project, get_single_result,
exact_match_regex, mutually_exclusive, )
from ...config import ( from ...config import (
get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend,
running_remotely, get_cache_dir, DOCKER_IMAGE_ENV_VAR, get_offline_dir) running_remotely, get_cache_dir, DOCKER_IMAGE_ENV_VAR, get_offline_dir)
@ -768,12 +770,52 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._edit(execution=self.data.execution) self._edit(execution=self.data.execution)
def get_parameters(self, backwards_compatibility=True):
# type: (bool) -> (Optional[dict])
"""
Get the parameters for a Task. This method returns a complete group of key-value parameter pairs, but does not
support parameter descriptions (the result is a dictionary of key-value pairs).
:param backwards_compatibility: If True (default) parameters without section name
(API version < 2.9, trains-server < 0.16) will be at dict root level.
If False, parameters without section name, will be nested under "general/" key.
:return: dict of the task parameters, all flattened to key/value.
Different sections with key prefix "section/"
"""
if not Session.check_min_api_version('2.9'):
return self._get_task_property('execution.parameters')
# API will makes sure we get old parameters under _legacy
parameters = dict()
hyperparams = self._get_task_property('hyperparams', default=dict())
if len(hyperparams) == 1 and '_legacy' in hyperparams and backwards_compatibility:
for section in ('_legacy', ):
for key, section_param in hyperparams[section].items():
parameters['{}'.format(key)] = section_param.value
else:
for section in hyperparams:
for key, section_param in hyperparams[section].items():
parameters['{}/{}'.format(section, key)] = section_param.value
return parameters
def set_parameters(self, *args, **kwargs): def set_parameters(self, *args, **kwargs):
# type: (*dict, **Any) -> () # type: (*dict, **Any) -> ()
""" """
Set the parameters for a Task. This method sets a complete group of key-value parameter pairs, but does not Set the parameters for a Task. This method sets a complete group of key-value parameter pairs, but does not
support parameter descriptions (the input is a dictionary of key-value pairs). support parameter descriptions (the input is a dictionary of key-value pairs).
:param args: Positional arguments, which are one or more dictionary or (key, value) iterable. They are
merged into a single key-value pair dictionary.
:param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``.
"""
return self._set_parameters(*args, __update=False, **kwargs)
def _set_parameters(self, *args, **kwargs):
# type: (*dict, **Any) -> ()
"""
Set the parameters for a Task. This method sets a complete group of key-value parameter pairs, but does not
support parameter descriptions (the input is a dictionary of key-value pairs).
:param args: Positional arguments, which are one or more dictionary or (key, value) iterable. They are :param args: Positional arguments, which are one or more dictionary or (key, value) iterable. They are
merged into a single key-value pair dictionary. merged into a single key-value pair dictionary.
:param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``. :param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``.
@ -781,58 +823,113 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if not all(isinstance(x, (dict, Iterable)) for x in args): if not all(isinstance(x, (dict, Iterable)) for x in args):
raise ValueError('only dict or iterable are supported as positional arguments') raise ValueError('only dict or iterable are supported as positional arguments')
prefix = kwargs.pop('__parameters_prefix', None)
descriptions = kwargs.pop('__parameters_descriptions', None) or dict()
params_types = kwargs.pop('__parameters_types', None) or dict()
update = kwargs.pop('__update', False) update = kwargs.pop('__update', False)
# new parameters dict
new_parameters = dict(itertools.chain.from_iterable(x.items() if isinstance(x, dict) else x for x in args))
new_parameters.update(kwargs)
if prefix:
prefix = prefix.strip('/')
new_parameters = dict(('{}/{}'.format(prefix, k), v) for k, v in new_parameters.items())
# verify parameters type:
not_allowed = {
k: type(v).__name__
for k, v in new_parameters.items()
if not isinstance(v, self._parameters_allowed_types)
}
if not_allowed:
raise ValueError(
"Only builtin types ({}) are allowed for values (got {})".format(
', '.join(t.__name__ for t in self._parameters_allowed_types),
', '.join('%s=>%s' % p for p in not_allowed.items())),
)
use_hyperparams = Session.check_min_api_version('2.9')
with self._edit_lock: with self._edit_lock:
self.reload() self.reload()
if update: # if we have a specific prefix and we use hyperparameters, and we use set.
parameters = self.get_parameters() # overwrite only the prefix, leave the rest as is.
if not update and prefix:
parameters = dict(**(self.get_parameters() or {}))
parameters = dict((k, v) for k, v in parameters.items() if not k.startswith(prefix+'/'))
elif update:
parameters = dict(**(self.get_parameters() or {}))
else: else:
parameters = dict() parameters = dict()
parameters.update(itertools.chain.from_iterable(x.items() if isinstance(x, dict) else x for x in args))
parameters.update(kwargs)
not_allowed = { parameters.update(new_parameters)
k: type(v).__name__
for k, v in parameters.items()
if not isinstance(v, self._parameters_allowed_types)
}
if not_allowed:
raise ValueError(
"Only builtin types ({}) are allowed for values (got {})".format(
', '.join(t.__name__ for t in self._parameters_allowed_types),
', '.join('%s=>%s' % p for p in not_allowed.items())),
)
# force cast all variables to strings (so that we can later edit them in UI) # force cast all variables to strings (so that we can later edit them in UI)
parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()} parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()}
execution = self.data.execution if use_hyperparams:
if execution is None: # build nested dict from flat parameters dict:
execution = tasks.Execution( org_hyperparams = self.data.hyperparams or {}
parameters=parameters, artifacts=[], dataviews=[], model='', hyperparams = dict()
model_desc={}, model_labels={}, docker_cmd='')
else:
execution.parameters = parameters
self._edit(execution=execution)
def set_parameter(self, name, value, description=None): # if the task is a legacy task, we should put everything back under _legacy
# type: (str, str, Optional[str]) -> () if self.data.hyperparams and '_legacy' in self.data.hyperparams:
for k, v in parameters.items():
section_name, key = '_legacy', k
section = hyperparams.get(section_name, dict())
description = \
descriptions.get(k) or \
org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).description
param_type = \
params_types.get(k) or \
org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).type
section[key] = tasks.ParamsItem(
section=section_name, name=key, value=v, description=description, type=str(param_type))
hyperparams[section_name] = section
else:
for k, v in parameters.items():
org_k = k
if '/' not in k:
k = 'General/{}'.format(k)
section_name, key = k.split('/', 1)
section = hyperparams.get(section_name, dict())
description = \
descriptions.get(org_k) or \
org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).description
section[key] = tasks.ParamsItem\
(section=section_name, name=key, value=v, description=description)
hyperparams[section_name] = section
self._edit(hyperparams=hyperparams)
else:
execution = self.data.execution
if execution is None:
execution = tasks.Execution(
parameters=parameters, artifacts=[], dataviews=[], model='',
model_desc={}, model_labels={}, docker_cmd='')
else:
execution.parameters = parameters
self._edit(execution=execution)
def set_parameter(self, name, value, description=None, value_type=None):
# type: (str, str, Optional[str], Optional[Any]) -> ()
""" """
Set a single Task parameter. This overrides any previous value for this parameter. Set a single Task parameter. This overrides any previous value for this parameter.
:param name: The parameter name. :param name: The parameter name.
:param value: The parameter value. :param value: The parameter value.
:param description: The parameter description. :param description: The parameter description.
:param value_type: The type of the parameters (cast to string and store)
.. note::
The ``description`` is not yet in use.
""" """
# not supported yet if not Session.check_min_api_version('2.9'):
if description: # not supported yet
# noinspection PyUnusedLocal
description = None description = None
self.set_parameters({name: value}, __update=True) value_type = None
self._set_parameters(
{name: value}, __update=True,
__parameters_descriptions={name: description},
__parameters_types={name: value_type}
)
def get_parameter(self, name, default=None): def get_parameter(self, name, default=None):
# type: (str, Any) -> Any # type: (str, Any) -> Any
@ -856,7 +953,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
merged into a single key-value pair dictionary. merged into a single key-value pair dictionary.
:param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``. :param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``.
""" """
self.set_parameters(__update=True, *args, **kwargs) self._set_parameters(*args, __update=True, **kwargs)
def set_model_label_enumeration(self, enumeration=None): def set_model_label_enumeration(self, enumeration=None):
# type: (Mapping[str, int]) -> () # type: (Mapping[str, int]) -> ()
@ -1316,7 +1413,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._update_requirements('') self._update_requirements('')
if Session.check_min_api_version('2.3'): if Session.check_min_api_version('2.9'):
self._set_task_property("system_tags", system_tags)
self._edit(system_tags=self._data.system_tags, comment=self._data.comment,
script=self._data.script, execution=self._data.execution, output_dest='',
hyperparams=dict(), configuration=dict())
elif Session.check_min_api_version('2.3'):
self._set_task_property("system_tags", system_tags) self._set_task_property("system_tags", system_tags)
self._edit(system_tags=self._data.system_tags, comment=self._data.comment, self._edit(system_tags=self._data.system_tags, comment=self._data.comment,
script=self._data.script, execution=self._data.execution, output_dest='') script=self._data.script, execution=self._data.execution, output_dest='')
@ -1386,6 +1488,67 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self.data.script = script self.data.script = script
self._edit(script=script) self._edit(script=script)
def _set_configuration(self, name, description=None, config_type=None, config_text=None, config_dict=None):
# type: (str, Optional[str], Optional[str], Optional[str], Optional[Mapping]) -> None
"""
Set Task configuration text/dict. Multiple configurations are supported.
:param str name: Configuration name.
:param str description: Configuration section description.
:param config_text: model configuration (unconstrained text string). usually the content
of a configuration file. If `config_text` is not None, `config_dict` must not be provided.
:param config_dict: model configuration parameters dictionary.
If `config_dict` is not None, `config_text` must not be provided.
"""
# make sure we have wither dict or text
mutually_exclusive(config_dict=config_dict, config_text=config_text)
if not Session.check_min_api_version('2.9'):
raise ValueError("Multiple configurations are not supported with the current 'trains-server', "
"please upgrade to the latest version")
if description:
description = str(description)
config = config_dict_to_text(config_text or config_dict)
with self._edit_lock:
self.reload()
configuration = self.data.configuration or {}
configuration[name] = tasks.ConfigurationItem(
name=name, value=config, description=description or None, type=config_type or None)
self._edit(configuration=configuration)
def _get_configuration_text(self, name):
# type: (str) -> Optional[str]
"""
Get Task configuration section as text
:param str name: Configuration name.
:return: The Task configuration as text (unconstrained text string).
return None if configuration name is not valid.
"""
if not Session.check_min_api_version('2.9'):
raise ValueError("Multiple configurations are not supported with the current 'trains-server', "
"please upgrade to the latest version")
configuration = self.data.configuration or {}
if not configuration.get(name):
return None
return configuration[name].value
def _get_configuration_dict(self, name):
# type: (str) -> Optional[dict]
"""
Get Task configuration section as dictionary
:param str name: Configuration name.
:return: The Task configuration as dictionary.
return None if configuration name is not valid.
"""
config_text = self._get_configuration_text(name)
if not config_text:
return None
return text_to_config_dict(config_text)
def get_offline_mode_folder(self): def get_offline_mode_folder(self):
# type: () -> (Optional[Path]) # type: () -> (Optional[Path])
""" """

View File

@ -339,7 +339,7 @@ class BaseModel(object):
def _config_dict_to_text(config): def _config_dict_to_text(config):
if not isinstance(config, six.string_types) and not isinstance(config, dict): if not isinstance(config, six.string_types) and not isinstance(config, dict):
raise ValueError("Model configuration only supports dictionary or string objects") raise ValueError("Model configuration only supports dictionary or string objects")
return config_dict_to_text return config_dict_to_text(config)
@staticmethod @staticmethod
def _text_to_config_dict(text): def _text_to_config_dict(text):

View File

@ -124,6 +124,7 @@ class Task(_Task):
__task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0)) __task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0))
__detect_repo_async = config.get('development.vcs_repo_detect_async', False) __detect_repo_async = config.get('development.vcs_repo_detect_async', False)
__default_output_uri = config.get('development.default_output_uri', None) __default_output_uri = config.get('development.default_output_uri', None)
__default_configuration_name = 'General'
class _ConnectedParametersType(object): class _ConnectedParametersType(object):
argparse = "argument_parser" argparse = "argument_parser"
@ -904,8 +905,8 @@ class Task(_Task):
self.data.tags.extend(tags) self.data.tags.extend(tags)
self._edit(tags=list(set(self.data.tags))) self._edit(tags=list(set(self.data.tags)))
def connect(self, mutable): def connect(self, mutable, name=None):
# type: (Any) -> Any # type: (Any, Optional[str]) -> Any
""" """
Connect an object to a Task object. This connects an experiment component (part of an experiment) to the Connect an object to a Task object. This connects an experiment component (part of an experiment) to the
experiment. For example, connect hyperparameters or models. experiment. For example, connect hyperparameters or models.
@ -918,6 +919,12 @@ class Task(_Task):
- TaskParameters - A TaskParameters object. - TaskParameters - A TaskParameters object.
- model - A model object for initial model warmup, or for model update/snapshot uploading. - model - A model object for initial model warmup, or for model update/snapshot uploading.
:param str name: A section name associated with the connected object. Default: 'General'
Currently only supported for `dict` / `TaskParameter` objects
Examples:
name='General' will put the connected dictionary under the General section in the hyper-parameters
name='Train' will put the connected dictionary under the Train section in the hyper-parameters
:return: The result returned when connecting the object, if supported. :return: The result returned when connecting the object, if supported.
:raise: Raise an exception on unsupported objects. :raise: Raise an exception on unsupported objects.
@ -931,14 +938,22 @@ class Task(_Task):
(TaskParameters, self._connect_task_parameters), (TaskParameters, self._connect_task_parameters),
) )
multi_config_support = Session.check_min_api_version('2.9')
if multi_config_support and not name:
name = self.__default_configuration_name
if not multi_config_support and name and name != self.__default_configuration_name:
raise ValueError("Multiple configurations are not supported with the current 'trains-server', "
"please upgrade to the latest version")
for mutable_type, method in dispatch: for mutable_type, method in dispatch:
if isinstance(mutable, mutable_type): if isinstance(mutable, mutable_type):
return method(mutable) return method(mutable, name=name)
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__) raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
def connect_configuration(self, configuration): def connect_configuration(self, configuration, name=None, description=None):
# type: (Union[Mapping, Path, str]) -> Union[Mapping, Path, str] # type: (Union[Mapping, Path, str], Optional[str], Optional[str]) -> Union[Mapping, Path, str]
""" """
Connect a configuration dictionary or configuration file (pathlib.Path / str) to a Task object. Connect a configuration dictionary or configuration file (pathlib.Path / str) to a Task object.
This method should be called before reading the configuration file. This method should be called before reading the configuration file.
@ -968,25 +983,54 @@ class Task(_Task):
A local path must be relative path. When executing a Task remotely in a worker, the contents brought A local path must be relative path. When executing a Task remotely in a worker, the contents brought
from the **Trains Server** (backend) overwrites the contents of the file. from the **Trains Server** (backend) overwrites the contents of the file.
:param str name: Configuration section name. default: 'General'
Allowing users to store multiple configuration dicts/files
:param str description: Configuration section description (text). default: None
:return: If a dictionary is specified, then a dictionary is returned. If pathlib2.Path / string is :return: If a dictionary is specified, then a dictionary is returned. If pathlib2.Path / string is
specified, then a path to a local configuration file is returned. Configuration object. specified, then a path to a local configuration file is returned. Configuration object.
""" """
pathlib_Path = None
if not isinstance(configuration, (dict, Path, six.string_types)): if not isinstance(configuration, (dict, Path, six.string_types)):
raise ValueError("connect_configuration supports `dict`, `str` and 'Path' types, " try:
"{} is not supported".format(type(configuration))) from pathlib import Path as pathlib_Path
except ImportError:
pass
if not pathlib_Path or not isinstance(configuration, pathlib_Path):
raise ValueError("connect_configuration supports `dict`, `str` and 'Path' types, "
"{} is not supported".format(type(configuration)))
multi_config_support = Session.check_min_api_version('2.9')
if multi_config_support and not name:
name = self.__default_configuration_name
if not multi_config_support and name and name != self.__default_configuration_name:
raise ValueError("Multiple configurations are not supported with the current 'trains-server', "
"please upgrade to the latest version")
# parameter dictionary # parameter dictionary
if isinstance(configuration, dict): if isinstance(configuration, dict):
def _update_config_dict(task, config_dict): def _update_config_dict(task, config_dict):
# noinspection PyProtectedMember if multi_config_support:
task._set_model_config(config_dict=config_dict) # noinspection PyProtectedMember
task._set_configuration(
name=name, description=description, config_type='dictionary', config_dict=config_dict)
else:
# noinspection PyProtectedMember
task._set_model_config(config_dict=config_dict)
if not running_remotely() or not self.is_main_task(): if not running_remotely() or not self.is_main_task():
self._set_model_config(config_dict=configuration) if multi_config_support:
self._set_configuration(
name=name, description=description, config_type='dictionary', config_dict=configuration)
else:
self._set_model_config(config_dict=configuration)
configuration = ProxyDictPostWrite(self, _update_config_dict, **configuration) configuration = ProxyDictPostWrite(self, _update_config_dict, **configuration)
else: else:
configuration.clear() configuration.clear()
configuration.update(self._get_model_config_dict()) configuration.update(self._get_configuration_dict(name=name) if multi_config_support
else self._get_model_config_dict())
configuration = ProxyDictPreWrite(False, False, **configuration) configuration = ProxyDictPreWrite(False, False, **configuration)
return configuration return configuration
@ -1002,16 +1046,26 @@ class Task(_Task):
except Exception: except Exception:
raise ValueError("Could not connect configuration file {}, file could not be read".format( raise ValueError("Could not connect configuration file {}, file could not be read".format(
configuration_path.as_posix())) configuration_path.as_posix()))
self._set_model_config(config_text=configuration_text) if multi_config_support:
self._set_configuration(
name=name, description=description,
config_type=configuration_path.suffixes[-1].lstrip('.')
if configuration_path.suffixes and configuration_path.suffixes[-1] else 'file',
config_text=configuration_text)
else:
self._set_model_config(config_text=configuration_text)
return configuration return configuration
else: else:
configuration_text = self._get_model_config_text() configuration_text = self._get_configuration_text(name=name) if multi_config_support \
else self._get_model_config_text()
configuration_path = Path(configuration) configuration_path = Path(configuration)
fd, local_filename = mkstemp(prefix='trains_task_config_', fd, local_filename = mkstemp(prefix='trains_task_config_',
suffix=configuration_path.suffixes[-1] if suffix=configuration_path.suffixes[-1] if
configuration_path.suffixes else '.txt') configuration_path.suffixes else '.txt')
os.write(fd, configuration_text.encode('utf-8')) os.write(fd, configuration_text.encode('utf-8'))
os.close(fd) os.close(fd)
if pathlib_Path:
return pathlib_Path(local_filename)
return Path(local_filename) if isinstance(configuration, Path) else local_filename return Path(local_filename) if isinstance(configuration, Path) else local_filename
def connect_label_enumeration(self, enumeration): def connect_label_enumeration(self, enumeration):
@ -1651,6 +1705,7 @@ class Task(_Task):
# noinspection PyProtectedMember # noinspection PyProtectedMember
offline_folder = os.path.join(export_data.get('offline_folder', ''), 'data/') offline_folder = os.path.join(export_data.get('offline_folder', ''), 'data/')
# noinspection PyProtectedMember
remote_url = task._get_default_report_storage_uri() remote_url = task._get_default_report_storage_uri()
if remote_url and remote_url.endswith('/'): if remote_url and remote_url.endswith('/'):
remote_url = remote_url[:-1] remote_url = remote_url[:-1]
@ -1977,7 +2032,7 @@ class Task(_Task):
return self._logger return self._logger
def _connect_output_model(self, model): def _connect_output_model(self, model, name=None):
assert isinstance(model, OutputModel) assert isinstance(model, OutputModel)
model.connect(self) model.connect(self)
return model return model
@ -2001,7 +2056,7 @@ class Task(_Task):
if self._connected_output_model: if self._connected_output_model:
self.connect(self._connected_output_model) self.connect(self._connected_output_model)
def _connect_input_model(self, model): def _connect_input_model(self, model, name=None):
assert isinstance(model, InputModel) assert isinstance(model, InputModel)
# we only allow for an input model to be connected once # we only allow for an input model to be connected once
# at least until we support multiple input models # at least until we support multiple input models
@ -2039,7 +2094,7 @@ class Task(_Task):
# added support for multiple type connections through _Arguments # added support for multiple type connections through _Arguments
return option return option
def _connect_argparse(self, parser, args=None, namespace=None, parsed_args=None): def _connect_argparse(self, parser, args=None, namespace=None, parsed_args=None, name=None):
# do not allow argparser to connect to jupyter notebook # do not allow argparser to connect to jupyter notebook
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -2079,15 +2134,15 @@ class Task(_Task):
parser, args=args, namespace=namespace, parsed_args=parsed_args) parser, args=args, namespace=namespace, parsed_args=parsed_args)
return parser return parser
def _connect_dictionary(self, dictionary): def _connect_dictionary(self, dictionary, name=None):
def _update_args_dict(task, config_dict): def _update_args_dict(task, config_dict):
# noinspection PyProtectedMember # noinspection PyProtectedMember
task._arguments.copy_from_dict(flatten_dictionary(config_dict)) task._arguments.copy_from_dict(flatten_dictionary(config_dict), prefix=name)
def _refresh_args_dict(task, config_dict): def _refresh_args_dict(task, config_dict):
# reread from task including newly added keys # reread from task including newly added keys
# noinspection PyProtectedMember # noinspection PyProtectedMember
a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict)) a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict), prefix=name)
# noinspection PyProtectedMember # noinspection PyProtectedMember
nested_dict = config_dict._to_dict() nested_dict = config_dict._to_dict()
config_dict.clear() config_dict.clear()
@ -2096,23 +2151,28 @@ class Task(_Task):
self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary) self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary)
if not running_remotely() or not self.is_main_task(): if not running_remotely() or not self.is_main_task():
self._arguments.copy_from_dict(flatten_dictionary(dictionary)) self._arguments.copy_from_dict(flatten_dictionary(dictionary), prefix=name)
dictionary = ProxyDictPostWrite(self, _update_args_dict, **dictionary) dictionary = ProxyDictPostWrite(self, _update_args_dict, **dictionary)
else: else:
flat_dict = flatten_dictionary(dictionary) flat_dict = flatten_dictionary(dictionary)
flat_dict = self._arguments.copy_to_dict(flat_dict) flat_dict = self._arguments.copy_to_dict(flat_dict, prefix=name)
dictionary = nested_from_flat_dictionary(dictionary, flat_dict) dictionary = nested_from_flat_dictionary(dictionary, flat_dict)
dictionary = ProxyDictPostWrite(self, _refresh_args_dict, **dictionary) dictionary = ProxyDictPostWrite(self, _refresh_args_dict, **dictionary)
return dictionary return dictionary
def _connect_task_parameters(self, attr_class): def _connect_task_parameters(self, attr_class, name=None):
self._try_set_connected_parameter_type(self._ConnectedParametersType.task_parameters) self._try_set_connected_parameter_type(self._ConnectedParametersType.task_parameters)
if running_remotely() and self.is_main_task(): if running_remotely() and self.is_main_task():
attr_class.update_from_dict(self.get_parameters()) parameters = self.get_parameters()
if not name:
attr_class.update_from_dict(parameters)
else:
attr_class.update_from_dict(
dict((k[len(name)+1:], v) for k, v in parameters.items() if k.startswith('{}/'.format(name))))
else: else:
self.set_parameters(attr_class.to_dict()) self.set_parameters(attr_class.to_dict(), __parameters_prefix=name)
return attr_class return attr_class
def _validate(self, check_output_dest_credentials=False): def _validate(self, check_output_dest_credentials=False):