From e9920e27edb6dc708013092ae93e7a778a6b5331 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 20 Nov 2020 00:07:21 +0200 Subject: [PATCH] Add Hydra bind support for Task.init call within the Hydra app, Issue #219 --- trains/binding/hydra_bind.py | 69 +++++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/trains/binding/hydra_bind.py b/trains/binding/hydra_bind.py index 2b536bd6..7ed61780 100644 --- a/trains/binding/hydra_bind.py +++ b/trains/binding/hydra_bind.py @@ -1,13 +1,13 @@ import io import sys -from ..config import running_remotely -from ..debugging.log import LoggerRoot +from ..config import running_remotely, get_remote_task_id class PatchHydra(object): _original_run_job = None _current_task = None + _last_untracked_state = {} _config_section = 'OmegaConf' _parameter_section = 'Hydra' _parameter_allow_full_edit = '_allow_omegaconf_edit_' @@ -23,22 +23,8 @@ class PatchHydra(object): if not sys.modules.get('hydra'): return False - from hydra.core import utils # noqa from hydra._internal import hydra as internal_hydra # noqa - # check if hydra is already initialized - if utils.HydraConfig.initialized(): - LoggerRoot.get_base_logger().warning( - "Hydra is already loaded storing read-only OmegaConf, " - "For full support call Task.init(...) before the Hydra App") - # noinspection PyBroadException - try: - # noinspection PyProtectedMember,PyUnresolvedReferences - PatchHydra._register_omegaconf(utils.HydraConfig.get()._get_root()) - except Exception: - pass - return False - cls._original_run_job = internal_hydra.Hydra.run internal_hydra.Hydra.run = cls._patched_run_job return True @@ -49,14 +35,20 @@ class PatchHydra(object): def update_current_task(task): # set current Task before patching PatchHydra._current_task = task - if not PatchHydra.patch_hydra(): + if PatchHydra.patch_hydra(): + # check if we have an untracked state, store it. + if PatchHydra._last_untracked_state.get('connect'): + PatchHydra._current_task.connect(**PatchHydra._last_untracked_state['connect']) + if PatchHydra._last_untracked_state.get('_set_configuration'): + # noinspection PyProtectedMember + PatchHydra._current_task._set_configuration(**PatchHydra._last_untracked_state['_set_configuration']) + PatchHydra._last_untracked_state = {} + else: # if patching failed set it to None PatchHydra._current_task = None @staticmethod def _patched_run_job(self, config_name, task_function, overrides, *args, **kwargs): - if not PatchHydra._current_task: - return PatchHydra._original_run_job(self, config_name, task_function, overrides, *args, **kwargs) allow_omegaconf_edit = False def patched_task_function(a_config, *a_args, **a_kwargs): @@ -75,6 +67,9 @@ class PatchHydra(object): # noinspection PyBroadException try: if running_remotely(): + if not PatchHydra._current_task: + from ..task import Task + PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id()) # get the _parameter_allow_full_edit casted back to boolean connected_config = dict() connected_config[PatchHydra._parameter_allow_full_edit] = False @@ -89,7 +84,12 @@ class PatchHydra(object): else: stored_config = dict(arg.split('=', 1) for arg in overrides) stored_config[PatchHydra._parameter_allow_full_edit] = False - PatchHydra._current_task.connect(stored_config, name=PatchHydra._parameter_section) + if PatchHydra._current_task: + PatchHydra._current_task.connect(stored_config, name=PatchHydra._parameter_section) + PatchHydra._last_untracked_state.pop('connect', None) + else: + PatchHydra._last_untracked_state['connect'] = dict( + mutable=stored_config, name=PatchHydra._parameter_section) # todo: remove the overrides section from the Args (we have it here) # PatchHydra._current_task.delete_parameter('Args/overrides') except Exception: @@ -101,17 +101,36 @@ class PatchHydra(object): from omegaconf import OmegaConf # noqa if is_read_only: - description = 'Full OmegaConf YAML configuration. ' \ - 'This is a read-only section, unless \'{}/{}\' is set to True'.format( - PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit) + description = \ + 'Full OmegaConf YAML configuration. ' \ + 'This is a read-only section, unless \'{}/{}\' is set to True'.format( + PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit) else: description = 'Full OmegaConf YAML configuration overridden! ({}/{}=True)'.format( PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit) - # noinspection PyProtectedMember - PatchHydra._current_task._set_configuration( + configuration = dict( name=PatchHydra._config_section, description=description, config_type='OmegaConf YAML', config_text=OmegaConf.to_yaml({k: v for k, v in config.items() if k not in ('hydra', )}) ) + if PatchHydra._current_task: + # noinspection PyProtectedMember + PatchHydra._current_task._set_configuration(**configuration) + PatchHydra._last_untracked_state.pop('_set_configuration', None) + else: + PatchHydra._last_untracked_state['_set_configuration'] = configuration + + +def __global_hydra_bind(): + # noinspection PyBroadException + try: + import hydra # noqa + PatchHydra.patch_hydra() + except Exception: + pass + + +# patch hydra +__global_hydra_bind()