mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Add multi configuration section support (hyperparams and configurations)
Support setting offline mode API version using TRAINS_OFFLINE_MODE env var
This commit is contained in:
108
trains/task.py
108
trains/task.py
@@ -124,6 +124,7 @@ class Task(_Task):
|
||||
__task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0))
|
||||
__detect_repo_async = config.get('development.vcs_repo_detect_async', False)
|
||||
__default_output_uri = config.get('development.default_output_uri', None)
|
||||
__default_configuration_name = 'General'
|
||||
|
||||
class _ConnectedParametersType(object):
|
||||
argparse = "argument_parser"
|
||||
@@ -904,8 +905,8 @@ class Task(_Task):
|
||||
self.data.tags.extend(tags)
|
||||
self._edit(tags=list(set(self.data.tags)))
|
||||
|
||||
def connect(self, mutable):
|
||||
# type: (Any) -> Any
|
||||
def connect(self, mutable, name=None):
|
||||
# type: (Any, Optional[str]) -> Any
|
||||
"""
|
||||
Connect an object to a Task object. This connects an experiment component (part of an experiment) to the
|
||||
experiment. For example, connect hyperparameters or models.
|
||||
@@ -918,6 +919,12 @@ class Task(_Task):
|
||||
- TaskParameters - A TaskParameters object.
|
||||
- model - A model object for initial model warmup, or for model update/snapshot uploading.
|
||||
|
||||
:param str name: A section name associated with the connected object. Default: 'General'
|
||||
Currently only supported for `dict` / `TaskParameter` objects
|
||||
Examples:
|
||||
name='General' will put the connected dictionary under the General section in the hyper-parameters
|
||||
name='Train' will put the connected dictionary under the Train section in the hyper-parameters
|
||||
|
||||
:return: The result returned when connecting the object, if supported.
|
||||
|
||||
:raise: Raise an exception on unsupported objects.
|
||||
@@ -931,14 +938,22 @@ class Task(_Task):
|
||||
(TaskParameters, self._connect_task_parameters),
|
||||
)
|
||||
|
||||
multi_config_support = Session.check_min_api_version('2.9')
|
||||
if multi_config_support and not name:
|
||||
name = self.__default_configuration_name
|
||||
|
||||
if not multi_config_support and name and name != self.__default_configuration_name:
|
||||
raise ValueError("Multiple configurations are not supported with the current 'trains-server', "
|
||||
"please upgrade to the latest version")
|
||||
|
||||
for mutable_type, method in dispatch:
|
||||
if isinstance(mutable, mutable_type):
|
||||
return method(mutable)
|
||||
return method(mutable, name=name)
|
||||
|
||||
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
|
||||
|
||||
def connect_configuration(self, configuration):
|
||||
# type: (Union[Mapping, Path, str]) -> Union[Mapping, Path, str]
|
||||
def connect_configuration(self, configuration, name=None, description=None):
|
||||
# type: (Union[Mapping, Path, str], Optional[str], Optional[str]) -> Union[Mapping, Path, str]
|
||||
"""
|
||||
Connect a configuration dictionary or configuration file (pathlib.Path / str) to a Task object.
|
||||
This method should be called before reading the configuration file.
|
||||
@@ -968,25 +983,54 @@ class Task(_Task):
|
||||
A local path must be relative path. When executing a Task remotely in a worker, the contents brought
|
||||
from the **Trains Server** (backend) overwrites the contents of the file.
|
||||
|
||||
:param str name: Configuration section name. default: 'General'
|
||||
Allowing users to store multiple configuration dicts/files
|
||||
|
||||
:param str description: Configuration section description (text). default: None
|
||||
|
||||
:return: If a dictionary is specified, then a dictionary is returned. If pathlib2.Path / string is
|
||||
specified, then a path to a local configuration file is returned. Configuration object.
|
||||
"""
|
||||
pathlib_Path = None
|
||||
if not isinstance(configuration, (dict, Path, six.string_types)):
|
||||
raise ValueError("connect_configuration supports `dict`, `str` and 'Path' types, "
|
||||
"{} is not supported".format(type(configuration)))
|
||||
try:
|
||||
from pathlib import Path as pathlib_Path
|
||||
except ImportError:
|
||||
pass
|
||||
if not pathlib_Path or not isinstance(configuration, pathlib_Path):
|
||||
raise ValueError("connect_configuration supports `dict`, `str` and 'Path' types, "
|
||||
"{} is not supported".format(type(configuration)))
|
||||
|
||||
multi_config_support = Session.check_min_api_version('2.9')
|
||||
if multi_config_support and not name:
|
||||
name = self.__default_configuration_name
|
||||
|
||||
if not multi_config_support and name and name != self.__default_configuration_name:
|
||||
raise ValueError("Multiple configurations are not supported with the current 'trains-server', "
|
||||
"please upgrade to the latest version")
|
||||
|
||||
# parameter dictionary
|
||||
if isinstance(configuration, dict):
|
||||
def _update_config_dict(task, config_dict):
|
||||
# noinspection PyProtectedMember
|
||||
task._set_model_config(config_dict=config_dict)
|
||||
if multi_config_support:
|
||||
# noinspection PyProtectedMember
|
||||
task._set_configuration(
|
||||
name=name, description=description, config_type='dictionary', config_dict=config_dict)
|
||||
else:
|
||||
# noinspection PyProtectedMember
|
||||
task._set_model_config(config_dict=config_dict)
|
||||
|
||||
if not running_remotely() or not self.is_main_task():
|
||||
self._set_model_config(config_dict=configuration)
|
||||
if multi_config_support:
|
||||
self._set_configuration(
|
||||
name=name, description=description, config_type='dictionary', config_dict=configuration)
|
||||
else:
|
||||
self._set_model_config(config_dict=configuration)
|
||||
configuration = ProxyDictPostWrite(self, _update_config_dict, **configuration)
|
||||
else:
|
||||
configuration.clear()
|
||||
configuration.update(self._get_model_config_dict())
|
||||
configuration.update(self._get_configuration_dict(name=name) if multi_config_support
|
||||
else self._get_model_config_dict())
|
||||
configuration = ProxyDictPreWrite(False, False, **configuration)
|
||||
return configuration
|
||||
|
||||
@@ -1002,16 +1046,26 @@ class Task(_Task):
|
||||
except Exception:
|
||||
raise ValueError("Could not connect configuration file {}, file could not be read".format(
|
||||
configuration_path.as_posix()))
|
||||
self._set_model_config(config_text=configuration_text)
|
||||
if multi_config_support:
|
||||
self._set_configuration(
|
||||
name=name, description=description,
|
||||
config_type=configuration_path.suffixes[-1].lstrip('.')
|
||||
if configuration_path.suffixes and configuration_path.suffixes[-1] else 'file',
|
||||
config_text=configuration_text)
|
||||
else:
|
||||
self._set_model_config(config_text=configuration_text)
|
||||
return configuration
|
||||
else:
|
||||
configuration_text = self._get_model_config_text()
|
||||
configuration_text = self._get_configuration_text(name=name) if multi_config_support \
|
||||
else self._get_model_config_text()
|
||||
configuration_path = Path(configuration)
|
||||
fd, local_filename = mkstemp(prefix='trains_task_config_',
|
||||
suffix=configuration_path.suffixes[-1] if
|
||||
configuration_path.suffixes else '.txt')
|
||||
os.write(fd, configuration_text.encode('utf-8'))
|
||||
os.close(fd)
|
||||
if pathlib_Path:
|
||||
return pathlib_Path(local_filename)
|
||||
return Path(local_filename) if isinstance(configuration, Path) else local_filename
|
||||
|
||||
def connect_label_enumeration(self, enumeration):
|
||||
@@ -1651,6 +1705,7 @@ class Task(_Task):
|
||||
# noinspection PyProtectedMember
|
||||
offline_folder = os.path.join(export_data.get('offline_folder', ''), 'data/')
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
remote_url = task._get_default_report_storage_uri()
|
||||
if remote_url and remote_url.endswith('/'):
|
||||
remote_url = remote_url[:-1]
|
||||
@@ -1977,7 +2032,7 @@ class Task(_Task):
|
||||
|
||||
return self._logger
|
||||
|
||||
def _connect_output_model(self, model):
|
||||
def _connect_output_model(self, model, name=None):
|
||||
assert isinstance(model, OutputModel)
|
||||
model.connect(self)
|
||||
return model
|
||||
@@ -2001,7 +2056,7 @@ class Task(_Task):
|
||||
if self._connected_output_model:
|
||||
self.connect(self._connected_output_model)
|
||||
|
||||
def _connect_input_model(self, model):
|
||||
def _connect_input_model(self, model, name=None):
|
||||
assert isinstance(model, InputModel)
|
||||
# we only allow for an input model to be connected once
|
||||
# at least until we support multiple input models
|
||||
@@ -2039,7 +2094,7 @@ class Task(_Task):
|
||||
# added support for multiple type connections through _Arguments
|
||||
return option
|
||||
|
||||
def _connect_argparse(self, parser, args=None, namespace=None, parsed_args=None):
|
||||
def _connect_argparse(self, parser, args=None, namespace=None, parsed_args=None, name=None):
|
||||
# do not allow argparser to connect to jupyter notebook
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@@ -2079,15 +2134,15 @@ class Task(_Task):
|
||||
parser, args=args, namespace=namespace, parsed_args=parsed_args)
|
||||
return parser
|
||||
|
||||
def _connect_dictionary(self, dictionary):
|
||||
def _connect_dictionary(self, dictionary, name=None):
|
||||
def _update_args_dict(task, config_dict):
|
||||
# noinspection PyProtectedMember
|
||||
task._arguments.copy_from_dict(flatten_dictionary(config_dict))
|
||||
task._arguments.copy_from_dict(flatten_dictionary(config_dict), prefix=name)
|
||||
|
||||
def _refresh_args_dict(task, config_dict):
|
||||
# reread from task including newly added keys
|
||||
# noinspection PyProtectedMember
|
||||
a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict))
|
||||
a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict), prefix=name)
|
||||
# noinspection PyProtectedMember
|
||||
nested_dict = config_dict._to_dict()
|
||||
config_dict.clear()
|
||||
@@ -2096,23 +2151,28 @@ class Task(_Task):
|
||||
self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary)
|
||||
|
||||
if not running_remotely() or not self.is_main_task():
|
||||
self._arguments.copy_from_dict(flatten_dictionary(dictionary))
|
||||
self._arguments.copy_from_dict(flatten_dictionary(dictionary), prefix=name)
|
||||
dictionary = ProxyDictPostWrite(self, _update_args_dict, **dictionary)
|
||||
else:
|
||||
flat_dict = flatten_dictionary(dictionary)
|
||||
flat_dict = self._arguments.copy_to_dict(flat_dict)
|
||||
flat_dict = self._arguments.copy_to_dict(flat_dict, prefix=name)
|
||||
dictionary = nested_from_flat_dictionary(dictionary, flat_dict)
|
||||
dictionary = ProxyDictPostWrite(self, _refresh_args_dict, **dictionary)
|
||||
|
||||
return dictionary
|
||||
|
||||
def _connect_task_parameters(self, attr_class):
|
||||
def _connect_task_parameters(self, attr_class, name=None):
|
||||
self._try_set_connected_parameter_type(self._ConnectedParametersType.task_parameters)
|
||||
|
||||
if running_remotely() and self.is_main_task():
|
||||
attr_class.update_from_dict(self.get_parameters())
|
||||
parameters = self.get_parameters()
|
||||
if not name:
|
||||
attr_class.update_from_dict(parameters)
|
||||
else:
|
||||
attr_class.update_from_dict(
|
||||
dict((k[len(name)+1:], v) for k, v in parameters.items() if k.startswith('{}/'.format(name))))
|
||||
else:
|
||||
self.set_parameters(attr_class.to_dict())
|
||||
self.set_parameters(attr_class.to_dict(), __parameters_prefix=name)
|
||||
return attr_class
|
||||
|
||||
def _validate(self, check_output_dest_credentials=False):
|
||||
|
||||
Reference in New Issue
Block a user