Allow users to disable remote overrides when using Task.connect() or Task.connect_configuration()

This commit is contained in:
allegroai 2024-02-04 19:33:41 +02:00
parent 233f94f741
commit 240b762a2a
2 changed files with 100 additions and 41 deletions

View File

@ -1932,8 +1932,8 @@ class InputModel(Model):
# type: () -> str
return self._base_model_id
def connect(self, task, name=None):
# type: (Task, Optional[str]) -> None
def connect(self, task, name=None, ignore_remote_overrides=False):
# type: (Task, Optional[str], bool) -> None
"""
Connect the current model to a Task object, if the model is preexisting. Preexisting models include:
@ -1943,24 +1943,31 @@ class InputModel(Model):
- Models whose origin is not ClearML that are used to create an InputModel object. For example,
models created using TensorFlow models.
When the experiment is executed remotely in a worker, the input model already specified in the experiment is
used.
When the experiment is executed remotely in a worker, the input model specified in the experiment UI/backend
is used, unless `ignore_remote_overrides` is set to True.
.. note::
The **ClearML Web-App** allows you to switch one input model for another and then enqueue the experiment
to execute in a worker.
:param object task: A Task object.
:param ignore_remote_overrides: If True, changing the model in the UI/backend will have no
effect when running remotely.
Default is False, meaning that any changes made in the UI/backend will be applied in remote execution.
:param str name: The model name to be stored on the Task
(default to filename of the model weights, without the file extension, or to `Input Model` if that is not found)
(default to filename of the model weights, without the file extension, or to `Input Model`
if that is not found)
"""
self._set_task(task)
name = name or InputModel._get_connect_name(self)
InputModel._warn_on_same_name_connect(name)
ignore_remote_overrides = task._handle_ignore_remote_overrides(
name + "/_ignore_remote_overrides_input_model_", ignore_remote_overrides
)
model_id = None
# noinspection PyProtectedMember
if running_remotely() and (task.is_main_task() or task._is_remote_main_task()):
if running_remotely() and (task.is_main_task() or task._is_remote_main_task()) and not ignore_remote_overrides:
input_models = task.input_models_id
# noinspection PyBroadException
try:
@ -2245,7 +2252,7 @@ class OutputModel(BaseModel):
pass
self.connect(task, name=name)
def connect(self, task, name=None):
def connect(self, task, name=None, **kwargs):
# type: (Task, Optional[str]) -> None
"""
Connect the current model to a Task object, if the model is a preexisting model. Preexisting models include:

View File

@ -1516,11 +1516,13 @@ class Task(_Task):
self.data.tags = list(set((self.data.tags or []) + tags))
self._edit(tags=self.data.tags)
def connect(self, mutable, name=None):
# type: (Any, Optional[str]) -> Any
def connect(self, mutable, name=None, ignore_remote_overrides=False):
# type: (Any, Optional[str], bool) -> Any
"""
Connect an object to a Task object. This connects an experiment component (part of an experiment) to the
experiment. For example, an experiment component can be a valid object containing some hyperparameters, or a :class:`Model`.
When running remotely, the value of the connected object is overridden by the corresponding value found
under the experiment's UI/backend (unless `ignore_remote_overrides` is True).
:param object mutable: The experiment component to connect. The object must be one of the following types:
@ -1533,16 +1535,23 @@ class Task(_Task):
:param str name: A section name associated with the connected object, if 'name' is None defaults to 'General'
Currently, `name` is only supported for `dict` and `TaskParameter` objects, and should be omitted for the other supported types. (Optional)
For example, by setting `name='General'` the connected dictionary will be under the General section in the hyperparameters section.
While by setting `name='Train'` the connected dictionary will be under the Train section in the hyperparameters section.
:param ignore_remote_overrides: If True, ignore UI/backend overrides when running remotely.
Default is False, meaning that any changes made in the UI/backend will be applied in remote execution.
:return: It will return the same object that was passed as the `mutable` argument to the method, except if the type of the object is dict.
For dicts the :meth:`Task.connect` will return the dict decorated as a `ProxyDictPostWrite`.
This is done to allow propagating the updates from the connected object.
:raise: Raises an exception if passed an unsupported object.
"""
# input model connect and task parameters will handle this instead
if not isinstance(mutable, (InputModel, TaskParameters)):
ignore_remote_overrides = self._handle_ignore_remote_overrides(
(name or "General") + "/_ignore_remote_overrides_", ignore_remote_overrides
)
# dispatching by match order
dispatch = (
(OutputModel, self._connect_output_model),
@ -1564,7 +1573,7 @@ class Task(_Task):
for mutable_type, method in dispatch:
if isinstance(mutable, mutable_type):
return method(mutable, name=name)
return method(mutable, name=name, ignore_remote_overrides=ignore_remote_overrides)
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
@ -1631,8 +1640,8 @@ class Task(_Task):
self.data.script.version_num = commit or ""
self._edit(script=self.data.script)
def connect_configuration(self, configuration, name=None, description=None):
# type: (Union[Mapping, list, Path, str], Optional[str], Optional[str]) -> Union[dict, Path, str]
def connect_configuration(self, configuration, name=None, description=None, ignore_remote_overrides=False):
# type: (Union[Mapping, list, Path, str], Optional[str], Optional[str], bool) -> Union[dict, Path, str]
"""
Connect a configuration dictionary or configuration file (pathlib.Path / str) to a Task object.
This method should be called before reading the configuration file.
@ -1650,6 +1659,9 @@ class Task(_Task):
my_params = task.connect_configuration(my_params)
When running remotely, the value of the connected configuration is overridden by the corresponding value found
under the experiment's UI/backend (unless `ignore_remote_overrides` is True).
:param configuration: The configuration. This is usually the configuration used in the model training process.
Specify one of the following:
@ -1664,9 +1676,15 @@ class Task(_Task):
:param str description: Configuration section description (text). default: None
:param bool ignore_remote_overrides: If True, ignore UI/backend overrides when running remotely.
Default is False, meaning that any changes made in the UI/backend will be applied in remote execution.
:return: If a dictionary is specified, then a dictionary is returned. If pathlib2.Path / string is
specified, then a path to a local configuration file is returned. Configuration object.
"""
ignore_remote_overrides = self._handle_ignore_remote_overrides(
(name or "General") + "/_ignore_remote_overrides_config_", ignore_remote_overrides
)
pathlib_Path = None # noqa
cast_Path = Path
if not isinstance(configuration, (dict, list, Path, six.string_types)):
@ -1710,7 +1728,7 @@ class Task(_Task):
configuration_ = ProxyDictPostWrite(self, _update_config_dict, configuration_)
return configuration_
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()):
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()) or ignore_remote_overrides:
configuration = get_dev_config(configuration)
else:
# noinspection PyBroadException
@ -1744,7 +1762,7 @@ class Task(_Task):
return configuration
# it is a path to a local file
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()):
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()) or ignore_remote_overrides:
# check if not absolute path
configuration_path = cast_Path(configuration)
if not configuration_path.is_file():
@ -1793,8 +1811,8 @@ class Task(_Task):
f.write(configuration_text)
return cast_Path(local_filename) if isinstance(configuration, cast_Path) else local_filename
def connect_label_enumeration(self, enumeration):
# type: (Dict[str, int]) -> Dict[str, int]
def connect_label_enumeration(self, enumeration, ignore_remote_overrides=False):
# type: (Dict[str, int], bool) -> Dict[str, int]
"""
Connect a label enumeration dictionary to a Task (experiment) object.
@ -1811,13 +1829,22 @@ class Task(_Task):
"person": 1
}
:param ignore_remote_overrides: If True, ignore UI/backend overrides when running remotely.
Default is False, meaning that any changes made in the UI/backend will be applied in remote execution.
:return: The label enumeration dictionary (JSON).
"""
ignore_remote_overrides = self._handle_ignore_remote_overrides(
"General/_ignore_remote_overrides_label_enumeration_", ignore_remote_overrides
)
if not isinstance(enumeration, dict):
raise ValueError("connect_label_enumeration supports only `dict` type, "
"{} is not supported".format(type(enumeration)))
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()):
if (
not running_remotely()
or not (self.is_main_task() or self._is_remote_main_task())
or ignore_remote_overrides
):
self.set_model_label_enumeration(enumeration)
else:
# pop everything
@ -3750,9 +3777,9 @@ class Task(_Task):
return self._logger
def _connect_output_model(self, model, name=None):
def _connect_output_model(self, model, name=None, **kwargs):
assert isinstance(model, OutputModel)
model.connect(self, name=name)
model.connect(self, name=name, ignore_remote_overrides=False)
return model
def _save_output_model(self, model):
@ -3764,6 +3791,19 @@ class Task(_Task):
# deprecated
self._connected_output_model = model
def _handle_ignore_remote_overrides(self, overrides_name, ignore_remote_overrides):
if self.running_locally() and ignore_remote_overrides:
self.set_parameter(
overrides_name,
True,
description="If True, ignore UI/backend overrides when running remotely."
" Set it to False if you would like the overrides to be applied",
value_type=bool
)
elif not self.running_locally():
ignore_remote_overrides = self.get_parameter(overrides_name, default=ignore_remote_overrides, cast=True)
return ignore_remote_overrides
def _reconnect_output_model(self):
"""
Deprecated: If there is a saved connected output model, connect it again.
@ -3776,7 +3816,7 @@ class Task(_Task):
if self._connected_output_model:
self.connect(self._connected_output_model)
def _connect_input_model(self, model, name=None):
def _connect_input_model(self, model, name=None, ignore_remote_overrides=False):
assert isinstance(model, InputModel)
# we only allow for an input model to be connected once
# at least until we support multiple input models
@ -3791,18 +3831,21 @@ class Task(_Task):
comment += 'Using model id: {}'.format(model.id)
self.set_comment(comment)
model.connect(self, name)
model.connect(self, name, ignore_remote_overrides=ignore_remote_overrides)
return model
def _connect_argparse(self, parser, args=None, namespace=None, parsed_args=None, name=None):
def _connect_argparse(
self, parser, args=None, namespace=None, parsed_args=None, name=None, ignore_remote_overrides=False
):
# do not allow argparser to connect to jupyter notebook
# noinspection PyBroadException
try:
if 'IPython' in sys.modules:
if "IPython" in sys.modules:
# noinspection PyPackageRequirements
from IPython import get_ipython # noqa
ip = get_ipython()
if ip is not None and 'IPKernelApp' in ip.config:
if ip is not None and "IPKernelApp" in ip.config:
return parser
except Exception:
pass
@ -3825,14 +3868,14 @@ class Task(_Task):
if parsed_args is None and parser == _parser:
parsed_args = _parsed_args
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()):
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()) and not ignore_remote_overrides:
self._arguments.copy_to_parser(parser, parsed_args)
else:
self._arguments.copy_defaults_from_argparse(
parser, args=args, namespace=namespace, parsed_args=parsed_args)
return parser
def _connect_dictionary(self, dictionary, name=None):
def _connect_dictionary(self, dictionary, name=None, ignore_remote_overrides=False):
def _update_args_dict(task, config_dict):
# noinspection PyProtectedMember
task._arguments.copy_from_dict(flatten_dictionary(config_dict), prefix=name)
@ -3862,7 +3905,7 @@ class Task(_Task):
if isinstance(v, dict):
_check_keys(v, warning_sent)
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()):
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()) or ignore_remote_overrides:
_check_keys(dictionary)
flat_dict = {str(k): v for k, v in flatten_dictionary(dictionary).items()}
self._arguments.copy_from_dict(flat_dict, prefix=name)
@ -3875,19 +3918,28 @@ class Task(_Task):
return dictionary
def _connect_task_parameters(self, attr_class, name=None):
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()):
parameters = self.get_parameters()
if not name:
attr_class.update_from_dict(parameters)
else:
attr_class.update_from_dict(
dict((k[len(name) + 1:], v) for k, v in parameters.items() if k.startswith('{}/'.format(name))))
def _connect_task_parameters(self, attr_class, name=None, ignore_remote_overrides=False):
ignore_remote_overrides_section = "_ignore_remote_overrides_"
if running_remotely():
ignore_remote_overrides = self.get_parameter(
(name or "General") + "/" + ignore_remote_overrides_section, default=ignore_remote_overrides, cast=True
)
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()) and not ignore_remote_overrides:
parameters = self.get_parameters(cast=True)
if name:
parameters = dict(
(k[len(name) + 1:], v) for k, v in parameters.items() if k.startswith("{}/".format(name))
)
parameters.pop(ignore_remote_overrides_section, None)
attr_class.update_from_dict(parameters)
else:
self.set_parameters(attr_class.to_dict(), __parameters_prefix=name)
parameters_dict = attr_class.to_dict()
if ignore_remote_overrides:
parameters_dict[ignore_remote_overrides_section] = True
self.set_parameters(parameters_dict, __parameters_prefix=name)
return attr_class
def _connect_object(self, an_object, name=None):
def _connect_object(self, an_object, name=None, ignore_remote_overrides=False):
def verify_type(key, value):
if str(key).startswith('_') or not isinstance(value, self._parameters_allowed_types):
return False
@ -3904,15 +3956,15 @@ class Task(_Task):
for k, v in cls_.__dict__.items()
if verify_type(k, v)
}
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()):
a_dict = self._connect_dictionary(a_dict, name)
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()) and not ignore_remote_overrides:
a_dict = self._connect_dictionary(a_dict, name, ignore_remote_overrides=ignore_remote_overrides)
for k, v in a_dict.items():
if getattr(an_object, k, None) != a_dict[k]:
setattr(an_object, k, v)
return an_object
else:
self._connect_dictionary(a_dict, name)
self._connect_dictionary(a_dict, name, ignore_remote_overrides=ignore_remote_overrides)
return an_object
def _dev_mode_stop_task(self, stop_reason, pid=None):