Allow users to disable remote overrides when using Task.connect() or Task.connect_configuration()

This commit is contained in:
allegroai 2024-02-04 19:33:41 +02:00
parent 233f94f741
commit 240b762a2a
2 changed files with 100 additions and 41 deletions

View File

@ -1932,8 +1932,8 @@ class InputModel(Model):
# type: () -> str # type: () -> str
return self._base_model_id return self._base_model_id
def connect(self, task, name=None): def connect(self, task, name=None, ignore_remote_overrides=False):
# type: (Task, Optional[str]) -> None # type: (Task, Optional[str], bool) -> None
""" """
Connect the current model to a Task object, if the model is preexisting. Preexisting models include: Connect the current model to a Task object, if the model is preexisting. Preexisting models include:
@ -1943,24 +1943,31 @@ class InputModel(Model):
- Models whose origin is not ClearML that are used to create an InputModel object. For example, - Models whose origin is not ClearML that are used to create an InputModel object. For example,
models created using TensorFlow models. models created using TensorFlow models.
When the experiment is executed remotely in a worker, the input model already specified in the experiment is When the experiment is executed remotely in a worker, the input model specified in the experiment UI/backend
used. is used, unless `ignore_remote_overrides` is set to True.
.. note:: .. note::
The **ClearML Web-App** allows you to switch one input model for another and then enqueue the experiment The **ClearML Web-App** allows you to switch one input model for another and then enqueue the experiment
to execute in a worker. to execute in a worker.
:param object task: A Task object. :param object task: A Task object.
:param ignore_remote_overrides: If True, changing the model in the UI/backend will have no
effect when running remotely.
Default is False, meaning that any changes made in the UI/backend will be applied in remote execution.
:param str name: The model name to be stored on the Task :param str name: The model name to be stored on the Task
(default to filename of the model weights, without the file extension, or to `Input Model` if that is not found) (default to filename of the model weights, without the file extension, or to `Input Model`
if that is not found)
""" """
self._set_task(task) self._set_task(task)
name = name or InputModel._get_connect_name(self) name = name or InputModel._get_connect_name(self)
InputModel._warn_on_same_name_connect(name) InputModel._warn_on_same_name_connect(name)
ignore_remote_overrides = task._handle_ignore_remote_overrides(
name + "/_ignore_remote_overrides_input_model_", ignore_remote_overrides
)
model_id = None model_id = None
# noinspection PyProtectedMember # noinspection PyProtectedMember
if running_remotely() and (task.is_main_task() or task._is_remote_main_task()): if running_remotely() and (task.is_main_task() or task._is_remote_main_task()) and not ignore_remote_overrides:
input_models = task.input_models_id input_models = task.input_models_id
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -2245,7 +2252,7 @@ class OutputModel(BaseModel):
pass pass
self.connect(task, name=name) self.connect(task, name=name)
def connect(self, task, name=None): def connect(self, task, name=None, **kwargs):
# type: (Task, Optional[str]) -> None # type: (Task, Optional[str]) -> None
""" """
Connect the current model to a Task object, if the model is a preexisting model. Preexisting models include: Connect the current model to a Task object, if the model is a preexisting model. Preexisting models include:

View File

