From 7a25f8afd9e227e77e58132d2fa8e88fdbda36c2 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 11 Feb 2021 14:27:28 +0200 Subject: [PATCH] Fix hydra multi-run support (issue #306) --- clearml/binding/hydra_bind.py | 87 +++++++++++++++++++++++++++-------- 1 file changed, 68 insertions(+), 19 deletions(-) diff --git a/clearml/binding/hydra_bind.py b/clearml/binding/hydra_bind.py index 7ed61780..fcb3523d 100644 --- a/clearml/binding/hydra_bind.py +++ b/clearml/binding/hydra_bind.py @@ -1,11 +1,15 @@ import io import sys +from functools import partial -from ..config import running_remotely, get_remote_task_id +from ..config import running_remotely, get_remote_task_id, DEV_TASK_NO_REUSE +from ..debugging.log import LoggerRoot class PatchHydra(object): _original_run_job = None + _original_hydra_run = None + _allow_omegaconf_edit = None _current_task = None _last_untracked_state = {} _config_section = 'OmegaConf' @@ -24,9 +28,16 @@ class PatchHydra(object): return False from hydra._internal import hydra as internal_hydra # noqa + from hydra.core import utils as utils_hydra # noqa + from hydra._internal.core_plugins import basic_launcher # noqa - cls._original_run_job = internal_hydra.Hydra.run - internal_hydra.Hydra.run = cls._patched_run_job + cls._original_hydra_run = internal_hydra.Hydra.run + internal_hydra.Hydra.run = cls._patched_hydra_run + + cls._original_run_job = utils_hydra.run_job + utils_hydra.run_job = cls._patched_run_job + internal_hydra.run_job = cls._patched_run_job + basic_launcher.run_job = cls._patched_run_job return True except Exception: return False @@ -48,20 +59,8 @@ class PatchHydra(object): PatchHydra._current_task = None @staticmethod - def _patched_run_job(self, config_name, task_function, overrides, *args, **kwargs): - allow_omegaconf_edit = False - - def patched_task_function(a_config, *a_args, **a_kwargs): - from omegaconf import OmegaConf # noqa - if not running_remotely() or not allow_omegaconf_edit: - PatchHydra._register_omegaconf(a_config) - else: - # noinspection PyProtectedMember - omega_yaml = PatchHydra._current_task._get_configuration_text(PatchHydra._config_section) - loaded_config = OmegaConf.load(io.StringIO(omega_yaml)) - a_config = OmegaConf.merge(a_config, loaded_config) - PatchHydra._register_omegaconf(a_config, is_read_only=False) - return task_function(a_config, *a_args, **a_kwargs) + def _patched_hydra_run(self, config_name, task_function, overrides, *args, **kwargs): + PatchHydra._allow_omegaconf_edit = False # store the config # noinspection PyBroadException @@ -74,7 +73,7 @@ class PatchHydra(object): connected_config = dict() connected_config[PatchHydra._parameter_allow_full_edit] = False PatchHydra._current_task.connect(connected_config, name=PatchHydra._parameter_section) - allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None) + PatchHydra._allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None) # get all the overrides full_parameters = PatchHydra._current_task.get_parameters(backwards_compatibility=False) stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items() @@ -82,6 +81,23 @@ class PatchHydra(object): stored_config.pop(PatchHydra._parameter_allow_full_edit, None) overrides = ['{}={}'.format(k, v) for k, v in stored_config.items()] else: + # We take care of it inside the _patched_run_job + pass + except Exception: + pass + + return PatchHydra._original_hydra_run(self, config_name, task_function, overrides, *args, **kwargs) + + @staticmethod + def _patched_run_job(config, task_function, *args, **kwargs): + # store the config + # noinspection PyBroadException + try: + if running_remotely(): + # we take care of it in the hydra run (where we have access to the overrides) + pass + else: + overrides = config.hydra.overrides.task stored_config = dict(arg.split('=', 1) for arg in overrides) stored_config[PatchHydra._parameter_allow_full_edit] = False if PatchHydra._current_task: @@ -94,7 +110,40 @@ class PatchHydra(object): # PatchHydra._current_task.delete_parameter('Args/overrides') except Exception: pass - return PatchHydra._original_run_job(self, config_name, patched_task_function, overrides, *args, **kwargs) + + pre_app_task_init_call = bool(PatchHydra._current_task) + + if pre_app_task_init_call: + LoggerRoot.get_base_logger(PatchHydra).info( + 'Task.init called outside of Hydra-App. For full Hydra multi-run support, ' + 'move the Task.init call into the Hydra-App main function') + + result = PatchHydra._original_run_job( + config, partial(PatchHydra._patched_task_function, task_function,), + *args, **kwargs) + + # if we have Task.init called inside the App, we close it after the app is done. + # This will make sure that hydra run will create multiple Tasks + if not running_remotely() and not pre_app_task_init_call and PatchHydra._current_task: + PatchHydra._current_task.close() + # make sure we do not reuse the Task if we have a multi-run session + DEV_TASK_NO_REUSE.set(True) + PatchHydra._current_task = None + + return result + + @staticmethod + def _patched_task_function(task_function, a_config, *a_args, **a_kwargs): + from omegaconf import OmegaConf # noqa + if not running_remotely() or not PatchHydra._allow_omegaconf_edit: + PatchHydra._register_omegaconf(a_config) + else: + # noinspection PyProtectedMember + omega_yaml = PatchHydra._current_task._get_configuration_text(PatchHydra._config_section) + loaded_config = OmegaConf.load(io.StringIO(omega_yaml)) + a_config = OmegaConf.merge(a_config, loaded_config) + PatchHydra._register_omegaconf(a_config, is_read_only=False) + return task_function(a_config, *a_args, **a_kwargs) @staticmethod def _register_omegaconf(config, is_read_only=True):