mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Add support for Hydra command-line syntax for modifying omegaconf
This commit is contained in:
		
							parent
							
								
									9142f861fd
								
							
						
					
					
						commit
						c26efb83af
					
				@ -1344,12 +1344,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
 | 
			
		||||
        params = self.get_parameters(cast=cast)
 | 
			
		||||
        return params.get(name, default)
 | 
			
		||||
 | 
			
		||||
    def delete_parameter(self, name):
 | 
			
		||||
    def delete_parameter(self, name, force=False):
 | 
			
		||||
        # type: (str) -> bool
 | 
			
		||||
        """
 | 
			
		||||
        Delete a parameter by its full name Section/name.
 | 
			
		||||
 | 
			
		||||
        :param name: Parameter name in full, i.e. Section/name. For example, 'Args/batch_size'
 | 
			
		||||
        :param force: If set to True then both new and running task hyper params can be deleted.
 | 
			
		||||
            Otherwise only the new task ones. Default is False
 | 
			
		||||
        :return: True if the parameter was deleted successfully
 | 
			
		||||
        """
 | 
			
		||||
        if not Session.check_min_api_version('2.9'):
 | 
			
		||||
@ -1360,7 +1362,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
 | 
			
		||||
        with self._edit_lock:
 | 
			
		||||
            paramkey = tasks.ParamKey(section=name.split('/', 1)[0], name=name.split('/', 1)[1])
 | 
			
		||||
            res = self.send(tasks.DeleteHyperParamsRequest(
 | 
			
		||||
                task=self.id, hyperparams=[paramkey]), raise_on_errors=False)
 | 
			
		||||
                task=self.id, hyperparams=[paramkey], force=force), raise_on_errors=False)
 | 
			
		||||
            self.reload()
 | 
			
		||||
 | 
			
		||||
        return res.ok()
 | 
			
		||||
 | 
			
		||||
@ -15,6 +15,8 @@ class PatchHydra(object):
 | 
			
		||||
    _config_section = 'OmegaConf'
 | 
			
		||||
    _parameter_section = 'Hydra'
 | 
			
		||||
    _parameter_allow_full_edit = '_allow_omegaconf_edit_'
 | 
			
		||||
    _should_delete_overrides = False
 | 
			
		||||
    _overrides_section = "Args/overrides"
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def patch_hydra(cls):
 | 
			
		||||
@ -42,6 +44,12 @@ class PatchHydra(object):
 | 
			
		||||
        except Exception:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def delete_overrides(cls):
 | 
			
		||||
        if not cls._should_delete_overrides or not cls._current_task:
 | 
			
		||||
            return
 | 
			
		||||
        cls._current_task.delete_parameter(cls._overrides_section, force=True)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def update_current_task(task):
 | 
			
		||||
        # set current Task before patching
 | 
			
		||||
