mirror of
https://github.com/clearml/clearml
synced 2025-03-04 02:57:24 +00:00
Handle appends to Hydra defaults list
This commit is contained in:
parent
0af630e9cf
commit
5911d9e6d6
@ -16,6 +16,9 @@ class PatchHydra(object):
|
||||
_parameter_allow_full_edit = '_allow_omegaconf_edit_'
|
||||
_should_delete_overrides = False
|
||||
_overrides_section = "Args/overrides"
|
||||
_default_hydra_context = None
|
||||
_overrides_parser = None
|
||||
_config_group_warning_sent = False
|
||||
|
||||
@classmethod
|
||||
def patch_hydra(cls):
|
||||
@ -82,6 +85,7 @@ class PatchHydra(object):
|
||||
|
||||
@staticmethod
|
||||
def _patched_hydra_run(self, config_name, task_function, overrides, *args, **kwargs):
|
||||
PatchHydra._default_hydra_context = self
|
||||
PatchHydra._allow_omegaconf_edit = False
|
||||
if not running_remotely():
|
||||
return PatchHydra._original_hydra_run(self, config_name, task_function, overrides, *args, **kwargs)
|
||||
@ -103,12 +107,11 @@ class PatchHydra(object):
|
||||
if k.startswith(PatchHydra._parameter_section+'/')}
|
||||
stored_config.pop(PatchHydra._parameter_allow_full_edit, None)
|
||||
for override_k, override_v in stored_config.items():
|
||||
if override_k.startswith("~"):
|
||||
new_override = override_k
|
||||
else:
|
||||
new_override = "++" + override_k.lstrip("+")
|
||||
new_override = override_k
|
||||
if override_v is not None and override_v != "":
|
||||
new_override += "=" + override_v
|
||||
if not new_override.startswith("~") and not PatchHydra._is_group(self, new_override):
|
||||
new_override = "++" + new_override.lstrip("+")
|
||||
overrides.append(new_override)
|
||||
PatchHydra._should_delete_overrides = True
|
||||
except Exception:
|
||||
@ -116,6 +119,28 @@ class PatchHydra(object):
|
||||
|
||||
return PatchHydra._original_hydra_run(self, config_name, task_function, overrides, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _parse_override(override):
|
||||
if PatchHydra._overrides_parser is None:
|
||||
from hydra.core.override_parser.overrides_parser import OverridesParser
|
||||
PatchHydra._overrides_parser = OverridesParser.create()
|
||||
return PatchHydra._overrides_parser.parse_overrides(overrides=[override])[0]
|
||||
|
||||
@staticmethod
|
||||
def _is_group(hydra_context, override):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
override = PatchHydra._parse_override(override)
|
||||
group_exists = hydra_context.config_loader.repository.group_exists(override.key_or_group)
|
||||
return group_exists
|
||||
except Exception:
|
||||
if not PatchHydra._config_group_warning_sent:
|
||||
LoggerRoot.get_base_logger().warning(
|
||||
"Could not determine if Hydra is overriding a Config Group"
|
||||
)
|
||||
PatchHydra._config_group_warning_sent = True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _patched_run_job(config, task_function, *args, **kwargs):
|
||||
# noinspection PyBroadException
|
||||
@ -124,11 +149,12 @@ class PatchHydra(object):
|
||||
|
||||
failed_status = JobStatus.FAILED
|
||||
except Exception:
|
||||
LoggerRoot.get_base_logger(PatchHydra).warning(
|
||||
LoggerRoot.get_base_logger().warning(
|
||||
"Could not import JobStatus from Hydra. Failed tasks will be marked as completed"
|
||||
)
|
||||
failed_status = None
|
||||
|
||||
hydra_context = kwargs.get("hydra_context", PatchHydra._default_hydra_context)
|
||||
# store the config
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -138,7 +164,8 @@ class PatchHydra(object):
|
||||
overrides = config.hydra.overrides.task
|
||||
stored_config = {}
|
||||
for arg in overrides:
|
||||
arg = arg.lstrip("+")
|
||||
if not PatchHydra._is_group(hydra_context, arg):
|
||||
arg = arg.lstrip("+")
|
||||
if "=" in arg:
|
||||
k, v = arg.split("=", 1)
|
||||
stored_config[k] = v
|
||||
|
Loading…
Reference in New Issue
Block a user