From c26efb83aff12828758a1ddedcac4aaa9039ccfd Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 24 Oct 2023 18:36:53 +0300 Subject: [PATCH] Add support for Hydra command-line syntax for modifying omegaconf --- clearml/backend_interface/task/task.py | 6 +- clearml/binding/hydra_bind.py | 108 ++++++++++++++----------- clearml/task.py | 2 + 3 files changed, 69 insertions(+), 47 deletions(-) diff --git a/clearml/backend_interface/task/task.py b/clearml/backend_interface/task/task.py index 5f4323f3..fc00233c 100644 --- a/clearml/backend_interface/task/task.py +++ b/clearml/backend_interface/task/task.py @@ -1344,12 +1344,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): params = self.get_parameters(cast=cast) return params.get(name, default) - def delete_parameter(self, name): + def delete_parameter(self, name, force=False): # type: (str) -> bool """ Delete a parameter by its full name Section/name. :param name: Parameter name in full, i.e. Section/name. For example, 'Args/batch_size' + :param force: If set to True then both new and running task hyper params can be deleted. + Otherwise only the new task ones. Default is False :return: True if the parameter was deleted successfully """ if not Session.check_min_api_version('2.9'): @@ -1360,7 +1362,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): with self._edit_lock: paramkey = tasks.ParamKey(section=name.split('/', 1)[0], name=name.split('/', 1)[1]) res = self.send(tasks.DeleteHyperParamsRequest( - task=self.id, hyperparams=[paramkey]), raise_on_errors=False) + task=self.id, hyperparams=[paramkey], force=force), raise_on_errors=False) self.reload() return res.ok() diff --git a/clearml/binding/hydra_bind.py b/clearml/binding/hydra_bind.py index 6c8af3f9..b89f1dbc 100644 --- a/clearml/binding/hydra_bind.py +++ b/clearml/binding/hydra_bind.py @@ -15,6 +15,8 @@ class PatchHydra(object): _config_section = 'OmegaConf' _parameter_section = 'Hydra' _parameter_allow_full_edit = '_allow_omegaconf_edit_' + _should_delete_overrides = False + _overrides_section = "Args/overrides" @classmethod def patch_hydra(cls): @@ -42,6 +44,12 @@ class PatchHydra(object): except Exception: return False + @classmethod + def delete_overrides(cls): + if not cls._should_delete_overrides or not cls._current_task: + return + cls._current_task.delete_parameter(cls._overrides_section, force=True) + @staticmethod def update_current_task(task): # set current Task before patching @@ -50,11 +58,24 @@ class PatchHydra(object): return if PatchHydra.patch_hydra(): # check if we have an untracked state, store it. - if PatchHydra._last_untracked_state.get('connect'): - PatchHydra._current_task.connect(**PatchHydra._last_untracked_state['connect']) - if PatchHydra._last_untracked_state.get('_set_configuration'): + if PatchHydra._last_untracked_state.get("connect"): + if PatchHydra._parameter_allow_full_edit in PatchHydra._last_untracked_state["connect"].get("mutable", {}): + allow_omegaconf_edit_section = PatchHydra._parameter_section + "/" + PatchHydra._parameter_allow_full_edit + allow_omegaconf_edit_section_val = PatchHydra._last_untracked_state["connect"]["mutable"].pop( + PatchHydra._parameter_allow_full_edit + ) + PatchHydra._current_task.set_parameter( + allow_omegaconf_edit_section, + allow_omegaconf_edit_section_val, + description="If True, the `{}` parameter section will be completely ignored. The OmegaConf will instead be pulled from the `{}` section".format( + PatchHydra._parameter_section, + PatchHydra._config_section + ) + ) + PatchHydra._current_task.connect(**PatchHydra._last_untracked_state["connect"]) + if PatchHydra._last_untracked_state.get("_set_configuration"): # noinspection PyProtectedMember - PatchHydra._current_task._set_configuration(**PatchHydra._last_untracked_state['_set_configuration']) + PatchHydra._current_task._set_configuration(**PatchHydra._last_untracked_state["_set_configuration"]) PatchHydra._last_untracked_state = {} else: # if patching failed set it to None @@ -63,36 +84,34 @@ class PatchHydra(object): @staticmethod def _patched_hydra_run(self, config_name, task_function, overrides, *args, **kwargs): PatchHydra._allow_omegaconf_edit = False + if not running_remotely(): + return PatchHydra._original_hydra_run(self, config_name, task_function, overrides, *args, **kwargs) - # store the config + # get the parameters from the backend # noinspection PyBroadException try: - if running_remotely(): - if not PatchHydra._current_task: - from ..task import Task - PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id()) - # get the _parameter_allow_full_edit casted back to boolean - connected_config = dict() - connected_config[PatchHydra._parameter_allow_full_edit] = False - PatchHydra._current_task.connect(connected_config, name=PatchHydra._parameter_section) - PatchHydra._allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None) - # get all the overrides - full_parameters = PatchHydra._current_task.get_parameters(backwards_compatibility=False) - stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items() - if k.startswith(PatchHydra._parameter_section+'/')} - stored_config.pop(PatchHydra._parameter_allow_full_edit, None) - # noinspection PyBroadException - try: - overrides = yaml.safe_load(full_parameters.get("Args/overrides", "")) or [] - except Exception: - overrides = [] - if overrides and not isinstance(overrides, (list, tuple)): - overrides = [overrides] - overrides += ['{}={}'.format(k, v) for k, v in stored_config.items()] - overrides = [("+" + o) if (o.startswith("+") and not o.startswith("++")) else o for o in overrides] - else: - # We take care of it inside the _patched_run_job - pass + if not PatchHydra._current_task: + from ..task import Task + PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id()) + # get the _parameter_allow_full_edit casted back to boolean + connected_config = {} + connected_config[PatchHydra._parameter_allow_full_edit] = False + PatchHydra._current_task.connect(connected_config, name=PatchHydra._parameter_section) + PatchHydra._allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None) + # get all the overrides + full_parameters = PatchHydra._current_task.get_parameters(backwards_compatibility=False) + stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items() + if k.startswith(PatchHydra._parameter_section+'/')} + stored_config.pop(PatchHydra._parameter_allow_full_edit, None) + for override_k, override_v in stored_config.items(): + if override_k.startswith("~"): + new_override = override_k + else: + new_override = "++" + override_k.lstrip("+") + if override_v is not None and override_v != "": + new_override += "=" + override_v + overrides.append(new_override) + PatchHydra._should_delete_overrides = True except Exception: pass @@ -114,12 +133,18 @@ class PatchHydra(object): # store the config # noinspection PyBroadException try: - if running_remotely(): - # we take care of it in the hydra run (where we have access to the overrides) - pass - else: + if not running_remotely(): + # note that we fetch the overrides from the backend in hydra run when running remotely, + # here we just get them from hydra to be stored as configuration/parameters overrides = config.hydra.overrides.task - stored_config = dict(arg.split('=', 1) for arg in overrides) + stored_config = {} + for arg in overrides: + arg = arg.lstrip("+") + if "=" in arg: + k, v = arg.split("=", 1) + stored_config[k] = v + else: + stored_config[arg] = None stored_config[PatchHydra._parameter_allow_full_edit] = False if PatchHydra._current_task: PatchHydra._current_task.connect(stored_config, name=PatchHydra._parameter_section) @@ -127,9 +152,7 @@ class PatchHydra(object): else: PatchHydra._last_untracked_state['connect'] = dict( mutable=stored_config, name=PatchHydra._parameter_section) - # Maybe ?! remove the overrides section from the Args (we have it here) - # But when used with a Pipeline this is the only section we get... so we leave it here anyhow - # PatchHydra._current_task.delete_parameter('Args/overrides') + PatchHydra._should_delete_overrides = True except Exception: pass @@ -176,8 +199,7 @@ class PatchHydra(object): else: # noinspection PyProtectedMember omega_yaml = PatchHydra._current_task._get_configuration_text(PatchHydra._config_section) - loaded_config = OmegaConf.load(io.StringIO(omega_yaml)) - a_config = OmegaConf.merge(a_config, loaded_config) + a_config = OmegaConf.load(io.StringIO(omega_yaml)) PatchHydra._register_omegaconf(a_config, is_read_only=False) return task_function(a_config, *a_args, **a_kwargs) @@ -194,10 +216,6 @@ class PatchHydra(object): description = 'Full OmegaConf YAML configuration overridden! ({}/{}=True)'.format( PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit) - # we should not have the hydra section in the config, but this seems never to be the case anymore. - # config = config.copy() - # config.pop('hydra', None) - configuration = dict( name=PatchHydra._config_section, description=description, diff --git a/clearml/task.py b/clearml/task.py index a24c8f8e..ac896798 100644 --- a/clearml/task.py +++ b/clearml/task.py @@ -739,6 +739,8 @@ class Task(_Task): if argparser_parseargs_called(): for parser, parsed_args in get_argparser_last_args(): task._connect_argparse(parser=parser, parsed_args=parsed_args) + + PatchHydra.delete_overrides() elif argparser_parseargs_called(): # actually we have nothing to do, in remote running, the argparser will ignore # all non argparser parameters, only caveat if parameter connected with the same name