From 6dd7b4e02ec9201e2857ee4bfaa389e4cff04248 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 8 Nov 2020 00:17:54 +0200 Subject: [PATCH] Add Hydra support (issue #219) --- trains/binding/hydra_bind.py | 117 +++++++++++++++++++++++++++++++++++ trains/task.py | 5 +- 2 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 trains/binding/hydra_bind.py diff --git a/trains/binding/hydra_bind.py b/trains/binding/hydra_bind.py new file mode 100644 index 00000000..2b536bd6 --- /dev/null +++ b/trains/binding/hydra_bind.py @@ -0,0 +1,117 @@ +import io +import sys + +from ..config import running_remotely +from ..debugging.log import LoggerRoot + + +class PatchHydra(object): + _original_run_job = None + _current_task = None + _config_section = 'OmegaConf' + _parameter_section = 'Hydra' + _parameter_allow_full_edit = '_allow_omegaconf_edit_' + + @classmethod + def patch_hydra(cls): + # noinspection PyBroadException + try: + # only once + if cls._original_run_job: + return True + # if hydra is not loaded, do not patch anything + if not sys.modules.get('hydra'): + return False + + from hydra.core import utils # noqa + from hydra._internal import hydra as internal_hydra # noqa + + # check if hydra is already initialized + if utils.HydraConfig.initialized(): + LoggerRoot.get_base_logger().warning( + "Hydra is already loaded storing read-only OmegaConf, " + "For full support call Task.init(...) before the Hydra App") + # noinspection PyBroadException + try: + # noinspection PyProtectedMember,PyUnresolvedReferences + PatchHydra._register_omegaconf(utils.HydraConfig.get()._get_root()) + except Exception: + pass + return False + + cls._original_run_job = internal_hydra.Hydra.run + internal_hydra.Hydra.run = cls._patched_run_job + return True + except Exception: + return False + + @staticmethod + def update_current_task(task): + # set current Task before patching + PatchHydra._current_task = task + if not PatchHydra.patch_hydra(): + # if patching failed set it to None + PatchHydra._current_task = None + + @staticmethod + def _patched_run_job(self, config_name, task_function, overrides, *args, **kwargs): + if not PatchHydra._current_task: + return PatchHydra._original_run_job(self, config_name, task_function, overrides, *args, **kwargs) + allow_omegaconf_edit = False + + def patched_task_function(a_config, *a_args, **a_kwargs): + from omegaconf import OmegaConf # noqa + if not running_remotely() or not allow_omegaconf_edit: + PatchHydra._register_omegaconf(a_config) + else: + # noinspection PyProtectedMember + omega_yaml = PatchHydra._current_task._get_configuration_text(PatchHydra._config_section) + loaded_config = OmegaConf.load(io.StringIO(omega_yaml)) + a_config = OmegaConf.merge(a_config, loaded_config) + PatchHydra._register_omegaconf(a_config, is_read_only=False) + return task_function(a_config, *a_args, **a_kwargs) + + # store the config + # noinspection PyBroadException + try: + if running_remotely(): + # get the _parameter_allow_full_edit casted back to boolean + connected_config = dict() + connected_config[PatchHydra._parameter_allow_full_edit] = False + PatchHydra._current_task.connect(connected_config, name=PatchHydra._parameter_section) + allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None) + # get all the overrides + full_parameters = PatchHydra._current_task.get_parameters(backwards_compatibility=False) + stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items() + if k.startswith(PatchHydra._parameter_section+'/')} + stored_config.pop(PatchHydra._parameter_allow_full_edit, None) + overrides = ['{}={}'.format(k, v) for k, v in stored_config.items()] + else: + stored_config = dict(arg.split('=', 1) for arg in overrides) + stored_config[PatchHydra._parameter_allow_full_edit] = False + PatchHydra._current_task.connect(stored_config, name=PatchHydra._parameter_section) + # todo: remove the overrides section from the Args (we have it here) + # PatchHydra._current_task.delete_parameter('Args/overrides') + except Exception: + pass + return PatchHydra._original_run_job(self, config_name, patched_task_function, overrides, *args, **kwargs) + + @staticmethod + def _register_omegaconf(config, is_read_only=True): + from omegaconf import OmegaConf # noqa + + if is_read_only: + description = 'Full OmegaConf YAML configuration. ' \ + 'This is a read-only section, unless \'{}/{}\' is set to True'.format( + PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit) + else: + description = 'Full OmegaConf YAML configuration overridden! ({}/{}=True)'.format( + PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit) + + # noinspection PyProtectedMember + PatchHydra._current_task._set_configuration( + name=PatchHydra._config_section, + description=description, + config_type='OmegaConf YAML', + config_text=OmegaConf.to_yaml({k: v for k, v in config.items() if k not in ('hydra', )}) + ) diff --git a/trains/task.py b/trains/task.py index 3ee0ba3e..84371915 100644 --- a/trains/task.py +++ b/trains/task.py @@ -42,6 +42,7 @@ from .binding.frameworks.tensorflow_bind import TensorflowBinding from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO from .binding.joblib_bind import PatchedJoblib from .binding.matplotlib_bind import PatchedMatplotlib +from .binding.hydra_bind import PatchHydra from .config import config, DEV_TASK_NO_REUSE, get_is_master_node from .config import running_remotely, get_remote_task_id from .config.cache import SessionCache @@ -337,7 +338,7 @@ class Task(_Task): .. code-block:: py auto_connect_frameworks={'matplotlib': True, 'tensorflow': True, 'pytorch': True, - 'xgboost': True, 'scikit': True, 'fastai': True, 'lightgbm': True} + 'xgboost': True, 'scikit': True, 'fastai': True, 'lightgbm': True, 'hydra': True} :param bool auto_resource_monitoring: Automatically create machine resource monitoring plots These plots appear in in the **Trains Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab, @@ -493,6 +494,8 @@ class Task(_Task): PatchOsFork.patch_fork() if auto_connect_frameworks: is_auto_connect_frameworks_bool = not isinstance(auto_connect_frameworks, dict) + if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('hydra', True): + PatchHydra.update_current_task(task) if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('scikit', True): PatchedJoblib.update_current_task(task) if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('matplotlib', True):