mirror of
https://github.com/clearml/clearml
synced 2025-03-03 10:42:00 +00:00
Fix hydra multi-run support (issue #306)
This commit is contained in:
parent
f64d1a0993
commit
7a25f8afd9
@ -1,11 +1,15 @@
|
||||
import io
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
from ..config import running_remotely, get_remote_task_id
|
||||
from ..config import running_remotely, get_remote_task_id, DEV_TASK_NO_REUSE
|
||||
from ..debugging.log import LoggerRoot
|
||||
|
||||
|
||||
class PatchHydra(object):
|
||||
_original_run_job = None
|
||||
_original_hydra_run = None
|
||||
_allow_omegaconf_edit = None
|
||||
_current_task = None
|
||||
_last_untracked_state = {}
|
||||
_config_section = 'OmegaConf'
|
||||
@ -24,9 +28,16 @@ class PatchHydra(object):
|
||||
return False
|
||||
|
||||
from hydra._internal import hydra as internal_hydra # noqa
|
||||
from hydra.core import utils as utils_hydra # noqa
|
||||
from hydra._internal.core_plugins import basic_launcher # noqa
|
||||
|
||||
cls._original_run_job = internal_hydra.Hydra.run
|
||||
internal_hydra.Hydra.run = cls._patched_run_job
|
||||
cls._original_hydra_run = internal_hydra.Hydra.run
|
||||
internal_hydra.Hydra.run = cls._patched_hydra_run
|
||||
|
||||
cls._original_run_job = utils_hydra.run_job
|
||||
utils_hydra.run_job = cls._patched_run_job
|
||||
internal_hydra.run_job = cls._patched_run_job
|
||||
basic_launcher.run_job = cls._patched_run_job
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
@ -48,20 +59,8 @@ class PatchHydra(object):
|
||||
PatchHydra._current_task = None
|
||||
|
||||
@staticmethod
|
||||
def _patched_run_job(self, config_name, task_function, overrides, *args, **kwargs):
|
||||
allow_omegaconf_edit = False
|
||||
|
||||
def patched_task_function(a_config, *a_args, **a_kwargs):
|
||||
from omegaconf import OmegaConf # noqa
|
||||
if not running_remotely() or not allow_omegaconf_edit:
|
||||
PatchHydra._register_omegaconf(a_config)
|
||||
else:
|
||||
# noinspection PyProtectedMember
|
||||
omega_yaml = PatchHydra._current_task._get_configuration_text(PatchHydra._config_section)
|
||||
loaded_config = OmegaConf.load(io.StringIO(omega_yaml))
|
||||
a_config = OmegaConf.merge(a_config, loaded_config)
|
||||
PatchHydra._register_omegaconf(a_config, is_read_only=False)
|
||||
return task_function(a_config, *a_args, **a_kwargs)
|
||||
def _patched_hydra_run(self, config_name, task_function, overrides, *args, **kwargs):
|
||||
PatchHydra._allow_omegaconf_edit = False
|
||||
|
||||
# store the config
|
||||
# noinspection PyBroadException
|
||||
@ -74,7 +73,7 @@ class PatchHydra(object):
|
||||
connected_config = dict()
|
||||
connected_config[PatchHydra._parameter_allow_full_edit] = False
|
||||
PatchHydra._current_task.connect(connected_config, name=PatchHydra._parameter_section)
|
||||
allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None)
|
||||
PatchHydra._allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None)
|
||||
# get all the overrides
|
||||
full_parameters = PatchHydra._current_task.get_parameters(backwards_compatibility=False)
|
||||
stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items()
|
||||
@ -82,6 +81,23 @@ class PatchHydra(object):
|
||||
stored_config.pop(PatchHydra._parameter_allow_full_edit, None)
|
||||
overrides = ['{}={}'.format(k, v) for k, v in stored_config.items()]
|
||||
else:
|
||||
# We take care of it inside the _patched_run_job
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return PatchHydra._original_hydra_run(self, config_name, task_function, overrides, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _patched_run_job(config, task_function, *args, **kwargs):
|
||||
# store the config
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if running_remotely():
|
||||
# we take care of it in the hydra run (where we have access to the overrides)
|
||||
pass
|
||||
else:
|
||||
overrides = config.hydra.overrides.task
|
||||
stored_config = dict(arg.split('=', 1) for arg in overrides)
|
||||
stored_config[PatchHydra._parameter_allow_full_edit] = False
|
||||
if PatchHydra._current_task:
|
||||
@ -94,7 +110,40 @@ class PatchHydra(object):
|
||||
# PatchHydra._current_task.delete_parameter('Args/overrides')
|
||||
except Exception:
|
||||
pass
|
||||
return PatchHydra._original_run_job(self, config_name, patched_task_function, overrides, *args, **kwargs)
|
||||
|
||||
pre_app_task_init_call = bool(PatchHydra._current_task)
|
||||
|
||||
if pre_app_task_init_call:
|
||||
LoggerRoot.get_base_logger(PatchHydra).info(
|
||||
'Task.init called outside of Hydra-App. For full Hydra multi-run support, '
|
||||
'move the Task.init call into the Hydra-App main function')
|
||||
|
||||
result = PatchHydra._original_run_job(
|
||||
config, partial(PatchHydra._patched_task_function, task_function,),
|
||||
*args, **kwargs)
|
||||
|
||||
# if we have Task.init called inside the App, we close it after the app is done.
|
||||
# This will make sure that hydra run will create multiple Tasks
|
||||
if not running_remotely() and not pre_app_task_init_call and PatchHydra._current_task:
|
||||
PatchHydra._current_task.close()
|
||||
# make sure we do not reuse the Task if we have a multi-run session
|
||||
DEV_TASK_NO_REUSE.set(True)
|
||||
PatchHydra._current_task = None
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _patched_task_function(task_function, a_config, *a_args, **a_kwargs):
|
||||
from omegaconf import OmegaConf # noqa
|
||||
if not running_remotely() or not PatchHydra._allow_omegaconf_edit:
|
||||
PatchHydra._register_omegaconf(a_config)
|
||||
else:
|
||||
# noinspection PyProtectedMember
|
||||
omega_yaml = PatchHydra._current_task._get_configuration_text(PatchHydra._config_section)
|
||||
loaded_config = OmegaConf.load(io.StringIO(omega_yaml))
|
||||
a_config = OmegaConf.merge(a_config, loaded_config)
|
||||
PatchHydra._register_omegaconf(a_config, is_read_only=False)
|
||||
return task_function(a_config, *a_args, **a_kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _register_omegaconf(config, is_read_only=True):
|
||||
|
Loading…
Reference in New Issue
Block a user