Fix Hydra support both Hydra section overrides and

This commit is contained in:
Alex Burlacu 2023-08-11 13:09:19 +03:00
parent c15f012e1b
commit 2c44bff461

View File

@ -1,7 +1,7 @@
import io
import sys
from functools import partial
import yaml
from ..config import running_remotely, get_remote_task_id, DEV_TASK_NO_REUSE
from ..debugging.log import LoggerRoot
@ -81,7 +81,14 @@ class PatchHydra(object):
stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items()
if k.startswith(PatchHydra._parameter_section+'/')}
stored_config.pop(PatchHydra._parameter_allow_full_edit, None)
overrides = ['{}={}'.format(k, v) for k, v in stored_config.items()]
# noinspection PyBroadException
try:
overrides = yaml.safe_load(full_parameters.get("Args/overrides", "")) or []
except Exception:
overrides = []
if overrides and not isinstance(overrides, (list, tuple)):
overrides = [overrides]
overrides += ['{}={}'.format(k, v) for k, v in stored_config.items()]
else:
# We take care of it inside the _patched_run_job
pass
@ -119,7 +126,8 @@ class PatchHydra(object):
else:
PatchHydra._last_untracked_state['connect'] = dict(
mutable=stored_config, name=PatchHydra._parameter_section)
# todo: remove the overrides section from the Args (we have it here)
# Maybe ?! remove the overrides section from the Args (we have it here)
# But when used with a Pipeline this is the only section we get... so we leave it here anyhow
# PatchHydra._current_task.delete_parameter('Args/overrides')
except Exception:
pass