@ -50,11 +58,24 @@ class PatchHydra(object):
 | 
			
		||||
            return
 | 
			
		||||
        if PatchHydra.patch_hydra():
 | 
			
		||||
            # check if we have an untracked state, store it.
 | 
			
		||||
            if PatchHydra._last_untracked_state.get('connect'):
 | 
			
		||||
                PatchHydra._current_task.connect(**PatchHydra._last_untracked_state['connect'])
 | 
			
		||||
            if PatchHydra._last_untracked_state.get('_set_configuration'):
 | 
			
		||||
            if PatchHydra._last_untracked_state.get("connect"):
 | 
			
		||||
                if PatchHydra._parameter_allow_full_edit in PatchHydra._last_untracked_state["connect"].get("mutable", {}):
 | 
			
		||||
                    allow_omegaconf_edit_section = PatchHydra._parameter_section + "/" + PatchHydra._parameter_allow_full_edit
 | 
			
		||||
                    allow_omegaconf_edit_section_val = PatchHydra._last_untracked_state["connect"]["mutable"].pop(
 | 
			
		||||
                        PatchHydra._parameter_allow_full_edit
 | 
			
		||||
                    )
 | 
			
		||||
                    PatchHydra._current_task.set_parameter(
 | 
			
		||||
                        allow_omegaconf_edit_section,
 | 
			
		||||
                        allow_omegaconf_edit_section_val,
 | 
			
		||||
                        description="If True, the `{}` parameter section will be completely ignored. The OmegaConf will instead be pulled from the `{}` section".format(
 | 
			
		||||
                            PatchHydra._parameter_section,
 | 
			
		||||
                            PatchHydra._config_section
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
                PatchHydra._current_task.connect(**PatchHydra._last_untracked_state["connect"])
 | 
			
		||||
            if PatchHydra._last_untracked_state.get("_set_configuration"):
 | 
			
		||||
                # noinspection PyProtectedMember
 | 
			
		||||
                PatchHydra._current_task._set_configuration(**PatchHydra._last_untracked_state['_set_configuration'])
 | 
			
		||||
                PatchHydra._current_task._set_configuration(**PatchHydra._last_untracked_state["_set_configuration"])
 | 
			
		||||
            PatchHydra._last_untracked_state = {}
 | 
			
		||||
        else:
 | 
			
		||||
            # if patching failed set it to None
 | 
			
		||||
@ -63,36 +84,34 @@ class PatchHydra(object):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _patched_hydra_run(self, config_name, task_function, overrides, *args, **kwargs):
 | 
			
		||||
        PatchHydra._allow_omegaconf_edit = False
 | 
			
		||||
        if not running_remotely():
 | 
			
		||||
            return PatchHydra._original_hydra_run(self, config_name, task_function, overrides, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # store the config
 | 
			
		||||
        # get the parameters from the backend
 | 
			
		||||
        # noinspection PyBroadException
 | 
			
		||||
        try:
 | 
			
		||||
            if running_remotely():
 | 
			
		||||
                if not PatchHydra._current_task:
 | 
			
		||||
                    from ..task import Task
 | 
			
		||||
                    PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id())
 | 
			
		||||
                # get the _parameter_allow_full_edit casted back to boolean
 | 
			
		||||
                connected_config = dict()
 | 
			
		||||
                connected_config[PatchHydra._parameter_allow_full_edit] = False
 | 
			
		||||
                PatchHydra._current_task.connect(connected_config, name=PatchHydra._parameter_section)
 | 
			
		||||
                PatchHydra._allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None)
 | 
			
		||||
                # get all the overrides
 | 
			
		||||
                full_parameters = PatchHydra._current_task.get_parameters(backwards_compatibility=False)
 | 
			
		||||
                stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items()
 | 
			
		||||
                                 if k.startswith(PatchHydra._parameter_section+'/')}
 | 
			
		||||
                stored_config.pop(PatchHydra._parameter_allow_full_edit, None)
 | 
			
		||||
                # noinspection PyBroadException
 | 
			
		||||
                try:
 | 
			
		||||
                    overrides = yaml.safe_load(full_parameters.get("Args/overrides", "")) or []
 | 
			
		||||
                except Exception:
 | 
			
		||||
                    overrides = []
 | 
			
		||||
                if overrides and not isinstance(overrides, (list, tuple)):
 | 
			
		||||
                    overrides = [overrides]
 | 
			
		||||
                overrides += ['{}={}'.format(k, v) for k, v in stored_config.items()]
 | 
			
		||||
                overrides = [("+" + o) if (o.startswith("+") and not o.startswith("++")) else o for o in overrides]
 | 
			
		||||
            else:
 | 
			
		||||
                # We take care of it inside the _patched_run_job
 | 
			
		||||
                pass
 | 
			
		||||
            if not PatchHydra._current_task:
 | 
			
		||||
                from ..task import Task
 | 
			
		||||
                PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id())
 | 
			
		||||
            # get the _parameter_allow_full_edit casted back to boolean
 | 
			
		||||
            connected_config = {}
 | 
			
		||||
            connected_config[PatchHydra._parameter_allow_full_edit] = False
 | 
			
		||||
            PatchHydra._current_task.connect(connected_config, name=PatchHydra._parameter_section)
 | 
			
		||||
            PatchHydra._allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None)
 | 
			
		||||
            # get all the overrides
 | 
			
		||||
            full_parameters = PatchHydra._current_task.get_parameters(backwards_compatibility=False)
 | 
			
		||||
            stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items()
 | 
			
		||||
                             if k.startswith(PatchHydra._parameter_section+'/')}
 | 
			
		||||
            stored_config.pop(PatchHydra._parameter_allow_full_edit, None)
 | 
			
		||||
            for override_k, override_v in stored_config.items():
 | 
			
		||||
                if override_k.startswith("~"):
 | 
			
		||||
                    new_override = override_k
 | 
			
		||||
                else:
 | 
			
		||||
                    new_override = "++" + override_k.lstrip("+")
 | 
			
		||||
                if override_v is not None and override_v != "":
 | 
			
		||||
                    new_override += "=" + override_v
 | 
			
		||||
                overrides.append(new_override)
 | 
			
		||||
            PatchHydra._should_delete_overrides = True
 | 
			
		||||
        except Exception:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
