mirror of
https://github.com/clearml/clearml
synced 2025-01-31 00:56:57 +00:00
Allow users to disable remote overrides when using Task.connect()
or Task.connect_configuration()
This commit is contained in:
parent
233f94f741
commit
240b762a2a
@ -1932,8 +1932,8 @@ class InputModel(Model):
|
||||
# type: () -> str
|
||||
return self._base_model_id
|
||||
|
||||
def connect(self, task, name=None):
|
||||
# type: (Task, Optional[str]) -> None
|
||||
def connect(self, task, name=None, ignore_remote_overrides=False):
|
||||
# type: (Task, Optional[str], bool) -> None
|
||||
"""
|
||||
Connect the current model to a Task object, if the model is preexisting. Preexisting models include:
|
||||
|
||||
@ -1943,24 +1943,31 @@ class InputModel(Model):
|
||||
- Models whose origin is not ClearML that are used to create an InputModel object. For example,
|
||||
models created using TensorFlow models.
|
||||
|
||||
When the experiment is executed remotely in a worker, the input model already specified in the experiment is
|
||||
used.
|
||||
When the experiment is executed remotely in a worker, the input model specified in the experiment UI/backend
|
||||
is used, unless `ignore_remote_overrides` is set to True.
|
||||
|
||||
.. note::
|
||||
The **ClearML Web-App** allows you to switch one input model for another and then enqueue the experiment
|
||||
to execute in a worker.
|
||||
|
||||
:param object task: A Task object.
|
||||
:param ignore_remote_overrides: If True, changing the model in the UI/backend will have no
|
||||
effect when running remotely.
|
||||
Default is False, meaning that any changes made in the UI/backend will be applied in remote execution.
|
||||
:param str name: The model name to be stored on the Task
|
||||
(default to filename of the model weights, without the file extension, or to `Input Model` if that is not found)
|
||||
(default to filename of the model weights, without the file extension, or to `Input Model`
|
||||
if that is not found)
|
||||
"""
|
||||
self._set_task(task)
|
||||
name = name or InputModel._get_connect_name(self)
|
||||
InputModel._warn_on_same_name_connect(name)
|
||||
ignore_remote_overrides = task._handle_ignore_remote_overrides(
|
||||
name + "/_ignore_remote_overrides_input_model_", ignore_remote_overrides
|
||||
)
|
||||
|
||||
model_id = None
|
||||
# noinspection PyProtectedMember
|
||||
if running_remotely() and (task.is_main_task() or task._is_remote_main_task()):
|
||||
if running_remotely() and (task.is_main_task() or task._is_remote_main_task()) and not ignore_remote_overrides:
|
||||
input_models = task.input_models_id
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -2245,7 +2252,7 @@ class OutputModel(BaseModel):
|
||||
pass
|
||||
self.connect(task, name=name)
|
||||
|
||||
def connect(self, task, name=None):
|
||||
def connect(self, task, name=None, **kwargs):
|
||||
# type: (Task, Optional[str]) -> None
|
||||
"""
|
||||
Connect the current model to a Task object, if the model is a preexisting model. Preexisting models include:
|
||||
|
120
clearml/task.py
120
clearml/task.py
@ -1516,11 +1516,13 @@ class Task(_Task):
|
||||
self.data.tags = list(set((self.data.tags or []) + tags))
|
||||
self._edit(tags=self.data.tags)
|
||||
|
||||
def connect(self, mutable, name=None):
|
||||
# type: (Any, Optional[str]) -> Any
|
||||
def connect(self, mutable, name=None, ignore_remote_overrides=False):
|
||||
# type: (Any, Optional[str], bool) -> Any
|
||||
"""
|
||||
Connect an object to a Task object. This connects an experiment component (part of an experiment) to the
|
||||
experiment. For example, an experiment component can be a valid object containing some hyperparameters, or a :class:`Model`.
|
||||
When running remotely, the value of the connected object is overridden by the corresponding value found
|
||||
under the experiment's UI/backend (unless `ignore_remote_overrides` is True).
|
||||
|
||||
:param object mutable: The experiment component to connect. The object must be one of the following types:
|
||||
|
||||
@ -1533,16 +1535,23 @@ class Task(_Task):
|
||||
|
||||
:param str name: A section name associated with the connected object, if 'name' is None defaults to 'General'
|
||||
Currently, `name` is only supported for `dict` and `TaskParameter` objects, and should be omitted for the other supported types. (Optional)
|
||||
|
||||
For example, by setting `name='General'` the connected dictionary will be under the General section in the hyperparameters section.
|
||||
While by setting `name='Train'` the connected dictionary will be under the Train section in the hyperparameters section.
|
||||
|
||||
:param ignore_remote_overrides: If True, ignore UI/backend overrides when running remotely.
|
||||
Default is False, meaning that any changes made in the UI/backend will be applied in remote execution.
|
||||
|
||||
:return: It will return the same object that was passed as the `mutable` argument to the method, except if the type of the object is dict.
|
||||
For dicts the :meth:`Task.connect` will return the dict decorated as a `ProxyDictPostWrite`.
|
||||
This is done to allow propagating the updates from the connected object.
|
||||
|
||||
:raise: Raises an exception if passed an unsupported object.
|
||||
"""
|
||||
# input model connect and task parameters will handle this instead
|
||||
if not isinstance(mutable, (InputModel, TaskParameters)):
|
||||
ignore_remote_overrides = self._handle_ignore_remote_overrides(
|
||||
(name or "General") + "/_ignore_remote_overrides_", ignore_remote_overrides
|
||||
)
|
||||
# dispatching by match order
|
||||
dispatch = (
|
||||
(OutputModel, self._connect_output_model),
|
||||
@ -1564,7 +1573,7 @@ class Task(_Task):
|
||||
|
||||
for mutable_type, method in dispatch:
|
||||
if isinstance(mutable, mutable_type):
|
||||
return method(mutable, name=name)
|
||||
return method(mutable, name=name, ignore_remote_overrides=ignore_remote_overrides)
|
||||
|
||||
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
|
||||
|
||||
@ -1631,8 +1640,8 @@ class Task(_Task):
|
||||
self.data.script.version_num = commit or ""
|
||||
self._edit(script=self.data.script)
|
||||
|
||||
def connect_configuration(self, configuration, name=None, description=None):
|
||||
# type: (Union[Mapping, list, Path, str], Optional[str], Optional[str]) -> Union[dict, Path, str]
|
||||
def connect_configuration(self, configuration, name=None, description=None, ignore_remote_overrides=False):
|
||||
# type: (Union[Mapping, list, Path, str], Optional[str], Optional[str], bool) -> Union[dict, Path, str]
|
||||
"""
|
||||
Connect a configuration dictionary or configuration file (pathlib.Path / str) to a Task object.
|
||||
This method should be called before reading the configuration file.
|
||||
@ -1650,6 +1659,9 @@ class Task(_Task):
|
||||
|
||||
my_params = task.connect_configuration(my_params)
|
||||
|
||||
When running remotely, the value of the connected configuration is overridden by the corresponding value found
|
||||
under the experiment's UI/backend (unless `ignore_remote_overrides` is True).
|
||||
|
||||
:param configuration: The configuration. This is usually the configuration used in the model training process.
|
||||
Specify one of the following:
|
||||
|
||||
@ -1664,9 +1676,15 @@ class Task(_Task):
|
||||
|
||||
:param str description: Configuration section description (text). default: None
|
||||
|
||||
:param bool ignore_remote_overrides: If True, ignore UI/backend overrides when running remotely.
|
||||
Default is False, meaning that any changes made in the UI/backend will be applied in remote execution.
|
||||
|
||||
:return: If a dictionary is specified, then a dictionary is returned. If pathlib2.Path / string is
|
||||
specified, then a path to a local configuration file is returned. Configuration object.
|
||||
"""
|
||||
ignore_remote_overrides = self._handle_ignore_remote_overrides(
|
||||
(name or "General") + "/_ignore_remote_overrides_config_", ignore_remote_overrides
|
||||
)
|
||||
pathlib_Path = None # noqa
|
||||
cast_Path = Path
|
||||
if not isinstance(configuration, (dict, list, Path, six.string_types)):
|
||||
@ -1710,7 +1728,7 @@ class Task(_Task):
|
||||
configuration_ = ProxyDictPostWrite(self, _update_config_dict, configuration_)
|
||||
return configuration_
|
||||
|
||||
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()):
|
||||
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()) or ignore_remote_overrides:
|
||||
configuration = get_dev_config(configuration)
|
||||
else:
|
||||
# noinspection PyBroadException
|
||||
@ -1744,7 +1762,7 @@ class Task(_Task):
|
||||
return configuration
|
||||
|
||||
# it is a path to a local file
|
||||
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()):
|
||||
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()) or ignore_remote_overrides:
|
||||
# check if not absolute path
|
||||
configuration_path = cast_Path(configuration)
|
||||
if not configuration_path.is_file():
|
||||
@ -1793,8 +1811,8 @@ class Task(_Task):
|
||||
f.write(configuration_text)
|
||||
return cast_Path(local_filename) if isinstance(configuration, cast_Path) else local_filename
|
||||
|
||||
def connect_label_enumeration(self, enumeration):
|
||||
# type: (Dict[str, int]) -> Dict[str, int]
|
||||
def connect_label_enumeration(self, enumeration, ignore_remote_overrides=False):
|
||||
# type: (Dict[str, int], bool) -> Dict[str, int]
|
||||
"""
|
||||
Connect a label enumeration dictionary to a Task (experiment) object.
|
||||
|
||||
@ -1811,13 +1829,22 @@ class Task(_Task):
|
||||
"person": 1
|
||||
}
|
||||
|
||||
:param ignore_remote_overrides: If True, ignore UI/backend overrides when running remotely.
|
||||
Default is False, meaning that any changes made in the UI/backend will be applied in remote execution.
|
||||
:return: The label enumeration dictionary (JSON).
|
||||
"""
|
||||
ignore_remote_overrides = self._handle_ignore_remote_overrides(
|
||||
"General/_ignore_remote_overrides_label_enumeration_", ignore_remote_overrides
|
||||
)
|
||||
if not isinstance(enumeration, dict):
|
||||
raise ValueError("connect_label_enumeration supports only `dict` type, "
|
||||
"{} is not supported".format(type(enumeration)))
|
||||
|
||||
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()):
|
||||
if (
|
||||
not running_remotely()
|
||||
or not (self.is_main_task() or self._is_remote_main_task())
|
||||
or ignore_remote_overrides
|
||||
):
|
||||
self.set_model_label_enumeration(enumeration)
|
||||
else:
|
||||
# pop everything
|
||||
@ -3750,9 +3777,9 @@ class Task(_Task):
|
||||
|
||||
return self._logger
|
||||
|
||||
def _connect_output_model(self, model, name=None):
|
||||
def _connect_output_model(self, model, name=None, **kwargs):
|
||||
assert isinstance(model, OutputModel)
|
||||
model.connect(self, name=name)
|
||||
model.connect(self, name=name, ignore_remote_overrides=False)
|
||||
return model
|
||||
|
||||
def _save_output_model(self, model):
|
||||
@ -3764,6 +3791,19 @@ class Task(_Task):
|
||||
# deprecated
|
||||
self._connected_output_model = model
|
||||
|
||||
def _handle_ignore_remote_overrides(self, overrides_name, ignore_remote_overrides):
|
||||
if self.running_locally() and ignore_remote_overrides:
|
||||
self.set_parameter(
|
||||
overrides_name,
|
||||
True,
|
||||
description="If True, ignore UI/backend overrides when running remotely."
|
||||
" Set it to False if you would like the overrides to be applied",
|
||||
value_type=bool
|
||||
)
|
||||
elif not self.running_locally():
|
||||
ignore_remote_overrides = self.get_parameter(overrides_name, default=ignore_remote_overrides, cast=True)
|
||||
return ignore_remote_overrides
|
||||
|
||||
def _reconnect_output_model(self):
|
||||
"""
|
||||
Deprecated: If there is a saved connected output model, connect it again.
|
||||
@ -3776,7 +3816,7 @@ class Task(_Task):
|
||||
if self._connected_output_model:
|
||||
self.connect(self._connected_output_model)
|
||||
|
||||
def _connect_input_model(self, model, name=None):
|
||||
def _connect_input_model(self, model, name=None, ignore_remote_overrides=False):
|
||||
assert isinstance(model, InputModel)
|
||||
# we only allow for an input model to be connected once
|
||||
# at least until we support multiple input models
|
||||
@ -3791,18 +3831,21 @@ class Task(_Task):
|
||||
comment += 'Using model id: {}'.format(model.id)
|
||||
self.set_comment(comment)
|
||||
|
||||
model.connect(self, name)
|
||||
model.connect(self, name, ignore_remote_overrides=ignore_remote_overrides)
|
||||
return model
|
||||
|
||||
def _connect_argparse(self, parser, args=None, namespace=None, parsed_args=None, name=None):
|
||||
def _connect_argparse(
|
||||
self, parser, args=None, namespace=None, parsed_args=None, name=None, ignore_remote_overrides=False
|
||||
):
|
||||
# do not allow argparser to connect to jupyter notebook
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if 'IPython' in sys.modules:
|
||||
if "IPython" in sys.modules:
|
||||
# noinspection PyPackageRequirements
|
||||
from IPython import get_ipython # noqa
|
||||
|
||||
ip = get_ipython()
|
||||
if ip is not None and 'IPKernelApp' in ip.config:
|
||||
if ip is not None and "IPKernelApp" in ip.config:
|
||||
return parser
|
||||
except Exception:
|
||||
pass
|
||||
@ -3825,14 +3868,14 @@ class Task(_Task):
|
||||
if parsed_args is None and parser == _parser:
|
||||
parsed_args = _parsed_args
|
||||
|
||||
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()):
|
||||
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()) and not ignore_remote_overrides:
|
||||
self._arguments.copy_to_parser(parser, parsed_args)
|
||||
else:
|
||||
self._arguments.copy_defaults_from_argparse(
|
||||
parser, args=args, namespace=namespace, parsed_args=parsed_args)
|
||||
return parser
|
||||
|
||||
def _connect_dictionary(self, dictionary, name=None):
|
||||
def _connect_dictionary(self, dictionary, name=None, ignore_remote_overrides=False):
|
||||
def _update_args_dict(task, config_dict):
|
||||
# noinspection PyProtectedMember
|
||||
task._arguments.copy_from_dict(flatten_dictionary(config_dict), prefix=name)
|
||||
@ -3862,7 +3905,7 @@ class Task(_Task):
|
||||
if isinstance(v, dict):
|
||||
_check_keys(v, warning_sent)
|
||||
|
||||
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()):
|
||||
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()) or ignore_remote_overrides:
|
||||
_check_keys(dictionary)
|
||||
flat_dict = {str(k): v for k, v in flatten_dictionary(dictionary).items()}
|
||||
self._arguments.copy_from_dict(flat_dict, prefix=name)
|
||||
@ -3875,19 +3918,28 @@ class Task(_Task):
|
||||
|
||||
return dictionary
|
||||
|
||||
def _connect_task_parameters(self, attr_class, name=None):
|
||||
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()):
|
||||
parameters = self.get_parameters()
|
||||
if not name:
|
||||
attr_class.update_from_dict(parameters)
|
||||
else:
|
||||
attr_class.update_from_dict(
|
||||
dict((k[len(name) + 1:], v) for k, v in parameters.items() if k.startswith('{}/'.format(name))))
|
||||
def _connect_task_parameters(self, attr_class, name=None, ignore_remote_overrides=False):
|
||||
ignore_remote_overrides_section = "_ignore_remote_overrides_"
|
||||
if running_remotely():
|
||||
ignore_remote_overrides = self.get_parameter(
|
||||
(name or "General") + "/" + ignore_remote_overrides_section, default=ignore_remote_overrides, cast=True
|
||||
)
|
||||
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()) and not ignore_remote_overrides:
|
||||
parameters = self.get_parameters(cast=True)
|
||||
if name:
|
||||
parameters = dict(
|
||||
(k[len(name) + 1:], v) for k, v in parameters.items() if k.startswith("{}/".format(name))
|
||||
)
|
||||
parameters.pop(ignore_remote_overrides_section, None)
|
||||
attr_class.update_from_dict(parameters)
|
||||
else:
|
||||
self.set_parameters(attr_class.to_dict(), __parameters_prefix=name)
|
||||
parameters_dict = attr_class.to_dict()
|
||||
if ignore_remote_overrides:
|
||||
parameters_dict[ignore_remote_overrides_section] = True
|
||||
self.set_parameters(parameters_dict, __parameters_prefix=name)
|
||||
return attr_class
|
||||
|
||||
def _connect_object(self, an_object, name=None):
|
||||
def _connect_object(self, an_object, name=None, ignore_remote_overrides=False):
|
||||
def verify_type(key, value):
|
||||
if str(key).startswith('_') or not isinstance(value, self._parameters_allowed_types):
|
||||
return False
|
||||
@ -3904,15 +3956,15 @@ class Task(_Task):
|
||||
for k, v in cls_.__dict__.items()
|
||||
if verify_type(k, v)
|
||||
}
|
||||
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()):
|
||||
a_dict = self._connect_dictionary(a_dict, name)
|
||||
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()) and not ignore_remote_overrides:
|
||||
a_dict = self._connect_dictionary(a_dict, name, ignore_remote_overrides=ignore_remote_overrides)
|
||||
for k, v in a_dict.items():
|
||||
if getattr(an_object, k, None) != a_dict[k]:
|
||||
setattr(an_object, k, v)
|
||||
|
||||
return an_object
|
||||
else:
|
||||
self._connect_dictionary(a_dict, name)
|
||||
self._connect_dictionary(a_dict, name, ignore_remote_overrides=ignore_remote_overrides)
|
||||
return an_object
|
||||
|
||||
def _dev_mode_stop_task(self, stop_reason, pid=None):
|
||||
|
Loading…
Reference in New Issue
Block a user