@ -1516,11 +1516,13 @@ class Task(_Task):
self.data.tags = list(set((self.data.tags or []) + tags)) self.data.tags = list(set((self.data.tags or []) + tags))
self._edit(tags=self.data.tags) self._edit(tags=self.data.tags)
def connect(self, mutable, name=None): def connect(self, mutable, name=None, ignore_remote_overrides=False):
# type: (Any, Optional[str]) -> Any # type: (Any, Optional[str], bool) -> Any
""" """
Connect an object to a Task object. This connects an experiment component (part of an experiment) to the Connect an object to a Task object. This connects an experiment component (part of an experiment) to the
experiment. For example, an experiment component can be a valid object containing some hyperparameters, or a :class:`Model`. experiment. For example, an experiment component can be a valid object containing some hyperparameters, or a :class:`Model`.
When running remotely, the value of the connected object is overridden by the corresponding value found
under the experiment's UI/backend (unless `ignore_remote_overrides` is True).
:param object mutable: The experiment component to connect. The object must be one of the following types: :param object mutable: The experiment component to connect. The object must be one of the following types:
@ -1533,16 +1535,23 @@ class Task(_Task):
:param str name: A section name associated with the connected object, if 'name' is None defaults to 'General' :param str name: A section name associated with the connected object, if 'name' is None defaults to 'General'
Currently, `name` is only supported for `dict` and `TaskParameter` objects, and should be omitted for the other supported types. (Optional) Currently, `name` is only supported for `dict` and `TaskParameter` objects, and should be omitted for the other supported types. (Optional)
For example, by setting `name='General'` the connected dictionary will be under the General section in the hyperparameters section. For example, by setting `name='General'` the connected dictionary will be under the General section in the hyperparameters section.
While by setting `name='Train'` the connected dictionary will be under the Train section in the hyperparameters section. While by setting `name='Train'` the connected dictionary will be under the Train section in the hyperparameters section.
:param ignore_remote_overrides: If True, ignore UI/backend overrides when running remotely.
Default is False, meaning that any changes made in the UI/backend will be applied in remote execution.
:return: It will return the same object that was passed as the `mutable` argument to the method, except if the type of the object is dict. :return: It will return the same object that was passed as the `mutable` argument to the method, except if the type of the object is dict.
For dicts the :meth:`Task.connect` will return the dict decorated as a `ProxyDictPostWrite`. For dicts the :meth:`Task.connect` will return the dict decorated as a `ProxyDictPostWrite`.
This is done to allow propagating the updates from the connected object. This is done to allow propagating the updates from the connected object.
:raise: Raises an exception if passed an unsupported object. :raise: Raises an exception if passed an unsupported object.
""" """
# input model connect and task parameters will handle this instead
if not isinstance(mutable, (InputModel, TaskParameters)):
ignore_remote_overrides = self._handle_ignore_remote_overrides(
(name or "General") + "/_ignore_remote_overrides_", ignore_remote_overrides
)
# dispatching by match order # dispatching by match order
dispatch = ( dispatch = (
(OutputModel, self._connect_output_model), (OutputModel, self._connect_output_model),
@ -1564,7 +1573,7 @@ class Task(_Task):
for mutable_type, method in dispatch: for mutable_type, method in dispatch:
if isinstance(mutable, mutable_type): if isinstance(mutable, mutable_type):
return method(mutable, name=name) return method(mutable, name=name, ignore_remote_overrides=ignore_remote_overrides)
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__) raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
@ -1631,8 +1640,8 @@ class Task(_Task):
self.data.script.version_num = commit or "" self.data.script.version_num = commit or ""
self._edit(script=self.data.script) self._edit(script=self.data.script)
def connect_configuration(self, configuration, name=None, description=None): def connect_configuration(self, configuration, name=None, description=None, ignore_remote_overrides=False):
# type: (Union[Mapping, list, Path, str], Optional[str], Optional[str]) -> Union[dict, Path, str] # type: (Union[Mapping, list, Path, str], Optional[str], Optional[str], bool) -> Union[dict, Path, str]
""" """
Connect a configuration dictionary or configuration file (pathlib.Path / str) to a Task object. Connect a configuration dictionary or configuration file (pathlib.Path / str) to a Task object.
This method should be called before reading the configuration file. This method should be called before reading the configuration file.
@ -1650,6 +1659,9 @@ class Task(_Task):
my_params = task.connect_configuration(my_params) my_params = task.connect_configuration(my_params)
When running remotely, the value of the connected configuration is overridden by the corresponding value found
under the experiment's UI/backend (unless `ignore_remote_overrides` is True).
:param configuration: The configuration. This is usually the configuration used in the model training process. :param configuration: The configuration. This is usually the configuration used in the model training process.
Specify one of the following: Specify one of the following:
@ -1664,9 +1676,15 @@ class Task(_Task):
:param str description: Configuration section description (text). default: None :param str description: Configuration section description (text). default: None
:param bool ignore_remote_overrides: If True, ignore UI/backend overrides when running remotely.
Default is False, meaning that any changes made in the UI/backend will be applied in remote execution.
:return: If a dictionary is specified, then a dictionary is returned. If pathlib2.Path / string is :return: If a dictionary is specified, then a dictionary is returned. If pathlib2.Path / string is
specified, then a path to a local configuration file is returned. Configuration object. specified, then a path to a local configuration file is returned. Configuration object.
""" """
ignore_remote_overrides = self._handle_ignore_remote_overrides(
(name or "General") + "/_ignore_remote_overrides_config_", ignore_remote_overrides
)
pathlib_Path = None # noqa pathlib_Path = None # noqa
cast_Path = Path cast_Path = Path
if not isinstance(configuration, (dict, list, Path, six.string_types)): if not isinstance(configuration, (dict, list, Path, six.string_types)):
@ -1710,7 +1728,7 @@ class Task(_Task):
configuration_ = ProxyDictPostWrite(self, _update_config_dict, configuration_) configuration_ = ProxyDictPostWrite(self, _update_config_dict, configuration_)
return configuration_ return configuration_
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()): if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()) or ignore_remote_overrides:
configuration = get_dev_config(configuration) configuration = get_dev_config(configuration)
else: else:
# noinspection PyBroadException # noinspection PyBroadException
@ -1744,7 +1762,7 @@ class Task(_Task):
return configuration return configuration
# it is a path to a local file # it is a path to a local file
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()): if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()) or ignore_remote_overrides:
# check if not absolute path # check if not absolute path
configuration_path = cast_Path(configuration) configuration_path = cast_Path(configuration)
if not configuration_path.is_file(): if not configuration_path.is_file():
@ -1793,8 +1811,8 @@ class Task(_Task):
f.write(configuration_text) f.write(configuration_text)
return cast_Path(local_filename) if isinstance(configuration, cast_Path) else local_filename return cast_Path(local_filename) if isinstance(configuration, cast_Path) else local_filename
def connect_label_enumeration(self, enumeration): def connect_label_enumeration(self, enumeration, ignore_remote_overrides=False):
# type: (Dict[str, int]) -> Dict[str, int] # type: (Dict[str, int], bool) -> Dict[str, int]
""" """
Connect a label enumeration dictionary to a Task (experiment) object. Connect a label enumeration dictionary to a Task (experiment) object.
@ -1811,13 +1829,22 @@ class Task(_Task):
"person": 1 "person": 1
} }
:param ignore_remote_overrides: If True, ignore UI/backend overrides when running remotely.
Default is False, meaning that any changes made in the UI/backend will be applied in remote execution.
:return: The label enumeration dictionary (JSON). :return: The label enumeration dictionary (JSON).
""" """
ignore_remote_overrides = self._handle_ignore_remote_overrides(
"General/_ignore_remote_overrides_label_enumeration_", ignore_remote_overrides
)
if not isinstance(enumeration, dict): if not isinstance(enumeration, dict):
raise ValueError("connect_label_enumeration supports only `dict` type, " raise ValueError("connect_label_enumeration supports only `dict` type, "
"{} is not supported".format(type(enumeration))) "{} is not supported".format(type(enumeration)))
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()): if (
not running_remotely()
or not (self.is_main_task() or self._is_remote_main_task())
or ignore_remote_overrides
):
self.set_model_label_enumeration(enumeration) self.set_model_label_enumeration(enumeration)
else: else:
# pop everything # pop everything
@ -3750,9 +3777,9 @@ class Task(_Task):
return self._logger return self._logger
def _connect_output_model(self, model, name=None): def _connect_output_model(self, model, name=None, **kwargs):
assert isinstance(model, OutputModel) assert isinstance(model, OutputModel)
model.connect(self, name=name) model.connect(self, name=name, ignore_remote_overrides=False)
return model return model
def _save_output_model(self, model): def _save_output_model(self, model):
@ -3764,6 +3791,19 @@ class Task(_Task):
# deprecated # deprecated
self._connected_output_model = model self._connected_output_model = model
def _handle_ignore_remote_overrides(self, overrides_name, ignore_remote_overrides):
if self.running_locally() and ignore_remote_overrides:
self.set_parameter(
overrides_name,
True,
description="If True, ignore UI/backend overrides when running remotely."
" Set it to False if you would like the overrides to be applied",
value_type=bool
)
elif not self.running_locally():
ignore_remote_overrides = self.get_parameter(overrides_name, default=ignore_remote_overrides, cast=True)
return ignore_remote_overrides
def _reconnect_output_model(self): def _reconnect_output_model(self):
""" """
Deprecated: If there is a saved connected output model, connect it again. Deprecated: If there is a saved connected output model, connect it again.
@ -3776,7 +3816,7 @@ class Task(_Task):
if self._connected_output_model: if self._connected_output_model:
self.connect(self._connected_output_model) self.connect(self._connected_output_model)
def _connect_input_model(self, model, name=None): def _connect_input_model(self, model, name=None, ignore_remote_overrides=False):
assert isinstance(model, InputModel) assert isinstance(model, InputModel)
# we only allow for an input model to be connected once # we only allow for an input model to be connected once
# at least until we support multiple input models # at least until we support multiple input models
@ -3791,18 +3831,21 @@ class Task(_Task):
comment += 'Using model id: {}'.format(model.id) comment += 'Using model id: {}'.format(model.id)
self.set_comment(comment) self.set_comment(comment)
model.connect(self, name) model.connect(self, name, ignore_remote_overrides=ignore_remote_overrides)
return model return model
def _connect_argparse(self, parser, args=None, namespace=None, parsed_args=None, name=None): def _connect_argparse(
self, parser, args=None, namespace=None, parsed_args=None, name=None, ignore_remote_overrides=False
):
# do not allow argparser to connect to jupyter notebook # do not allow argparser to connect to jupyter notebook
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if 'IPython' in sys.modules: if "IPython" in sys.modules:
# noinspection PyPackageRequirements # noinspection PyPackageRequirements
from IPython import get_ipython # noqa from IPython import get_ipython # noqa
ip = get_ipython() ip = get_ipython()
if ip is not None and 'IPKernelApp' in ip.config: if ip is not None and "IPKernelApp" in ip.config:
return parser return parser
except Exception: except Exception:
pass pass
@ -3825,14 +3868,14 @@ class Task(_Task):
if parsed_args is None and parser == _parser: if parsed_args is None and parser == _parser:
parsed_args = _parsed_args parsed_args = _parsed_args
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()): if running_remotely() and (self.is_main_task() or self._is_remote_main_task()) and not ignore_remote_overrides:
self._arguments.copy_to_parser(parser, parsed_args) self._arguments.copy_to_parser(parser, parsed_args)
else: else:
self._arguments.copy_defaults_from_argparse( self._arguments.copy_defaults_from_argparse(
parser, args=args, namespace=namespace, parsed_args=parsed_args) parser, args=args, namespace=namespace, parsed_args=parsed_args)
return parser return parser
def _connect_dictionary(self, dictionary, name=None): def _connect_dictionary(self, dictionary, name=None, ignore_remote_overrides=False):
def _update_args_dict(task, config_dict): def _update_args_dict(task, config_dict):
# noinspection PyProtectedMember # noinspection PyProtectedMember
task._arguments.copy_from_dict(flatten_dictionary(config_dict), prefix=name) task._arguments.copy_from_dict(flatten_dictionary(config_dict), prefix=name)
@ -3862,7 +3905,7 @@ class Task(_Task):
if isinstance(v, dict): if isinstance(v, dict):
_check_keys(v, warning_sent) _check_keys(v, warning_sent)
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()): if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()) or ignore_remote_overrides:
_check_keys(dictionary) _check_keys(dictionary)
flat_dict = {str(k): v for k, v in flatten_dictionary(dictionary).items()} flat_dict = {str(k): v for k, v in flatten_dictionary(dictionary).items()}
self._arguments.copy_from_dict(flat_dict, prefix=name) self._arguments.copy_from_dict(flat_dict, prefix=name)
@ -3875,19 +3918,28 @@ class Task(_Task):
return dictionary return dictionary
def _connect_task_parameters(self, attr_class, name=None): def _connect_task_parameters(self, attr_class, name=None, ignore_remote_overrides=False):
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()): ignore_remote_overrides_section = "_ignore_remote_overrides_"
parameters = self.get_parameters() if running_remotely():
if not name: ignore_remote_overrides = self.get_parameter(
(name or "General") + "/" + ignore_remote_overrides_section, default=ignore_remote_overrides, cast=True
)
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()) and not ignore_remote_overrides:
parameters = self.get_parameters(cast=True)
if name:
parameters = dict(
(k[len(name) + 1:], v) for k, v in parameters.items() if k.startswith("{}/".format(name))
)
parameters.pop(ignore_remote_overrides_section, None)
attr_class.update_from_dict(parameters) attr_class.update_from_dict(parameters)
else: else:
attr_class.update_from_dict( parameters_dict = attr_class.to_dict()
dict((k[len(name) + 1:], v) for k, v in parameters.items() if k.startswith('{}/'.format(name)))) if ignore_remote_overrides:
else: parameters_dict[ignore_remote_overrides_section] = True
self.set_parameters(attr_class.to_dict(), __parameters_prefix=name) self.set_parameters(parameters_dict, __parameters_prefix=name)
return attr_class return attr_class
def _connect_object(self, an_object, name=None): def _connect_object(self, an_object, name=None, ignore_remote_overrides=False):
def verify_type(key, value): def verify_type(key, value):
if str(key).startswith('_') or not isinstance(value, self._parameters_allowed_types): if str(key).startswith('_') or not isinstance(value, self._parameters_allowed_types):
return False return False
@ -3904,15 +3956,15 @@ class Task(_Task):
for k, v in cls_.__dict__.items() for k, v in cls_.__dict__.items()
if verify_type(k, v) if verify_type(k, v)
} }
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()): if running_remotely() and (self.is_main_task() or self._is_remote_main_task()) and not ignore_remote_overrides:
a_dict = self._connect_dictionary(a_dict, name) a_dict = self._connect_dictionary(a_dict, name, ignore_remote_overrides=ignore_remote_overrides)
for k, v in a_dict.items(): for k, v in a_dict.items():
if getattr(an_object, k, None) != a_dict[k]: if getattr(an_object, k, None) != a_dict[k]:
setattr(an_object, k, v) setattr(an_object, k, v)
return an_object return an_object
else: else:
self._connect_dictionary(a_dict, name) self._connect_dictionary(a_dict, name, ignore_remote_overrides=ignore_remote_overrides)
return an_object return an_object
def _dev_mode_stop_task(self, stop_reason, pid=None): def _dev_mode_stop_task(self, stop_reason, pid=None):