mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Add support for Hydra command-line syntax for modifying omegaconf
This commit is contained in:
parent
9142f861fd
commit
c26efb83af
@ -1344,12 +1344,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
params = self.get_parameters(cast=cast)
|
params = self.get_parameters(cast=cast)
|
||||||
return params.get(name, default)
|
return params.get(name, default)
|
||||||
|
|
||||||
def delete_parameter(self, name):
|
def delete_parameter(self, name, force=False):
|
||||||
# type: (str) -> bool
|
# type: (str) -> bool
|
||||||
"""
|
"""
|
||||||
Delete a parameter by its full name Section/name.
|
Delete a parameter by its full name Section/name.
|
||||||
|
|
||||||
:param name: Parameter name in full, i.e. Section/name. For example, 'Args/batch_size'
|
:param name: Parameter name in full, i.e. Section/name. For example, 'Args/batch_size'
|
||||||
|
:param force: If set to True then both new and running task hyper params can be deleted.
|
||||||
|
Otherwise only the new task ones. Default is False
|
||||||
:return: True if the parameter was deleted successfully
|
:return: True if the parameter was deleted successfully
|
||||||
"""
|
"""
|
||||||
if not Session.check_min_api_version('2.9'):
|
if not Session.check_min_api_version('2.9'):
|
||||||
@ -1360,7 +1362,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
with self._edit_lock:
|
with self._edit_lock:
|
||||||
paramkey = tasks.ParamKey(section=name.split('/', 1)[0], name=name.split('/', 1)[1])
|
paramkey = tasks.ParamKey(section=name.split('/', 1)[0], name=name.split('/', 1)[1])
|
||||||
res = self.send(tasks.DeleteHyperParamsRequest(
|
res = self.send(tasks.DeleteHyperParamsRequest(
|
||||||
task=self.id, hyperparams=[paramkey]), raise_on_errors=False)
|
task=self.id, hyperparams=[paramkey], force=force), raise_on_errors=False)
|
||||||
self.reload()
|
self.reload()
|
||||||
|
|
||||||
return res.ok()
|
return res.ok()
|
||||||
|
@ -15,6 +15,8 @@ class PatchHydra(object):
|
|||||||
_config_section = 'OmegaConf'
|
_config_section = 'OmegaConf'
|
||||||
_parameter_section = 'Hydra'
|
_parameter_section = 'Hydra'
|
||||||
_parameter_allow_full_edit = '_allow_omegaconf_edit_'
|
_parameter_allow_full_edit = '_allow_omegaconf_edit_'
|
||||||
|
_should_delete_overrides = False
|
||||||
|
_overrides_section = "Args/overrides"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def patch_hydra(cls):
|
def patch_hydra(cls):
|
||||||
@ -42,6 +44,12 @@ class PatchHydra(object):
|
|||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete_overrides(cls):
|
||||||
|
if not cls._should_delete_overrides or not cls._current_task:
|
||||||
|
return
|
||||||
|
cls._current_task.delete_parameter(cls._overrides_section, force=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task):
|
def update_current_task(task):
|
||||||
# set current Task before patching
|
# set current Task before patching
|
||||||
@ -50,11 +58,24 @@ class PatchHydra(object):
|
|||||||
return
|
return
|
||||||
if PatchHydra.patch_hydra():
|
if PatchHydra.patch_hydra():
|
||||||
# check if we have an untracked state, store it.
|
# check if we have an untracked state, store it.
|
||||||
if PatchHydra._last_untracked_state.get('connect'):
|
if PatchHydra._last_untracked_state.get("connect"):
|
||||||
PatchHydra._current_task.connect(**PatchHydra._last_untracked_state['connect'])
|
if PatchHydra._parameter_allow_full_edit in PatchHydra._last_untracked_state["connect"].get("mutable", {}):
|
||||||
if PatchHydra._last_untracked_state.get('_set_configuration'):
|
allow_omegaconf_edit_section = PatchHydra._parameter_section + "/" + PatchHydra._parameter_allow_full_edit
|
||||||
|
allow_omegaconf_edit_section_val = PatchHydra._last_untracked_state["connect"]["mutable"].pop(
|
||||||
|
PatchHydra._parameter_allow_full_edit
|
||||||
|
)
|
||||||
|
PatchHydra._current_task.set_parameter(
|
||||||
|
allow_omegaconf_edit_section,
|
||||||
|
allow_omegaconf_edit_section_val,
|
||||||
|
description="If True, the `{}` parameter section will be completely ignored. The OmegaConf will instead be pulled from the `{}` section".format(
|
||||||
|
PatchHydra._parameter_section,
|
||||||
|
PatchHydra._config_section
|
||||||
|
)
|
||||||
|
)
|
||||||
|
PatchHydra._current_task.connect(**PatchHydra._last_untracked_state["connect"])
|
||||||
|
if PatchHydra._last_untracked_state.get("_set_configuration"):
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
PatchHydra._current_task._set_configuration(**PatchHydra._last_untracked_state['_set_configuration'])
|
PatchHydra._current_task._set_configuration(**PatchHydra._last_untracked_state["_set_configuration"])
|
||||||
PatchHydra._last_untracked_state = {}
|
PatchHydra._last_untracked_state = {}
|
||||||
else:
|
else:
|
||||||
# if patching failed set it to None
|
# if patching failed set it to None
|
||||||
@ -63,16 +84,17 @@ class PatchHydra(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _patched_hydra_run(self, config_name, task_function, overrides, *args, **kwargs):
|
def _patched_hydra_run(self, config_name, task_function, overrides, *args, **kwargs):
|
||||||
PatchHydra._allow_omegaconf_edit = False
|
PatchHydra._allow_omegaconf_edit = False
|
||||||
|
if not running_remotely():
|
||||||
|
return PatchHydra._original_hydra_run(self, config_name, task_function, overrides, *args, **kwargs)
|
||||||
|
|
||||||
# store the config
|
# get the parameters from the backend
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
if running_remotely():
|
|
||||||
if not PatchHydra._current_task:
|
if not PatchHydra._current_task:
|
||||||
from ..task import Task
|
from ..task import Task
|
||||||
PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id())
|
PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id())
|
||||||
# get the _parameter_allow_full_edit casted back to boolean
|
# get the _parameter_allow_full_edit casted back to boolean
|
||||||
connected_config = dict()
|
connected_config = {}
|
||||||
connected_config[PatchHydra._parameter_allow_full_edit] = False
|
connected_config[PatchHydra._parameter_allow_full_edit] = False
|
||||||
PatchHydra._current_task.connect(connected_config, name=PatchHydra._parameter_section)
|
PatchHydra._current_task.connect(connected_config, name=PatchHydra._parameter_section)
|
||||||
PatchHydra._allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None)
|
PatchHydra._allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None)
|
||||||
@ -81,18 +103,15 @@ class PatchHydra(object):
|
|||||||
stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items()
|
stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items()
|
||||||
if k.startswith(PatchHydra._parameter_section+'/')}
|
if k.startswith(PatchHydra._parameter_section+'/')}
|
||||||
stored_config.pop(PatchHydra._parameter_allow_full_edit, None)
|
stored_config.pop(PatchHydra._parameter_allow_full_edit, None)
|
||||||
# noinspection PyBroadException
|
for override_k, override_v in stored_config.items():
|
||||||
try:
|
if override_k.startswith("~"):
|
||||||
overrides = yaml.safe_load(full_parameters.get("Args/overrides", "")) or []
|
new_override = override_k
|
||||||
except Exception:
|
|
||||||
overrides = []
|
|
||||||
if overrides and not isinstance(overrides, (list, tuple)):
|
|
||||||
overrides = [overrides]
|
|
||||||
overrides += ['{}={}'.format(k, v) for k, v in stored_config.items()]
|
|
||||||
overrides = [("+" + o) if (o.startswith("+") and not o.startswith("++")) else o for o in overrides]
|
|
||||||
else:
|
else:
|
||||||
# We take care of it inside the _patched_run_job
|
new_override = "++" + override_k.lstrip("+")
|
||||||
pass
|
if override_v is not None and override_v != "":
|
||||||
|
new_override += "=" + override_v
|
||||||
|
overrides.append(new_override)
|
||||||
|
PatchHydra._should_delete_overrides = True
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -114,12 +133,18 @@ class PatchHydra(object):
|
|||||||
# store the config
|
# store the config
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
if running_remotely():
|
if not running_remotely():
|
||||||
# we take care of it in the hydra run (where we have access to the overrides)
|
# note that we fetch the overrides from the backend in hydra run when running remotely,
|
||||||
pass
|
# here we just get them from hydra to be stored as configuration/parameters
|
||||||
else:
|
|
||||||
overrides = config.hydra.overrides.task
|
overrides = config.hydra.overrides.task
|
||||||
stored_config = dict(arg.split('=', 1) for arg in overrides)
|
stored_config = {}
|
||||||
|
for arg in overrides:
|
||||||
|
arg = arg.lstrip("+")
|
||||||
|
if "=" in arg:
|
||||||
|
k, v = arg.split("=", 1)
|
||||||
|
stored_config[k] = v
|
||||||
|
else:
|
||||||
|
stored_config[arg] = None
|
||||||
stored_config[PatchHydra._parameter_allow_full_edit] = False
|
stored_config[PatchHydra._parameter_allow_full_edit] = False
|
||||||
if PatchHydra._current_task:
|
if PatchHydra._current_task:
|
||||||
PatchHydra._current_task.connect(stored_config, name=PatchHydra._parameter_section)
|
PatchHydra._current_task.connect(stored_config, name=PatchHydra._parameter_section)
|
||||||
@ -127,9 +152,7 @@ class PatchHydra(object):
|
|||||||
else:
|
else:
|
||||||
PatchHydra._last_untracked_state['connect'] = dict(
|
PatchHydra._last_untracked_state['connect'] = dict(
|
||||||
mutable=stored_config, name=PatchHydra._parameter_section)
|
mutable=stored_config, name=PatchHydra._parameter_section)
|
||||||
# Maybe ?! remove the overrides section from the Args (we have it here)
|
PatchHydra._should_delete_overrides = True
|
||||||
# But when used with a Pipeline this is the only section we get... so we leave it here anyhow
|
|
||||||
# PatchHydra._current_task.delete_parameter('Args/overrides')
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -176,8 +199,7 @@ class PatchHydra(object):
|
|||||||
else:
|
else:
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
omega_yaml = PatchHydra._current_task._get_configuration_text(PatchHydra._config_section)
|
omega_yaml = PatchHydra._current_task._get_configuration_text(PatchHydra._config_section)
|
||||||
loaded_config = OmegaConf.load(io.StringIO(omega_yaml))
|
a_config = OmegaConf.load(io.StringIO(omega_yaml))
|
||||||
a_config = OmegaConf.merge(a_config, loaded_config)
|
|
||||||
PatchHydra._register_omegaconf(a_config, is_read_only=False)
|
PatchHydra._register_omegaconf(a_config, is_read_only=False)
|
||||||
return task_function(a_config, *a_args, **a_kwargs)
|
return task_function(a_config, *a_args, **a_kwargs)
|
||||||
|
|
||||||
@ -194,10 +216,6 @@ class PatchHydra(object):
|
|||||||
description = 'Full OmegaConf YAML configuration overridden! ({}/{}=True)'.format(
|
description = 'Full OmegaConf YAML configuration overridden! ({}/{}=True)'.format(
|
||||||
PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit)
|
PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit)
|
||||||
|
|
||||||
# we should not have the hydra section in the config, but this seems never to be the case anymore.
|
|
||||||
# config = config.copy()
|
|
||||||
# config.pop('hydra', None)
|
|
||||||
|
|
||||||
configuration = dict(
|
configuration = dict(
|
||||||
name=PatchHydra._config_section,
|
name=PatchHydra._config_section,
|
||||||
description=description,
|
description=description,
|
||||||
|
@ -739,6 +739,8 @@ class Task(_Task):
|
|||||||
if argparser_parseargs_called():
|
if argparser_parseargs_called():
|
||||||
for parser, parsed_args in get_argparser_last_args():
|
for parser, parsed_args in get_argparser_last_args():
|
||||||
task._connect_argparse(parser=parser, parsed_args=parsed_args)
|
task._connect_argparse(parser=parser, parsed_args=parsed_args)
|
||||||
|
|
||||||
|
PatchHydra.delete_overrides()
|
||||||
elif argparser_parseargs_called():
|
elif argparser_parseargs_called():
|
||||||
# actually we have nothing to do, in remote running, the argparser will ignore
|
# actually we have nothing to do, in remote running, the argparser will ignore
|
||||||
# all non argparser parameters, only caveat if parameter connected with the same name
|
# all non argparser parameters, only caveat if parameter connected with the same name
|
||||||
|
Loading…
Reference in New Issue
Block a user