Add multi configuration section support (hyperparams and configurations)

Support setting offline mode API version using TRAINS_OFFLINE_MODE env var
This commit is contained in:
allegroai
2020-08-08 12:35:03 +03:00
parent 6d4e85de0a
commit e378de1e41
6 changed files with 355 additions and 99 deletions

View File

@@ -124,6 +124,7 @@ class Task(_Task):
__task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0))
__detect_repo_async = config.get('development.vcs_repo_detect_async', False)
__default_output_uri = config.get('development.default_output_uri', None)
__default_configuration_name = 'General'
class _ConnectedParametersType(object):
argparse = "argument_parser"
@@ -904,8 +905,8 @@ class Task(_Task):
self.data.tags.extend(tags)
self._edit(tags=list(set(self.data.tags)))
def connect(self, mutable):
# type: (Any) -> Any
def connect(self, mutable, name=None):
# type: (Any, Optional[str]) -> Any
"""
Connect an object to a Task object. This connects an experiment component (part of an experiment) to the
experiment. For example, connect hyperparameters or models.
@@ -918,6 +919,12 @@ class Task(_Task):
- TaskParameters - A TaskParameters object.
- model - A model object for initial model warmup, or for model update/snapshot uploading.
:param str name: A section name associated with the connected object. Default: 'General'
Currently only supported for `dict` / `TaskParameter` objects
Examples:
name='General' will put the connected dictionary under the General section in the hyper-parameters
name='Train' will put the connected dictionary under the Train section in the hyper-parameters
:return: The result returned when connecting the object, if supported.
:raise: Raise an exception on unsupported objects.
@@ -931,14 +938,22 @@ class Task(_Task):
(TaskParameters, self._connect_task_parameters),
)
multi_config_support = Session.check_min_api_version('2.9')
if multi_config_support and not name:
name = self.__default_configuration_name
if not multi_config_support and name and name != self.__default_configuration_name:
raise ValueError("Multiple configurations are not supported with the current 'trains-server', "
"please upgrade to the latest version")
for mutable_type, method in dispatch:
if isinstance(mutable, mutable_type):
return method(mutable)
return method(mutable, name=name)
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
def connect_configuration(self, configuration):
# type: (Union[Mapping, Path, str]) -> Union[Mapping, Path, str]
def connect_configuration(self, configuration, name=None, description=None):
# type: (Union[Mapping, Path, str], Optional[str], Optional[str]) -> Union[Mapping, Path, str]
"""
Connect a configuration dictionary or configuration file (pathlib.Path / str) to a Task object.
This method should be called before reading the configuration file.
@@ -968,25 +983,54 @@ class Task(_Task):
A local path must be relative path. When executing a Task remotely in a worker, the contents brought
from the **Trains Server** (backend) overwrites the contents of the file.
:param str name: Configuration section name. default: 'General'
Allowing users to store multiple configuration dicts/files
:param str description: Configuration section description (text). default: None
:return: If a dictionary is specified, then a dictionary is returned. If pathlib2.Path / string is
specified, then a path to a local configuration file is returned. Configuration object.
"""
pathlib_Path = None
if not isinstance(configuration, (dict, Path, six.string_types)):
raise ValueError("connect_configuration supports `dict`, `str` and 'Path' types, "
"{} is not supported".format(type(configuration)))
try:
from pathlib import Path as pathlib_Path
except ImportError:
pass
if not pathlib_Path or not isinstance(configuration, pathlib_Path):
raise ValueError("connect_configuration supports `dict`, `str` and 'Path' types, "
"{} is not supported".format(type(configuration)))
multi_config_support = Session.check_min_api_version('2.9')
if multi_config_support and not name:
name = self.__default_configuration_name
if not multi_config_support and name and name != self.__default_configuration_name:
raise ValueError("Multiple configurations are not supported with the current 'trains-server', "
"please upgrade to the latest version")
# parameter dictionary
if isinstance(configuration, dict):
def _update_config_dict(task, config_dict):
# noinspection PyProtectedMember
task._set_model_config(config_dict=config_dict)
if multi_config_support:
# noinspection PyProtectedMember
task._set_configuration(
name=name, description=description, config_type='dictionary', config_dict=config_dict)
else:
# noinspection PyProtectedMember
task._set_model_config(config_dict=config_dict)
if not running_remotely() or not self.is_main_task():
self._set_model_config(config_dict=configuration)
if multi_config_support:
self._set_configuration(
name=name, description=description, config_type='dictionary', config_dict=configuration)
else:
self._set_model_config(config_dict=configuration)
configuration = ProxyDictPostWrite(self, _update_config_dict, **configuration)
else:
configuration.clear()
configuration.update(self._get_model_config_dict())
configuration.update(self._get_configuration_dict(name=name) if multi_config_support
else self._get_model_config_dict())
configuration = ProxyDictPreWrite(False, False, **configuration)
return configuration
@@ -1002,16 +1046,26 @@ class Task(_Task):
except Exception:
raise ValueError("Could not connect configuration file {}, file could not be read".format(
configuration_path.as_posix()))
self._set_model_config(config_text=configuration_text)
if multi_config_support:
self._set_configuration(
name=name, description=description,
config_type=configuration_path.suffixes[-1].lstrip('.')
if configuration_path.suffixes and configuration_path.suffixes[-1] else 'file',
config_text=configuration_text)
else:
self._set_model_config(config_text=configuration_text)
return configuration
else:
configuration_text = self._get_model_config_text()
configuration_text = self._get_configuration_text(name=name) if multi_config_support \
else self._get_model_config_text()
configuration_path = Path(configuration)
fd, local_filename = mkstemp(prefix='trains_task_config_',
suffix=configuration_path.suffixes[-1] if
configuration_path.suffixes else '.txt')
os.write(fd, configuration_text.encode('utf-8'))
os.close(fd)
if pathlib_Path:
return pathlib_Path(local_filename)
return Path(local_filename) if isinstance(configuration, Path) else local_filename
def connect_label_enumeration(self, enumeration):
@@ -1651,6 +1705,7 @@ class Task(_Task):
# noinspection PyProtectedMember
offline_folder = os.path.join(export_data.get('offline_folder', ''), 'data/')
# noinspection PyProtectedMember
remote_url = task._get_default_report_storage_uri()
if remote_url and remote_url.endswith('/'):
remote_url = remote_url[:-1]
@@ -1977,7 +2032,7 @@ class Task(_Task):
return self._logger
def _connect_output_model(self, model):
def _connect_output_model(self, model, name=None):
assert isinstance(model, OutputModel)
model.connect(self)
return model
@@ -2001,7 +2056,7 @@ class Task(_Task):
if self._connected_output_model:
self.connect(self._connected_output_model)
def _connect_input_model(self, model):
def _connect_input_model(self, model, name=None):
assert isinstance(model, InputModel)
# we only allow for an input model to be connected once
# at least until we support multiple input models
@@ -2039,7 +2094,7 @@ class Task(_Task):
# added support for multiple type connections through _Arguments
return option
def _connect_argparse(self, parser, args=None, namespace=None, parsed_args=None):
def _connect_argparse(self, parser, args=None, namespace=None, parsed_args=None, name=None):
# do not allow argparser to connect to jupyter notebook
# noinspection PyBroadException
try:
@@ -2079,15 +2134,15 @@ class Task(_Task):
parser, args=args, namespace=namespace, parsed_args=parsed_args)
return parser
def _connect_dictionary(self, dictionary):
def _connect_dictionary(self, dictionary, name=None):
def _update_args_dict(task, config_dict):
# noinspection PyProtectedMember
task._arguments.copy_from_dict(flatten_dictionary(config_dict))
task._arguments.copy_from_dict(flatten_dictionary(config_dict), prefix=name)
def _refresh_args_dict(task, config_dict):
# reread from task including newly added keys
# noinspection PyProtectedMember
a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict))
a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict), prefix=name)
# noinspection PyProtectedMember
nested_dict = config_dict._to_dict()
config_dict.clear()
@@ -2096,23 +2151,28 @@ class Task(_Task):
self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary)
if not running_remotely() or not self.is_main_task():
self._arguments.copy_from_dict(flatten_dictionary(dictionary))
self._arguments.copy_from_dict(flatten_dictionary(dictionary), prefix=name)
dictionary = ProxyDictPostWrite(self, _update_args_dict, **dictionary)
else:
flat_dict = flatten_dictionary(dictionary)
flat_dict = self._arguments.copy_to_dict(flat_dict)
flat_dict = self._arguments.copy_to_dict(flat_dict, prefix=name)
dictionary = nested_from_flat_dictionary(dictionary, flat_dict)
dictionary = ProxyDictPostWrite(self, _refresh_args_dict, **dictionary)
return dictionary
def _connect_task_parameters(self, attr_class):
def _connect_task_parameters(self, attr_class, name=None):
self._try_set_connected_parameter_type(self._ConnectedParametersType.task_parameters)
if running_remotely() and self.is_main_task():
attr_class.update_from_dict(self.get_parameters())
parameters = self.get_parameters()
if not name:
attr_class.update_from_dict(parameters)
else:
attr_class.update_from_dict(
dict((k[len(name)+1:], v) for k, v in parameters.items() if k.startswith('{}/'.format(name))))
else:
self.set_parameters(attr_class.to_dict())
self.set_parameters(attr_class.to_dict(), __parameters_prefix=name)
return attr_class
def _validate(self, check_output_dest_credentials=False):