From e378de1e41fef8c0647cfafa044a511cf1a9702b Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 8 Aug 2020 12:35:03 +0300 Subject: [PATCH] Add multi configuration section support (hyperparams and configurations) Support setting offline mode API version using TRAINS_OFFLINE_MODE env var --- trains/backend_api/session/datamodel.py | 6 +- trains/backend_api/session/session.py | 16 +- trains/backend_interface/task/args.py | 87 +++++---- trains/backend_interface/task/task.py | 235 ++++++++++++++++++++---- trains/model.py | 2 +- trains/task.py | 108 ++++++++--- 6 files changed, 355 insertions(+), 99 deletions(-) diff --git a/trains/backend_api/session/datamodel.py b/trains/backend_api/session/datamodel.py index ff733ce0..72d96382 100644 --- a/trains/backend_api/session/datamodel.py +++ b/trains/backend_api/session/datamodel.py @@ -64,7 +64,11 @@ class DataModel(object): @classmethod def _to_base_type(cls, value): - if isinstance(value, DataModel): + if isinstance(value, dict): + # Note: this should come before DataModel to handle data models that are simply a dict + # (and thus are not expected to have additional named properties) + return {k: cls._to_base_type(v) for k, v in value.items()} + elif isinstance(value, DataModel): return value.to_dict() elif isinstance(value, enum.Enum): return value.value diff --git a/trains/backend_api/session/session.py b/trains/backend_api/session/session.py index 6ca3c54d..6ae4ca4d 100644 --- a/trains/backend_api/session/session.py +++ b/trains/backend_api/session/session.py @@ -62,6 +62,7 @@ class Session(TokenManager): default_files = "https://demofiles.trains.allegro.ai" default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW" default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8" + force_max_api_version = None # TODO: add requests.codes.gateway_timeout once we support async commits _retry_codes = [ @@ -182,6 +183,9 @@ class Session(TokenManager): self.__class__._sessions_created += 1 + if self.force_max_api_version and self.check_min_api_version(self.force_max_api_version): + Session.api_version = str(self.force_max_api_version) + def _send_request( self, service, @@ -539,7 +543,17 @@ class Session(TokenManager): # If no session was created, create a default one, in order to get the backend api version. if cls._sessions_created <= 0: if cls._offline_mode: - cls.api_version = cls._offline_default_version + # allow to change the offline mode version by setting ENV_OFFLINE_MODE to the required API version + if cls.api_version != cls._offline_default_version: + offline_api = ENV_OFFLINE_MODE.get(converter=lambda x: x) + if offline_api: + try: + # check cast to float, but leave original str if we pass it. + float(offline_api) + cls._offline_default_version = str(offline_api) + except ValueError: + pass + cls.api_version = cls._offline_default_version else: # noinspection PyBroadException try: diff --git a/trains/backend_interface/task/args.py b/trains/backend_interface/task/args.py index 2555c754..6465acfc 100644 --- a/trains/backend_interface/task/args.py +++ b/trains/backend_interface/task/args.py @@ -10,32 +10,35 @@ from ...utilities.args import call_original_argparser class _Arguments(object): _prefix_sep = '/' # TODO: separate dict and argparse after we add UI support - _prefix_dict = 'dict' + _prefix_sep - _prefix_args = 'argparse' + _prefix_sep + _prefix_args = 'Args' + _prefix_sep _prefix_tf_defines = 'TF_DEFINE' + _prefix_sep class _ProxyDictWrite(dict): """ Dictionary wrapper that updates an arguments instance on any item set in the dictionary """ - def __init__(self, arguments, *args, **kwargs): + def __init__(self, __arguments, __section_name, *args, **kwargs): super(_Arguments._ProxyDictWrite, self).__init__(*args, **kwargs) - self._arguments = arguments + self._arguments = __arguments + self._section_name = (__section_name.strip(_Arguments._prefix_sep) + _Arguments._prefix_sep) \ + if __section_name else None def __setitem__(self, key, value): super(_Arguments._ProxyDictWrite, self).__setitem__(key, value) if self._arguments: - self._arguments.copy_from_dict(self) + self._arguments.copy_from_dict(self, prefix=self._section_name) class _ProxyDictReadOnly(dict): """ Dictionary wrapper that prevents modifications to the dictionary """ - def __init__(self, arguments, *args, **kwargs): + def __init__(self, __arguments, __section_name, *args, **kwargs): super(_Arguments._ProxyDictReadOnly, self).__init__(*args, **kwargs) - self._arguments = arguments + self._arguments = __arguments + self._section_name = (__section_name.strip(_Arguments._prefix_sep) + _Arguments._prefix_sep) \ + if __section_name else None def __setitem__(self, key, value): if self._arguments: - param_dict = self._arguments.copy_to_dict({key: value}) + param_dict = self._arguments.copy_to_dict({key: value}, prefix=self._section_name) value = param_dict.get(key, value) super(_Arguments._ProxyDictReadOnly, self).__setitem__(key, value) @@ -45,24 +48,32 @@ class _Arguments(object): self._exclude_parser_args = {} def exclude_parser_args(self, excluded_args): + """ + You can use a dictionary for fined grained control of connected + arguments. The dictionary keys are argparse variable names and the values are booleans. + The ``False`` value excludes the specified argument from the Task's parameter section. + Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``. + + :param excluded_args: dict + """ self._exclude_parser_args = excluded_args or {} def set_defaults(self, *dicts, **kwargs): - self._task.set_parameters(*dicts, **kwargs) + # noinspection PyProtectedMember + self._task._set_parameters(*dicts, __parameters_prefix=self._prefix_args, **kwargs) def add_argument(self, option_strings, type=None, default=None, help=None): if not option_strings: raise Exception('Expected at least one argument name (option string)') name = option_strings[0].strip('- \t') if isinstance(option_strings, list) else option_strings.strip('- \t') - # TODO: add argparse prefix - # name = self._prefix_args + name + name = self._prefix_args + name self._task.set_parameter(name=name, value=default, description=help) def connect(self, parser): self._task.connect_argparse(parser) @classmethod - def _add_to_defaults(cls, a_parser, defaults, a_args=None, a_namespace=None, a_parsed_args=None): + def _add_to_defaults(cls, a_parser, defaults, descriptions, a_args=None, a_namespace=None, a_parsed_args=None): actions = [ a for a in a_parser._actions if isinstance(a, _StoreAction) or isinstance(a, _StoreConstAction) @@ -87,6 +98,9 @@ class _Arguments(object): for a in actions } + desc_ = {a.dest: a.help for a in actions} + descriptions.update(desc_) + full_args_dict = copy(defaults) full_args_dict.update(args_dict) defaults.update(defaults_) @@ -101,17 +115,19 @@ class _Arguments(object): defaults[sub_parser.dest] = full_args_dict.get(sub_parser.dest) or '' for choice in sub_parser.choices.values(): # recursively parse - defaults = cls._add_to_defaults( + defaults, descriptions = cls._add_to_defaults( a_parser=choice, defaults=defaults, + descriptions=descriptions, a_parsed_args=a_parsed_args or full_args_dict ) - return defaults + return defaults, descriptions def copy_defaults_from_argparse(self, parser, args=None, namespace=None, parsed_args=None): task_defaults = {} - self._add_to_defaults(parser, task_defaults, args, namespace, parsed_args) + task_defaults_descriptions = {} + self._add_to_defaults(parser, task_defaults, task_defaults_descriptions, args, namespace, parsed_args) # Make sure we didn't miss anything if parsed_args: @@ -133,12 +149,13 @@ class _Arguments(object): except Exception: del task_defaults[k] - # Skip excluded arguments, Add prefix, TODO: add argparse prefix - # task_defaults = dict([(self._prefix_args + k, v) for k, v in task_defaults.items() - # if k not in self._exclude_parser_args]) - task_defaults = dict([(k, v) for k, v in task_defaults.items() if self._exclude_parser_args.get(k, True)]) + # Skip excluded arguments, Add prefix. + task_defaults = dict([(self._prefix_args + k, v) for k, v in task_defaults.items() + if self._exclude_parser_args.get(k, True)]) + task_defaults_descriptions = dict([(self._prefix_args + k, v) for k, v in task_defaults_descriptions.items() + if self._exclude_parser_args.get(k, True)]) # Store to task - self._task.update_parameters(task_defaults) + self._task.update_parameters(task_defaults, __parameters_descriptions=task_defaults_descriptions) @classmethod def _find_parser_action(cls, a_parser, name): @@ -158,11 +175,10 @@ class _Arguments(object): return _actions def copy_to_parser(self, parser, parsed_args): - # todo: change to argparse prefix only - # task_arguments = dict([(k[len(self._prefix_args):], v) for k, v in self._task.get_parameters().items() - # if k.startswith(self._prefix_args)]) - task_arguments = dict([(k, v) for k, v in self._task.get_parameters().items() - if not k.startswith(self._prefix_tf_defines) and self._exclude_parser_args.get(k, True)]) + # Change to argparse prefix only + task_arguments = dict([(k[len(self._prefix_args):], v) for k, v in self._task.get_parameters().items() + if k.startswith(self._prefix_args) and + self._exclude_parser_args.get(k[len(self._prefix_args):], True)]) arg_parser_argeuments = {} for k, v in task_arguments.items(): # python2 unicode support @@ -304,25 +320,24 @@ class _Arguments(object): parser.set_defaults(**arg_parser_argeuments) def copy_from_dict(self, dictionary, prefix=None): - # TODO: add dict prefix - prefix = prefix or '' # self._prefix_dict + # add dict prefix + prefix = prefix # or self._prefix_dict if prefix: - with self._task._edit_lock: - prefix_dictionary = dict([(prefix + k, v) for k, v in dictionary.items()]) - cur_params = dict([(k, v) for k, v in self._task.get_parameters().items() if not k.startswith(prefix)]) - cur_params.update(prefix_dictionary) - self._task.set_parameters(cur_params) + prefix = prefix.strip(self._prefix_sep) + self._prefix_sep + # this will only set the specific section + self._task.set_parameters(dictionary, __parameters_prefix=prefix) else: self._task.update_parameters(dictionary) if not isinstance(dictionary, self._ProxyDictWrite): - return self._ProxyDictWrite(self, **dictionary) + return self._ProxyDictWrite(self, prefix, **dictionary) return dictionary def copy_to_dict(self, dictionary, prefix=None): # iterate over keys and merge values according to parameter type in dictionary - # TODO: add dict prefix - prefix = prefix or '' # self._prefix_dict + # add dict prefix + prefix = prefix # or self._prefix_dict if prefix: + prefix = prefix.strip(self._prefix_sep) + self._prefix_sep parameters = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items() if k.startswith(prefix)]) else: @@ -396,5 +411,5 @@ class _Arguments(object): # dictionary[k] = v if not isinstance(dictionary, self._ProxyDictReadOnly): - return self._ProxyDictReadOnly(self, **dictionary) + return self._ProxyDictReadOnly(self, prefix, **dictionary) return dictionary diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index bc30d0ce..05e4fd60 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -30,13 +30,15 @@ from ...backend_api import Session from ...backend_api.services import tasks, models, events, projects from ...backend_api.session.defs import ENV_OFFLINE_MODE from ...utilities.pyhocon import ConfigTree, ConfigFactory +from ...utilities.config import config_dict_to_text, text_to_config_dict from ..base import IdObjectBase, InterfaceBase from ..metrics import Metrics, Reporter from ..model import Model from ..setupuploadmixin import SetupUploadMixin -from ..util import make_message, get_or_create_project, get_single_result, \ - exact_match_regex +from ..util import ( + make_message, get_or_create_project, get_single_result, + exact_match_regex, mutually_exclusive, ) from ...config import ( get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, running_remotely, get_cache_dir, DOCKER_IMAGE_ENV_VAR, get_offline_dir) @@ -768,12 +770,52 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): self._edit(execution=self.data.execution) + def get_parameters(self, backwards_compatibility=True): + # type: (bool) -> (Optional[dict]) + """ + Get the parameters for a Task. This method returns a complete group of key-value parameter pairs, but does not + support parameter descriptions (the result is a dictionary of key-value pairs). + :param backwards_compatibility: If True (default) parameters without section name + (API version < 2.9, trains-server < 0.16) will be at dict root level. + If False, parameters without section name, will be nested under "general/" key. + :return: dict of the task parameters, all flattened to key/value. + Different sections with key prefix "section/" + """ + if not Session.check_min_api_version('2.9'): + return self._get_task_property('execution.parameters') + + # API will makes sure we get old parameters under _legacy + parameters = dict() + hyperparams = self._get_task_property('hyperparams', default=dict()) + if len(hyperparams) == 1 and '_legacy' in hyperparams and backwards_compatibility: + for section in ('_legacy', ): + for key, section_param in hyperparams[section].items(): + parameters['{}'.format(key)] = section_param.value + else: + for section in hyperparams: + for key, section_param in hyperparams[section].items(): + parameters['{}/{}'.format(section, key)] = section_param.value + + return parameters + def set_parameters(self, *args, **kwargs): # type: (*dict, **Any) -> () """ Set the parameters for a Task. This method sets a complete group of key-value parameter pairs, but does not support parameter descriptions (the input is a dictionary of key-value pairs). + :param args: Positional arguments, which are one or more dictionary or (key, value) iterable. They are + merged into a single key-value pair dictionary. + :param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``. + """ + return self._set_parameters(*args, __update=False, **kwargs) + + def _set_parameters(self, *args, **kwargs): + # type: (*dict, **Any) -> () + """ + Set the parameters for a Task. This method sets a complete group of key-value parameter pairs, but does not + support parameter descriptions (the input is a dictionary of key-value pairs). + :param args: Positional arguments, which are one or more dictionary or (key, value) iterable. They are merged into a single key-value pair dictionary. :param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``. @@ -781,58 +823,113 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): if not all(isinstance(x, (dict, Iterable)) for x in args): raise ValueError('only dict or iterable are supported as positional arguments') + prefix = kwargs.pop('__parameters_prefix', None) + descriptions = kwargs.pop('__parameters_descriptions', None) or dict() + params_types = kwargs.pop('__parameters_types', None) or dict() update = kwargs.pop('__update', False) + # new parameters dict + new_parameters = dict(itertools.chain.from_iterable(x.items() if isinstance(x, dict) else x for x in args)) + new_parameters.update(kwargs) + if prefix: + prefix = prefix.strip('/') + new_parameters = dict(('{}/{}'.format(prefix, k), v) for k, v in new_parameters.items()) + + # verify parameters type: + not_allowed = { + k: type(v).__name__ + for k, v in new_parameters.items() + if not isinstance(v, self._parameters_allowed_types) + } + if not_allowed: + raise ValueError( + "Only builtin types ({}) are allowed for values (got {})".format( + ', '.join(t.__name__ for t in self._parameters_allowed_types), + ', '.join('%s=>%s' % p for p in not_allowed.items())), + ) + + use_hyperparams = Session.check_min_api_version('2.9') + with self._edit_lock: self.reload() - if update: - parameters = self.get_parameters() + # if we have a specific prefix and we use hyperparameters, and we use set. + # overwrite only the prefix, leave the rest as is. + if not update and prefix: + parameters = dict(**(self.get_parameters() or {})) + parameters = dict((k, v) for k, v in parameters.items() if not k.startswith(prefix+'/')) + elif update: + parameters = dict(**(self.get_parameters() or {})) else: parameters = dict() - parameters.update(itertools.chain.from_iterable(x.items() if isinstance(x, dict) else x for x in args)) - parameters.update(kwargs) - not_allowed = { - k: type(v).__name__ - for k, v in parameters.items() - if not isinstance(v, self._parameters_allowed_types) - } - if not_allowed: - raise ValueError( - "Only builtin types ({}) are allowed for values (got {})".format( - ', '.join(t.__name__ for t in self._parameters_allowed_types), - ', '.join('%s=>%s' % p for p in not_allowed.items())), - ) + parameters.update(new_parameters) # force cast all variables to strings (so that we can later edit them in UI) parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()} - execution = self.data.execution - if execution is None: - execution = tasks.Execution( - parameters=parameters, artifacts=[], dataviews=[], model='', - model_desc={}, model_labels={}, docker_cmd='') - else: - execution.parameters = parameters - self._edit(execution=execution) + if use_hyperparams: + # build nested dict from flat parameters dict: + org_hyperparams = self.data.hyperparams or {} + hyperparams = dict() - def set_parameter(self, name, value, description=None): - # type: (str, str, Optional[str]) -> () + # if the task is a legacy task, we should put everything back under _legacy + if self.data.hyperparams and '_legacy' in self.data.hyperparams: + for k, v in parameters.items(): + section_name, key = '_legacy', k + section = hyperparams.get(section_name, dict()) + description = \ + descriptions.get(k) or \ + org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).description + param_type = \ + params_types.get(k) or \ + org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).type + section[key] = tasks.ParamsItem( + section=section_name, name=key, value=v, description=description, type=str(param_type)) + hyperparams[section_name] = section + else: + for k, v in parameters.items(): + org_k = k + if '/' not in k: + k = 'General/{}'.format(k) + section_name, key = k.split('/', 1) + section = hyperparams.get(section_name, dict()) + description = \ + descriptions.get(org_k) or \ + org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).description + section[key] = tasks.ParamsItem\ + (section=section_name, name=key, value=v, description=description) + hyperparams[section_name] = section + self._edit(hyperparams=hyperparams) + else: + execution = self.data.execution + if execution is None: + execution = tasks.Execution( + parameters=parameters, artifacts=[], dataviews=[], model='', + model_desc={}, model_labels={}, docker_cmd='') + else: + execution.parameters = parameters + self._edit(execution=execution) + + def set_parameter(self, name, value, description=None, value_type=None): + # type: (str, str, Optional[str], Optional[Any]) -> () """ Set a single Task parameter. This overrides any previous value for this parameter. :param name: The parameter name. :param value: The parameter value. :param description: The parameter description. - - .. note:: - The ``description`` is not yet in use. + :param value_type: The type of the parameters (cast to string and store) """ - # not supported yet - if description: - # noinspection PyUnusedLocal + if not Session.check_min_api_version('2.9'): + # not supported yet description = None - self.set_parameters({name: value}, __update=True) + value_type = None + + self._set_parameters( + {name: value}, __update=True, + __parameters_descriptions={name: description}, + __parameters_types={name: value_type} + ) def get_parameter(self, name, default=None): # type: (str, Any) -> Any @@ -856,7 +953,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): merged into a single key-value pair dictionary. :param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``. """ - self.set_parameters(__update=True, *args, **kwargs) + self._set_parameters(*args, __update=True, **kwargs) def set_model_label_enumeration(self, enumeration=None): # type: (Mapping[str, int]) -> () @@ -1316,7 +1413,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): self._update_requirements('') - if Session.check_min_api_version('2.3'): + if Session.check_min_api_version('2.9'): + self._set_task_property("system_tags", system_tags) + self._edit(system_tags=self._data.system_tags, comment=self._data.comment, + script=self._data.script, execution=self._data.execution, output_dest='', + hyperparams=dict(), configuration=dict()) + elif Session.check_min_api_version('2.3'): self._set_task_property("system_tags", system_tags) self._edit(system_tags=self._data.system_tags, comment=self._data.comment, script=self._data.script, execution=self._data.execution, output_dest='') @@ -1386,6 +1488,67 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): self.data.script = script self._edit(script=script) + def _set_configuration(self, name, description=None, config_type=None, config_text=None, config_dict=None): + # type: (str, Optional[str], Optional[str], Optional[str], Optional[Mapping]) -> None + """ + Set Task configuration text/dict. Multiple configurations are supported. + + :param str name: Configuration name. + :param str description: Configuration section description. + :param config_text: model configuration (unconstrained text string). usually the content + of a configuration file. If `config_text` is not None, `config_dict` must not be provided. + :param config_dict: model configuration parameters dictionary. + If `config_dict` is not None, `config_text` must not be provided. + """ + # make sure we have wither dict or text + mutually_exclusive(config_dict=config_dict, config_text=config_text) + + if not Session.check_min_api_version('2.9'): + raise ValueError("Multiple configurations are not supported with the current 'trains-server', " + "please upgrade to the latest version") + + if description: + description = str(description) + config = config_dict_to_text(config_text or config_dict) + with self._edit_lock: + self.reload() + configuration = self.data.configuration or {} + configuration[name] = tasks.ConfigurationItem( + name=name, value=config, description=description or None, type=config_type or None) + self._edit(configuration=configuration) + + def _get_configuration_text(self, name): + # type: (str) -> Optional[str] + """ + Get Task configuration section as text + + :param str name: Configuration name. + :return: The Task configuration as text (unconstrained text string). + return None if configuration name is not valid. + """ + if not Session.check_min_api_version('2.9'): + raise ValueError("Multiple configurations are not supported with the current 'trains-server', " + "please upgrade to the latest version") + + configuration = self.data.configuration or {} + if not configuration.get(name): + return None + return configuration[name].value + + def _get_configuration_dict(self, name): + # type: (str) -> Optional[dict] + """ + Get Task configuration section as dictionary + + :param str name: Configuration name. + :return: The Task configuration as dictionary. + return None if configuration name is not valid. + """ + config_text = self._get_configuration_text(name) + if not config_text: + return None + return text_to_config_dict(config_text) + def get_offline_mode_folder(self): # type: () -> (Optional[Path]) """ diff --git a/trains/model.py b/trains/model.py index ced338e5..d17f72d4 100644 --- a/trains/model.py +++ b/trains/model.py @@ -339,7 +339,7 @@ class BaseModel(object): def _config_dict_to_text(config): if not isinstance(config, six.string_types) and not isinstance(config, dict): raise ValueError("Model configuration only supports dictionary or string objects") - return config_dict_to_text + return config_dict_to_text(config) @staticmethod def _text_to_config_dict(text): diff --git a/trains/task.py b/trains/task.py index 63158137..e293a4ce 100644 --- a/trains/task.py +++ b/trains/task.py @@ -124,6 +124,7 @@ class Task(_Task): __task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0)) __detect_repo_async = config.get('development.vcs_repo_detect_async', False) __default_output_uri = config.get('development.default_output_uri', None) + __default_configuration_name = 'General' class _ConnectedParametersType(object): argparse = "argument_parser" @@ -904,8 +905,8 @@ class Task(_Task): self.data.tags.extend(tags) self._edit(tags=list(set(self.data.tags))) - def connect(self, mutable): - # type: (Any) -> Any + def connect(self, mutable, name=None): + # type: (Any, Optional[str]) -> Any """ Connect an object to a Task object. This connects an experiment component (part of an experiment) to the experiment. For example, connect hyperparameters or models. @@ -918,6 +919,12 @@ class Task(_Task): - TaskParameters - A TaskParameters object. - model - A model object for initial model warmup, or for model update/snapshot uploading. + :param str name: A section name associated with the connected object. Default: 'General' + Currently only supported for `dict` / `TaskParameter` objects + Examples: + name='General' will put the connected dictionary under the General section in the hyper-parameters + name='Train' will put the connected dictionary under the Train section in the hyper-parameters + :return: The result returned when connecting the object, if supported. :raise: Raise an exception on unsupported objects. @@ -931,14 +938,22 @@ class Task(_Task): (TaskParameters, self._connect_task_parameters), ) + multi_config_support = Session.check_min_api_version('2.9') + if multi_config_support and not name: + name = self.__default_configuration_name + + if not multi_config_support and name and name != self.__default_configuration_name: + raise ValueError("Multiple configurations are not supported with the current 'trains-server', " + "please upgrade to the latest version") + for mutable_type, method in dispatch: if isinstance(mutable, mutable_type): - return method(mutable) + return method(mutable, name=name) raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__) - def connect_configuration(self, configuration): - # type: (Union[Mapping, Path, str]) -> Union[Mapping, Path, str] + def connect_configuration(self, configuration, name=None, description=None): + # type: (Union[Mapping, Path, str], Optional[str], Optional[str]) -> Union[Mapping, Path, str] """ Connect a configuration dictionary or configuration file (pathlib.Path / str) to a Task object. This method should be called before reading the configuration file. @@ -968,25 +983,54 @@ class Task(_Task): A local path must be relative path. When executing a Task remotely in a worker, the contents brought from the **Trains Server** (backend) overwrites the contents of the file. + :param str name: Configuration section name. default: 'General' + Allowing users to store multiple configuration dicts/files + + :param str description: Configuration section description (text). default: None + :return: If a dictionary is specified, then a dictionary is returned. If pathlib2.Path / string is specified, then a path to a local configuration file is returned. Configuration object. """ + pathlib_Path = None if not isinstance(configuration, (dict, Path, six.string_types)): - raise ValueError("connect_configuration supports `dict`, `str` and 'Path' types, " - "{} is not supported".format(type(configuration))) + try: + from pathlib import Path as pathlib_Path + except ImportError: + pass + if not pathlib_Path or not isinstance(configuration, pathlib_Path): + raise ValueError("connect_configuration supports `dict`, `str` and 'Path' types, " + "{} is not supported".format(type(configuration))) + + multi_config_support = Session.check_min_api_version('2.9') + if multi_config_support and not name: + name = self.__default_configuration_name + + if not multi_config_support and name and name != self.__default_configuration_name: + raise ValueError("Multiple configurations are not supported with the current 'trains-server', " + "please upgrade to the latest version") # parameter dictionary if isinstance(configuration, dict): def _update_config_dict(task, config_dict): - # noinspection PyProtectedMember - task._set_model_config(config_dict=config_dict) + if multi_config_support: + # noinspection PyProtectedMember + task._set_configuration( + name=name, description=description, config_type='dictionary', config_dict=config_dict) + else: + # noinspection PyProtectedMember + task._set_model_config(config_dict=config_dict) if not running_remotely() or not self.is_main_task(): - self._set_model_config(config_dict=configuration) + if multi_config_support: + self._set_configuration( + name=name, description=description, config_type='dictionary', config_dict=configuration) + else: + self._set_model_config(config_dict=configuration) configuration = ProxyDictPostWrite(self, _update_config_dict, **configuration) else: configuration.clear() - configuration.update(self._get_model_config_dict()) + configuration.update(self._get_configuration_dict(name=name) if multi_config_support + else self._get_model_config_dict()) configuration = ProxyDictPreWrite(False, False, **configuration) return configuration @@ -1002,16 +1046,26 @@ class Task(_Task): except Exception: raise ValueError("Could not connect configuration file {}, file could not be read".format( configuration_path.as_posix())) - self._set_model_config(config_text=configuration_text) + if multi_config_support: + self._set_configuration( + name=name, description=description, + config_type=configuration_path.suffixes[-1].lstrip('.') + if configuration_path.suffixes and configuration_path.suffixes[-1] else 'file', + config_text=configuration_text) + else: + self._set_model_config(config_text=configuration_text) return configuration else: - configuration_text = self._get_model_config_text() + configuration_text = self._get_configuration_text(name=name) if multi_config_support \ + else self._get_model_config_text() configuration_path = Path(configuration) fd, local_filename = mkstemp(prefix='trains_task_config_', suffix=configuration_path.suffixes[-1] if configuration_path.suffixes else '.txt') os.write(fd, configuration_text.encode('utf-8')) os.close(fd) + if pathlib_Path: + return pathlib_Path(local_filename) return Path(local_filename) if isinstance(configuration, Path) else local_filename def connect_label_enumeration(self, enumeration): @@ -1651,6 +1705,7 @@ class Task(_Task): # noinspection PyProtectedMember offline_folder = os.path.join(export_data.get('offline_folder', ''), 'data/') + # noinspection PyProtectedMember remote_url = task._get_default_report_storage_uri() if remote_url and remote_url.endswith('/'): remote_url = remote_url[:-1] @@ -1977,7 +2032,7 @@ class Task(_Task): return self._logger - def _connect_output_model(self, model): + def _connect_output_model(self, model, name=None): assert isinstance(model, OutputModel) model.connect(self) return model @@ -2001,7 +2056,7 @@ class Task(_Task): if self._connected_output_model: self.connect(self._connected_output_model) - def _connect_input_model(self, model): + def _connect_input_model(self, model, name=None): assert isinstance(model, InputModel) # we only allow for an input model to be connected once # at least until we support multiple input models @@ -2039,7 +2094,7 @@ class Task(_Task): # added support for multiple type connections through _Arguments return option - def _connect_argparse(self, parser, args=None, namespace=None, parsed_args=None): + def _connect_argparse(self, parser, args=None, namespace=None, parsed_args=None, name=None): # do not allow argparser to connect to jupyter notebook # noinspection PyBroadException try: @@ -2079,15 +2134,15 @@ class Task(_Task): parser, args=args, namespace=namespace, parsed_args=parsed_args) return parser - def _connect_dictionary(self, dictionary): + def _connect_dictionary(self, dictionary, name=None): def _update_args_dict(task, config_dict): # noinspection PyProtectedMember - task._arguments.copy_from_dict(flatten_dictionary(config_dict)) + task._arguments.copy_from_dict(flatten_dictionary(config_dict), prefix=name) def _refresh_args_dict(task, config_dict): # reread from task including newly added keys # noinspection PyProtectedMember - a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict)) + a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict), prefix=name) # noinspection PyProtectedMember nested_dict = config_dict._to_dict() config_dict.clear() @@ -2096,23 +2151,28 @@ class Task(_Task): self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary) if not running_remotely() or not self.is_main_task(): - self._arguments.copy_from_dict(flatten_dictionary(dictionary)) + self._arguments.copy_from_dict(flatten_dictionary(dictionary), prefix=name) dictionary = ProxyDictPostWrite(self, _update_args_dict, **dictionary) else: flat_dict = flatten_dictionary(dictionary) - flat_dict = self._arguments.copy_to_dict(flat_dict) + flat_dict = self._arguments.copy_to_dict(flat_dict, prefix=name) dictionary = nested_from_flat_dictionary(dictionary, flat_dict) dictionary = ProxyDictPostWrite(self, _refresh_args_dict, **dictionary) return dictionary - def _connect_task_parameters(self, attr_class): + def _connect_task_parameters(self, attr_class, name=None): self._try_set_connected_parameter_type(self._ConnectedParametersType.task_parameters) if running_remotely() and self.is_main_task(): - attr_class.update_from_dict(self.get_parameters()) + parameters = self.get_parameters() + if not name: + attr_class.update_from_dict(parameters) + else: + attr_class.update_from_dict( + dict((k[len(name)+1:], v) for k, v in parameters.items() if k.startswith('{}/'.format(name)))) else: - self.set_parameters(attr_class.to_dict()) + self.set_parameters(attr_class.to_dict(), __parameters_prefix=name) return attr_class def _validate(self, check_output_dest_credentials=False):