Add support for Hydra command-line syntax for modifying omegaconf

This commit is contained in:
allegroai 2023-10-24 18:36:53 +03:00
parent 9142f861fd
commit c26efb83af
3 changed files with 69 additions and 47 deletions

View File

@ -1344,12 +1344,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
params = self.get_parameters(cast=cast) params = self.get_parameters(cast=cast)
return params.get(name, default) return params.get(name, default)
def delete_parameter(self, name): def delete_parameter(self, name, force=False):
# type: (str) -> bool # type: (str) -> bool
""" """
Delete a parameter by its full name Section/name. Delete a parameter by its full name Section/name.
:param name: Parameter name in full, i.e. Section/name. For example, 'Args/batch_size' :param name: Parameter name in full, i.e. Section/name. For example, 'Args/batch_size'
:param force: If set to True then both new and running task hyper params can be deleted.
Otherwise only the new task ones. Default is False
:return: True if the parameter was deleted successfully :return: True if the parameter was deleted successfully
""" """
if not Session.check_min_api_version('2.9'): if not Session.check_min_api_version('2.9'):
@ -1360,7 +1362,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
with self._edit_lock: with self._edit_lock:
paramkey = tasks.ParamKey(section=name.split('/', 1)[0], name=name.split('/', 1)[1]) paramkey = tasks.ParamKey(section=name.split('/', 1)[0], name=name.split('/', 1)[1])
res = self.send(tasks.DeleteHyperParamsRequest( res = self.send(tasks.DeleteHyperParamsRequest(
task=self.id, hyperparams=[paramkey]), raise_on_errors=False) task=self.id, hyperparams=[paramkey], force=force), raise_on_errors=False)
self.reload() self.reload()
return res.ok() return res.ok()

View File

@ -15,6 +15,8 @@ class PatchHydra(object):
_config_section = 'OmegaConf' _config_section = 'OmegaConf'
_parameter_section = 'Hydra' _parameter_section = 'Hydra'
_parameter_allow_full_edit = '_allow_omegaconf_edit_' _parameter_allow_full_edit = '_allow_omegaconf_edit_'
_should_delete_overrides = False
_overrides_section = "Args/overrides"
@classmethod @classmethod
def patch_hydra(cls): def patch_hydra(cls):
@ -42,6 +44,12 @@ class PatchHydra(object):
except Exception: except Exception:
return False return False
@classmethod
def delete_overrides(cls):
if not cls._should_delete_overrides or not cls._current_task:
return
cls._current_task.delete_parameter(cls._overrides_section, force=True)
@staticmethod @staticmethod
def update_current_task(task): def update_current_task(task):
# set current Task before patching # set current Task before patching
@ -50,11 +58,24 @@ class PatchHydra(object):
return return
if PatchHydra.patch_hydra(): if PatchHydra.patch_hydra():
# check if we have an untracked state, store it. # check if we have an untracked state, store it.
if PatchHydra._last_untracked_state.get('connect'): if PatchHydra._last_untracked_state.get("connect"):
PatchHydra._current_task.connect(**PatchHydra._last_untracked_state['connect']) if PatchHydra._parameter_allow_full_edit in PatchHydra._last_untracked_state["connect"].get("mutable", {}):
if PatchHydra._last_untracked_state.get('_set_configuration'): allow_omegaconf_edit_section = PatchHydra._parameter_section + "/" + PatchHydra._parameter_allow_full_edit
allow_omegaconf_edit_section_val = PatchHydra._last_untracked_state["connect"]["mutable"].pop(
PatchHydra._parameter_allow_full_edit
)
PatchHydra._current_task.set_parameter(
allow_omegaconf_edit_section,
allow_omegaconf_edit_section_val,
description="If True, the `{}` parameter section will be completely ignored. The OmegaConf will instead be pulled from the `{}` section".format(
PatchHydra._parameter_section,
PatchHydra._config_section
)
)
PatchHydra._current_task.connect(**PatchHydra._last_untracked_state["connect"])
if PatchHydra._last_untracked_state.get("_set_configuration"):
# noinspection PyProtectedMember # noinspection PyProtectedMember
PatchHydra._current_task._set_configuration(**PatchHydra._last_untracked_state['_set_configuration']) PatchHydra._current_task._set_configuration(**PatchHydra._last_untracked_state["_set_configuration"])
PatchHydra._last_untracked_state = {} PatchHydra._last_untracked_state = {}
else: else:
# if patching failed set it to None # if patching failed set it to None
@ -63,36 +84,34 @@ class PatchHydra(object):
@staticmethod @staticmethod
def _patched_hydra_run(self, config_name, task_function, overrides, *args, **kwargs): def _patched_hydra_run(self, config_name, task_function, overrides, *args, **kwargs):
PatchHydra._allow_omegaconf_edit = False PatchHydra._allow_omegaconf_edit = False
if not running_remotely():
return PatchHydra._original_hydra_run(self, config_name, task_function, overrides, *args, **kwargs)
# store the config # get the parameters from the backend
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if running_remotely(): if not PatchHydra._current_task:
if not PatchHydra._current_task: from ..task import Task
from ..task import Task PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id())
PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id()) # get the _parameter_allow_full_edit casted back to boolean
# get the _parameter_allow_full_edit casted back to boolean connected_config = {}
connected_config = dict() connected_config[PatchHydra._parameter_allow_full_edit] = False
connected_config[PatchHydra._parameter_allow_full_edit] = False PatchHydra._current_task.connect(connected_config, name=PatchHydra._parameter_section)
PatchHydra._current_task.connect(connected_config, name=PatchHydra._parameter_section) PatchHydra._allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None)
PatchHydra._allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None) # get all the overrides
# get all the overrides full_parameters = PatchHydra._current_task.get_parameters(backwards_compatibility=False)
full_parameters = PatchHydra._current_task.get_parameters(backwards_compatibility=False) stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items()
stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items() if k.startswith(PatchHydra._parameter_section+'/')}
if k.startswith(PatchHydra._parameter_section+'/')} stored_config.pop(PatchHydra._parameter_allow_full_edit, None)
stored_config.pop(PatchHydra._parameter_allow_full_edit, None) for override_k, override_v in stored_config.items():
# noinspection PyBroadException if override_k.startswith("~"):
try: new_override = override_k
overrides = yaml.safe_load(full_parameters.get("Args/overrides", "")) or [] else:
except Exception: new_override = "++" + override_k.lstrip("+")
overrides = [] if override_v is not None and override_v != "":
if overrides and not isinstance(overrides, (list, tuple)): new_override += "=" + override_v
overrides = [overrides] overrides.append(new_override)
overrides += ['{}={}'.format(k, v) for k, v in stored_config.items()] PatchHydra._should_delete_overrides = True
overrides = [("+" + o) if (o.startswith("+") and not o.startswith("++")) else o for o in overrides]
else:
# We take care of it inside the _patched_run_job
pass
except Exception: except Exception:
pass pass
@ -114,12 +133,18 @@ class PatchHydra(object):
# store the config # store the config
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if running_remotely(): if not running_remotely():
# we take care of it in the hydra run (where we have access to the overrides) # note that we fetch the overrides from the backend in hydra run when running remotely,
pass # here we just get them from hydra to be stored as configuration/parameters
else:
overrides = config.hydra.overrides.task overrides = config.hydra.overrides.task
stored_config = dict(arg.split('=', 1) for arg in overrides) stored_config = {}
for arg in overrides:
arg = arg.lstrip("+")
if "=" in arg:
k, v = arg.split("=", 1)
stored_config[k] = v
else:
stored_config[arg] = None
stored_config[PatchHydra._parameter_allow_full_edit] = False stored_config[PatchHydra._parameter_allow_full_edit] = False
if PatchHydra._current_task: if PatchHydra._current_task:
PatchHydra._current_task.connect(stored_config, name=PatchHydra._parameter_section) PatchHydra._current_task.connect(stored_config, name=PatchHydra._parameter_section)
@ -127,9 +152,7 @@ class PatchHydra(object):
else: else:
PatchHydra._last_untracked_state['connect'] = dict( PatchHydra._last_untracked_state['connect'] = dict(
mutable=stored_config, name=PatchHydra._parameter_section) mutable=stored_config, name=PatchHydra._parameter_section)
# Maybe ?! remove the overrides section from the Args (we have it here) PatchHydra._should_delete_overrides = True
# But when used with a Pipeline this is the only section we get... so we leave it here anyhow
# PatchHydra._current_task.delete_parameter('Args/overrides')
except Exception: except Exception:
pass pass
@ -176,8 +199,7 @@ class PatchHydra(object):
else: else:
# noinspection PyProtectedMember # noinspection PyProtectedMember
omega_yaml = PatchHydra._current_task._get_configuration_text(PatchHydra._config_section) omega_yaml = PatchHydra._current_task._get_configuration_text(PatchHydra._config_section)
loaded_config = OmegaConf.load(io.StringIO(omega_yaml)) a_config = OmegaConf.load(io.StringIO(omega_yaml))
a_config = OmegaConf.merge(a_config, loaded_config)
PatchHydra._register_omegaconf(a_config, is_read_only=False) PatchHydra._register_omegaconf(a_config, is_read_only=False)
return task_function(a_config, *a_args, **a_kwargs) return task_function(a_config, *a_args, **a_kwargs)
@ -194,10 +216,6 @@ class PatchHydra(object):
description = 'Full OmegaConf YAML configuration overridden! ({}/{}=True)'.format( description = 'Full OmegaConf YAML configuration overridden! ({}/{}=True)'.format(
PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit) PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit)
# we should not have the hydra section in the config, but this seems never to be the case anymore.
# config = config.copy()
# config.pop('hydra', None)
configuration = dict( configuration = dict(
name=PatchHydra._config_section, name=PatchHydra._config_section,
description=description, description=description,

View File

@ -739,6 +739,8 @@ class Task(_Task):
if argparser_parseargs_called(): if argparser_parseargs_called():
for parser, parsed_args in get_argparser_last_args(): for parser, parsed_args in get_argparser_last_args():
task._connect_argparse(parser=parser, parsed_args=parsed_args) task._connect_argparse(parser=parser, parsed_args=parsed_args)
PatchHydra.delete_overrides()
elif argparser_parseargs_called(): elif argparser_parseargs_called():
# actually we have nothing to do, in remote running, the argparser will ignore # actually we have nothing to do, in remote running, the argparser will ignore
# all non argparser parameters, only caveat if parameter connected with the same name # all non argparser parameters, only caveat if parameter connected with the same name