Handle appends to Hydra defaults list

This commit is contained in:
allegroai 2023-11-21 11:05:30 +02:00
parent 0af630e9cf
commit 5911d9e6d6

View File

@ -16,6 +16,9 @@ class PatchHydra(object):
_parameter_allow_full_edit = '_allow_omegaconf_edit_' _parameter_allow_full_edit = '_allow_omegaconf_edit_'
_should_delete_overrides = False _should_delete_overrides = False
_overrides_section = "Args/overrides" _overrides_section = "Args/overrides"
_default_hydra_context = None
_overrides_parser = None
_config_group_warning_sent = False
@classmethod @classmethod
def patch_hydra(cls): def patch_hydra(cls):
@ -82,6 +85,7 @@ class PatchHydra(object):
@staticmethod @staticmethod
def _patched_hydra_run(self, config_name, task_function, overrides, *args, **kwargs): def _patched_hydra_run(self, config_name, task_function, overrides, *args, **kwargs):
PatchHydra._default_hydra_context = self
PatchHydra._allow_omegaconf_edit = False PatchHydra._allow_omegaconf_edit = False
if not running_remotely(): if not running_remotely():
return PatchHydra._original_hydra_run(self, config_name, task_function, overrides, *args, **kwargs) return PatchHydra._original_hydra_run(self, config_name, task_function, overrides, *args, **kwargs)
@ -103,12 +107,11 @@ class PatchHydra(object):
if k.startswith(PatchHydra._parameter_section+'/')} if k.startswith(PatchHydra._parameter_section+'/')}
stored_config.pop(PatchHydra._parameter_allow_full_edit, None) stored_config.pop(PatchHydra._parameter_allow_full_edit, None)
for override_k, override_v in stored_config.items(): for override_k, override_v in stored_config.items():
if override_k.startswith("~"): new_override = override_k
new_override = override_k
else:
new_override = "++" + override_k.lstrip("+")
if override_v is not None and override_v != "": if override_v is not None and override_v != "":
new_override += "=" + override_v new_override += "=" + override_v
if not new_override.startswith("~") and not PatchHydra._is_group(self, new_override):
new_override = "++" + new_override.lstrip("+")
overrides.append(new_override) overrides.append(new_override)
PatchHydra._should_delete_overrides = True PatchHydra._should_delete_overrides = True
except Exception: except Exception:
@ -116,6 +119,28 @@ class PatchHydra(object):
return PatchHydra._original_hydra_run(self, config_name, task_function, overrides, *args, **kwargs) return PatchHydra._original_hydra_run(self, config_name, task_function, overrides, *args, **kwargs)
@staticmethod
def _parse_override(override):
if PatchHydra._overrides_parser is None:
from hydra.core.override_parser.overrides_parser import OverridesParser
PatchHydra._overrides_parser = OverridesParser.create()
return PatchHydra._overrides_parser.parse_overrides(overrides=[override])[0]
@staticmethod
def _is_group(hydra_context, override):
# noinspection PyBroadException
try:
override = PatchHydra._parse_override(override)
group_exists = hydra_context.config_loader.repository.group_exists(override.key_or_group)
return group_exists
except Exception:
if not PatchHydra._config_group_warning_sent:
LoggerRoot.get_base_logger().warning(
"Could not determine if Hydra is overriding a Config Group"
)
PatchHydra._config_group_warning_sent = True
return False
@staticmethod @staticmethod
def _patched_run_job(config, task_function, *args, **kwargs): def _patched_run_job(config, task_function, *args, **kwargs):
# noinspection PyBroadException # noinspection PyBroadException
@ -124,11 +149,12 @@ class PatchHydra(object):
failed_status = JobStatus.FAILED failed_status = JobStatus.FAILED
except Exception: except Exception:
LoggerRoot.get_base_logger(PatchHydra).warning( LoggerRoot.get_base_logger().warning(
"Could not import JobStatus from Hydra. Failed tasks will be marked as completed" "Could not import JobStatus from Hydra. Failed tasks will be marked as completed"
) )
failed_status = None failed_status = None
hydra_context = kwargs.get("hydra_context", PatchHydra._default_hydra_context)
# store the config # store the config
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -138,7 +164,8 @@ class PatchHydra(object):
overrides = config.hydra.overrides.task overrides = config.hydra.overrides.task
stored_config = {} stored_config = {}
for arg in overrides: for arg in overrides:
arg = arg.lstrip("+") if not PatchHydra._is_group(hydra_context, arg):
arg = arg.lstrip("+")
if "=" in arg: if "=" in arg:
k, v = arg.split("=", 1) k, v = arg.split("=", 1)
stored_config[k] = v stored_config[k] = v