@ -114,12 +133,18 @@ class PatchHydra(object):
 | 
			
		||||
        # store the config
 | 
			
		||||
        # noinspection PyBroadException
 | 
			
		||||
        try:
 | 
			
		||||
            if running_remotely():
 | 
			
		||||
                # we take care of it in the hydra run (where we have access to the overrides)
 | 
			
		||||
                pass
 | 
			
		||||
            else:
 | 
			
		||||
            if not running_remotely():
 | 
			
		||||
                # note that we fetch the overrides from the backend in hydra run when running remotely,
 | 
			
		||||
                # here we just get them from hydra to be stored as configuration/parameters
 | 
			
		||||
                overrides = config.hydra.overrides.task
 | 
			
		||||
                stored_config = dict(arg.split('=', 1) for arg in overrides)
 | 
			
		||||
                stored_config = {}
 | 
			
		||||
                for arg in overrides:
 | 
			
		||||
                    arg = arg.lstrip("+")
 | 
			
		||||
                    if "=" in arg:
 | 
			
		||||
                        k, v = arg.split("=", 1)
 | 
			
		||||
                        stored_config[k] = v
 | 
			
		||||
                    else:
 | 
			
		||||
                        stored_config[arg] = None
 | 
			
		||||
                stored_config[PatchHydra._parameter_allow_full_edit] = False
 | 
			
		||||
                if PatchHydra._current_task:
 | 
			
		||||
                    PatchHydra._current_task.connect(stored_config, name=PatchHydra._parameter_section)
 | 
			
		||||
@ -127,9 +152,7 @@ class PatchHydra(object):
 | 
			
		||||
                else:
 | 
			
		||||
                    PatchHydra._last_untracked_state['connect'] = dict(
 | 
			
		||||
                        mutable=stored_config, name=PatchHydra._parameter_section)
 | 
			
		||||
                # Maybe ?! remove the overrides section from the Args (we have it here)
 | 
			
		||||
                # But when used with a Pipeline this is the only section we get... so we leave it here anyhow
 | 
			
		||||
                # PatchHydra._current_task.delete_parameter('Args/overrides')
 | 
			
		||||
                PatchHydra._should_delete_overrides = True
 | 
			
		||||
        except Exception:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
@ -176,8 +199,7 @@ class PatchHydra(object):
 | 
			
		||||
        else:
 | 
			
		||||
            # noinspection PyProtectedMember
 | 
			
		||||
            omega_yaml = PatchHydra._current_task._get_configuration_text(PatchHydra._config_section)
 | 
			
		||||
            loaded_config = OmegaConf.load(io.StringIO(omega_yaml))
 | 
			
		||||
            a_config = OmegaConf.merge(a_config, loaded_config)
 | 
			
		||||
            a_config = OmegaConf.load(io.StringIO(omega_yaml))
 | 
			
		||||
            PatchHydra._register_omegaconf(a_config, is_read_only=False)
 | 
			
		||||
        return task_function(a_config, *a_args, **a_kwargs)
 | 
			
		||||
 | 
			
		||||
@ -194,10 +216,6 @@ class PatchHydra(object):
 | 
			
		||||
            description = 'Full OmegaConf YAML configuration overridden! ({}/{}=True)'.format(
 | 
			
		||||
                PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit)
 | 
			
		||||
 | 
			
		||||
        # we should not have the hydra section in the config, but this seems never to be the case anymore.
 | 
			
		||||
        # config = config.copy()
 | 
			
		||||
        # config.pop('hydra', None)
 | 
			
		||||
 | 
			
		||||
        configuration = dict(
 | 
			
		||||
            name=PatchHydra._config_section,
 | 
			
		||||
            description=description,
 | 
			
		||||
 | 
			
		||||
@ -739,6 +739,8 @@ class Task(_Task):
 | 
			
		||||
                if argparser_parseargs_called():
 | 
			
		||||
                    for parser, parsed_args in get_argparser_last_args():
 | 
			
		||||
                        task._connect_argparse(parser=parser, parsed_args=parsed_args)
 | 
			
		||||
 | 
			
		||||
                PatchHydra.delete_overrides()
 | 
			
		||||
            elif argparser_parseargs_called():
 | 
			
		||||
                # actually we have nothing to do, in remote running, the argparser will ignore
 | 
			
		||||
                # all non argparser parameters, only caveat if parameter connected with the same name
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user