mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Add Hydra support (issue #219)
This commit is contained in:
		
							parent
							
								
									501e27057b
								
							
						
					
					
						commit
						6dd7b4e02e
					
				
							
								
								
									
										117
									
								
								trains/binding/hydra_bind.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								trains/binding/hydra_bind.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,117 @@ | |||||||
|  | import io | ||||||
|  | import sys | ||||||
|  | 
 | ||||||
|  | from ..config import running_remotely | ||||||
|  | from ..debugging.log import LoggerRoot | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class PatchHydra(object): | ||||||
|  |     _original_run_job = None | ||||||
|  |     _current_task = None | ||||||
|  |     _config_section = 'OmegaConf' | ||||||
|  |     _parameter_section = 'Hydra' | ||||||
|  |     _parameter_allow_full_edit = '_allow_omegaconf_edit_' | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def patch_hydra(cls): | ||||||
|  |         # noinspection PyBroadException | ||||||
|  |         try: | ||||||
|  |             # only once | ||||||
|  |             if cls._original_run_job: | ||||||
|  |                 return True | ||||||
|  |             # if hydra is not loaded, do not patch anything | ||||||
|  |             if not sys.modules.get('hydra'): | ||||||
|  |                 return False | ||||||
|  | 
 | ||||||
|  |             from hydra.core import utils  # noqa | ||||||
|  |             from hydra._internal import hydra as internal_hydra  # noqa | ||||||
|  | 
 | ||||||
|  |             # check if hydra is already initialized | ||||||
|  |             if utils.HydraConfig.initialized(): | ||||||
|  |                 LoggerRoot.get_base_logger().warning( | ||||||
|  |                     "Hydra is already loaded storing read-only OmegaConf, " | ||||||
|  |                     "For full support call Task.init(...) before the Hydra App") | ||||||
|  |                 # noinspection PyBroadException | ||||||
|  |                 try: | ||||||
|  |                     # noinspection PyProtectedMember,PyUnresolvedReferences | ||||||
|  |                     PatchHydra._register_omegaconf(utils.HydraConfig.get()._get_root()) | ||||||
|  |                 except Exception: | ||||||
|  |                     pass | ||||||
|  |                 return False | ||||||
|  | 
 | ||||||
|  |             cls._original_run_job = internal_hydra.Hydra.run | ||||||
|  |             internal_hydra.Hydra.run = cls._patched_run_job | ||||||
|  |             return True | ||||||
|  |         except Exception: | ||||||
|  |             return False | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def update_current_task(task): | ||||||
|  |         # set current Task before patching | ||||||
|  |         PatchHydra._current_task = task | ||||||
|  |         if not PatchHydra.patch_hydra(): | ||||||
|  |             # if patching failed set it to None | ||||||
|  |             PatchHydra._current_task = None | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def _patched_run_job(self, config_name, task_function, overrides, *args, **kwargs): | ||||||
|  |         if not PatchHydra._current_task: | ||||||
|  |             return PatchHydra._original_run_job(self, config_name, task_function, overrides, *args, **kwargs) | ||||||
|  |         allow_omegaconf_edit = False | ||||||
|  | 
 | ||||||
|  |         def patched_task_function(a_config, *a_args, **a_kwargs): | ||||||
|  |             from omegaconf import OmegaConf  # noqa | ||||||
|  |             if not running_remotely() or not allow_omegaconf_edit: | ||||||
|  |                 PatchHydra._register_omegaconf(a_config) | ||||||
|  |             else: | ||||||
|  |                 # noinspection PyProtectedMember | ||||||
|  |                 omega_yaml = PatchHydra._current_task._get_configuration_text(PatchHydra._config_section) | ||||||
|  |                 loaded_config = OmegaConf.load(io.StringIO(omega_yaml)) | ||||||
|  |                 a_config = OmegaConf.merge(a_config, loaded_config) | ||||||
|  |                 PatchHydra._register_omegaconf(a_config, is_read_only=False) | ||||||
|  |             return task_function(a_config, *a_args, **a_kwargs) | ||||||
|  | 
 | ||||||
|  |         # store the config | ||||||
|  |         # noinspection PyBroadException | ||||||
|  |         try: | ||||||
|  |             if running_remotely(): | ||||||
|  |                 # get the _parameter_allow_full_edit casted back to boolean | ||||||
|  |                 connected_config = dict() | ||||||
|  |                 connected_config[PatchHydra._parameter_allow_full_edit] = False | ||||||
|  |                 PatchHydra._current_task.connect(connected_config, name=PatchHydra._parameter_section) | ||||||
|  |                 allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None) | ||||||
|  |                 # get all the overrides | ||||||
|  |                 full_parameters = PatchHydra._current_task.get_parameters(backwards_compatibility=False) | ||||||
|  |                 stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items() | ||||||
|  |                                  if k.startswith(PatchHydra._parameter_section+'/')} | ||||||
|  |                 stored_config.pop(PatchHydra._parameter_allow_full_edit, None) | ||||||
|  |                 overrides = ['{}={}'.format(k, v) for k, v in stored_config.items()] | ||||||
|  |             else: | ||||||
|  |                 stored_config = dict(arg.split('=', 1) for arg in overrides) | ||||||
|  |                 stored_config[PatchHydra._parameter_allow_full_edit] = False | ||||||
|  |                 PatchHydra._current_task.connect(stored_config, name=PatchHydra._parameter_section) | ||||||
|  |                 # todo: remove the overrides section from the Args (we have it here) | ||||||
|  |                 # PatchHydra._current_task.delete_parameter('Args/overrides') | ||||||
|  |         except Exception: | ||||||
|  |             pass | ||||||
|  |         return PatchHydra._original_run_job(self, config_name, patched_task_function, overrides, *args, **kwargs) | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def _register_omegaconf(config, is_read_only=True): | ||||||
|  |         from omegaconf import OmegaConf  # noqa | ||||||
|  | 
 | ||||||
