mirror of
https://github.com/clearml/clearml
synced 2025-02-12 07:35:08 +00:00
Add Hydra bind support for Task.init call within the Hydra app, Issue #219
This commit is contained in:
parent
f161811558
commit
e9920e27ed
@ -1,13 +1,13 @@
|
|||||||
import io
|
import io
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from ..config import running_remotely
|
from ..config import running_remotely, get_remote_task_id
|
||||||
from ..debugging.log import LoggerRoot
|
|
||||||
|
|
||||||
|
|
||||||
class PatchHydra(object):
|
class PatchHydra(object):
|
||||||
_original_run_job = None
|
_original_run_job = None
|
||||||
_current_task = None
|
_current_task = None
|
||||||
|
_last_untracked_state = {}
|
||||||
_config_section = 'OmegaConf'
|
_config_section = 'OmegaConf'
|
||||||
_parameter_section = 'Hydra'
|
_parameter_section = 'Hydra'
|
||||||
_parameter_allow_full_edit = '_allow_omegaconf_edit_'
|
_parameter_allow_full_edit = '_allow_omegaconf_edit_'
|
||||||
@ -23,22 +23,8 @@ class PatchHydra(object):
|
|||||||
if not sys.modules.get('hydra'):
|
if not sys.modules.get('hydra'):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
from hydra.core import utils # noqa
|
|
||||||
from hydra._internal import hydra as internal_hydra # noqa
|
from hydra._internal import hydra as internal_hydra # noqa
|
||||||
|
|
||||||
# check if hydra is already initialized
|
|
||||||
if utils.HydraConfig.initialized():
|
|
||||||
LoggerRoot.get_base_logger().warning(
|
|
||||||
"Hydra is already loaded storing read-only OmegaConf, "
|
|
||||||
"For full support call Task.init(...) before the Hydra App")
|
|
||||||
# noinspection PyBroadException
|
|
||||||
try:
|
|
||||||
# noinspection PyProtectedMember,PyUnresolvedReferences
|
|
||||||
PatchHydra._register_omegaconf(utils.HydraConfig.get()._get_root())
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return False
|
|
||||||
|
|
||||||
cls._original_run_job = internal_hydra.Hydra.run
|
cls._original_run_job = internal_hydra.Hydra.run
|
||||||
internal_hydra.Hydra.run = cls._patched_run_job
|
internal_hydra.Hydra.run = cls._patched_run_job
|
||||||
return True
|
return True
|
||||||
@ -49,14 +35,20 @@ class PatchHydra(object):
|
|||||||
def update_current_task(task):
|
def update_current_task(task):
|
||||||
# set current Task before patching
|
# set current Task before patching
|
||||||
PatchHydra._current_task = task
|
PatchHydra._current_task = task
|
||||||
if not PatchHydra.patch_hydra():
|
if PatchHydra.patch_hydra():
|
||||||
|
# check if we have an untracked state, store it.
|
||||||
|
if PatchHydra._last_untracked_state.get('connect'):
|
||||||
|
PatchHydra._current_task.connect(**PatchHydra._last_untracked_state['connect'])
|
||||||
|
if PatchHydra._last_untracked_state.get('_set_configuration'):
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
PatchHydra._current_task._set_configuration(**PatchHydra._last_untracked_state['_set_configuration'])
|
||||||
|
PatchHydra._last_untracked_state = {}
|
||||||
|
else:
|
||||||
# if patching failed set it to None
|
# if patching failed set it to None
|
||||||
PatchHydra._current_task = None
|
PatchHydra._current_task = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patched_run_job(self, config_name, task_function, overrides, *args, **kwargs):
|
def _patched_run_job(self, config_name, task_function, overrides, *args, **kwargs):
|
||||||
if not PatchHydra._current_task:
|
|
||||||
return PatchHydra._original_run_job(self, config_name, task_function, overrides, *args, **kwargs)
|
|
||||||
allow_omegaconf_edit = False
|
allow_omegaconf_edit = False
|
||||||
|
|
||||||
def patched_task_function(a_config, *a_args, **a_kwargs):
|
def patched_task_function(a_config, *a_args, **a_kwargs):
|
||||||
@ -75,6 +67,9 @@ class PatchHydra(object):
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
if running_remotely():
|
if running_remotely():
|
||||||
|
if not PatchHydra._current_task:
|
||||||
|
from ..task import Task
|
||||||
|
PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id())
|
||||||
# get the _parameter_allow_full_edit casted back to boolean
|
# get the _parameter_allow_full_edit casted back to boolean
|
||||||
connected_config = dict()
|
connected_config = dict()
|
||||||
connected_config[PatchHydra._parameter_allow_full_edit] = False
|
connected_config[PatchHydra._parameter_allow_full_edit] = False
|
||||||
@ -89,7 +84,12 @@ class PatchHydra(object):
|
|||||||
else:
|
else:
|
||||||
stored_config = dict(arg.split('=', 1) for arg in overrides)
|
stored_config = dict(arg.split('=', 1) for arg in overrides)
|
||||||
stored_config[PatchHydra._parameter_allow_full_edit] = False
|
stored_config[PatchHydra._parameter_allow_full_edit] = False
|
||||||
|
if PatchHydra._current_task:
|
||||||
PatchHydra._current_task.connect(stored_config, name=PatchHydra._parameter_section)
|
PatchHydra._current_task.connect(stored_config, name=PatchHydra._parameter_section)
|
||||||
|
PatchHydra._last_untracked_state.pop('connect', None)
|
||||||
|
else:
|
||||||
|
PatchHydra._last_untracked_state['connect'] = dict(
|
||||||
|
mutable=stored_config, name=PatchHydra._parameter_section)
|
||||||
# todo: remove the overrides section from the Args (we have it here)
|
# todo: remove the overrides section from the Args (we have it here)
|
||||||
# PatchHydra._current_task.delete_parameter('Args/overrides')
|
# PatchHydra._current_task.delete_parameter('Args/overrides')
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -101,17 +101,36 @@ class PatchHydra(object):
|
|||||||
from omegaconf import OmegaConf # noqa
|
from omegaconf import OmegaConf # noqa
|
||||||
|
|
||||||
if is_read_only:
|
if is_read_only:
|
||||||
description = 'Full OmegaConf YAML configuration. ' \
|
description = \
|
||||||
|
'Full OmegaConf YAML configuration. ' \
|
||||||
'This is a read-only section, unless \'{}/{}\' is set to True'.format(
|
'This is a read-only section, unless \'{}/{}\' is set to True'.format(
|
||||||
PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit)
|
PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit)
|
||||||
else:
|
else:
|
||||||
description = 'Full OmegaConf YAML configuration overridden! ({}/{}=True)'.format(
|
description = 'Full OmegaConf YAML configuration overridden! ({}/{}=True)'.format(
|
||||||
PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit)
|
PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit)
|
||||||
|
|
||||||
# noinspection PyProtectedMember
|
configuration = dict(
|
||||||
PatchHydra._current_task._set_configuration(
|
|
||||||
name=PatchHydra._config_section,
|
name=PatchHydra._config_section,
|
||||||
description=description,
|
description=description,
|
||||||
config_type='OmegaConf YAML',
|
config_type='OmegaConf YAML',
|
||||||
config_text=OmegaConf.to_yaml({k: v for k, v in config.items() if k not in ('hydra', )})
|
config_text=OmegaConf.to_yaml({k: v for k, v in config.items() if k not in ('hydra', )})
|
||||||
)
|
)
|
||||||
|
if PatchHydra._current_task:
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
PatchHydra._current_task._set_configuration(**configuration)
|
||||||
|
PatchHydra._last_untracked_state.pop('_set_configuration', None)
|
||||||
|
else:
|
||||||
|
PatchHydra._last_untracked_state['_set_configuration'] = configuration
|
||||||
|
|
||||||
|
|
||||||
|
def __global_hydra_bind():
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
import hydra # noqa
|
||||||
|
PatchHydra.patch_hydra()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# patch hydra
|
||||||
|
__global_hydra_bind()
|
||||||
|
Loading…
Reference in New Issue
Block a user