|  |         if is_read_only: | ||||||
|  |             description = 'Full OmegaConf YAML configuration. ' \ | ||||||
|  |                           'This is a read-only section, unless \'{}/{}\' is set to True'.format( | ||||||
|  |                 PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit) | ||||||
|  |         else: | ||||||
|  |             description = 'Full OmegaConf YAML configuration overridden! ({}/{}=True)'.format( | ||||||
|  |                 PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit) | ||||||
|  | 
 | ||||||
|  |         # noinspection PyProtectedMember | ||||||
|  |         PatchHydra._current_task._set_configuration( | ||||||
|  |             name=PatchHydra._config_section, | ||||||
|  |             description=description, | ||||||
|  |             config_type='OmegaConf YAML', | ||||||
|  |             config_text=OmegaConf.to_yaml({k: v for k, v in config.items() if k not in ('hydra', )}) | ||||||
|  |         ) | ||||||
| @ -42,6 +42,7 @@ from .binding.frameworks.tensorflow_bind import TensorflowBinding | |||||||
| from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO | from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO | ||||||
| from .binding.joblib_bind import PatchedJoblib | from .binding.joblib_bind import PatchedJoblib | ||||||
| from .binding.matplotlib_bind import PatchedMatplotlib | from .binding.matplotlib_bind import PatchedMatplotlib | ||||||
|  | from .binding.hydra_bind import PatchHydra | ||||||
| from .config import config, DEV_TASK_NO_REUSE, get_is_master_node | from .config import config, DEV_TASK_NO_REUSE, get_is_master_node | ||||||
| from .config import running_remotely, get_remote_task_id | from .config import running_remotely, get_remote_task_id | ||||||
| from .config.cache import SessionCache | from .config.cache import SessionCache | ||||||
| @ -337,7 +338,7 @@ class Task(_Task): | |||||||
|             .. code-block:: py |             .. code-block:: py | ||||||
| 
 | 
 | ||||||
|                auto_connect_frameworks={'matplotlib': True, 'tensorflow': True, 'pytorch': True, |                auto_connect_frameworks={'matplotlib': True, 'tensorflow': True, 'pytorch': True, | ||||||
|                     'xgboost': True, 'scikit': True, 'fastai': True, 'lightgbm': True} |                     'xgboost': True, 'scikit': True, 'fastai': True, 'lightgbm': True, 'hydra': True} | ||||||
| 
 | 
 | ||||||
|         :param bool auto_resource_monitoring: Automatically create machine resource monitoring plots |         :param bool auto_resource_monitoring: Automatically create machine resource monitoring plots | ||||||
|             These plots appear in in the **Trains Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab, |             These plots appear in in the **Trains Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab, | ||||||
| @ -493,6 +494,8 @@ class Task(_Task): | |||||||
|             PatchOsFork.patch_fork() |             PatchOsFork.patch_fork() | ||||||
|             if auto_connect_frameworks: |             if auto_connect_frameworks: | ||||||
|                 is_auto_connect_frameworks_bool = not isinstance(auto_connect_frameworks, dict) |                 is_auto_connect_frameworks_bool = not isinstance(auto_connect_frameworks, dict) | ||||||
|  |                 if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('hydra', True): | ||||||
|  |                     PatchHydra.update_current_task(task) | ||||||
|                 if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('scikit', True): |                 if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('scikit', True): | ||||||
|                     PatchedJoblib.update_current_task(task) |                     PatchedJoblib.update_current_task(task) | ||||||
|                 if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('matplotlib', True): |                 if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('matplotlib', True): | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai