From 260690b99046033b71f34e47431973ab3b22c92a Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 9 Apr 2022 14:15:38 +0300 Subject: [PATCH] Add tensorboard hparams binding --- clearml/binding/frameworks/tensorflow_bind.py | 34 ++++++++++- clearml/task.py | 57 ++++++++++++------- 2 files changed, 68 insertions(+), 23 deletions(-) diff --git a/clearml/binding/frameworks/tensorflow_bind.py b/clearml/binding/frameworks/tensorflow_bind.py index 266fedf9..3a589753 100644 --- a/clearml/binding/frameworks/tensorflow_bind.py +++ b/clearml/binding/frameworks/tensorflow_bind.py @@ -32,11 +32,11 @@ except ImportError: class TensorflowBinding(object): @classmethod - def update_current_task(cls, task, patch_reporting=True, patch_model_io=True): + def update_current_task(cls, task, patch_reporting=True, patch_model_io=True, report_hparams=True): if not task: IsTensorboardInit.clear_tensorboard_used() - EventTrainsWriter.update_current_task(task) + EventTrainsWriter.update_current_task(task, report_hparams=report_hparams) if patch_reporting: PatchSummaryToEventTransformer.update_current_task(task) @@ -232,6 +232,7 @@ class EventTrainsWriter(object): ClearML events and reports the events (metrics) for an ClearML task (logger). """ __main_task = None + __report_hparams = True _add_lock = threading.RLock() _series_name_lookup = {} @@ -627,6 +628,24 @@ class EventTrainsWriter(object): max_history=self.max_keep_images, ) + def _add_hparams(self, hparams_metadata): + if not EventTrainsWriter.__report_hparams: + return + # noinspection PyBroadException + try: + from tensorboard.plugins.hparams.metadata import parse_session_start_info_plugin_data + + content = hparams_metadata["metadata"]["pluginData"]["content"] + content = base64.b64decode(content) + session_start_info = parse_session_start_info_plugin_data(content) + session_start_info = MessageToDict(session_start_info) + hparams = session_start_info["hparams"] + EventTrainsWriter.__main_task.update_parameters( + {"TB_hparams/{}".format(k): v for k, v in hparams.items()} + ) + except Exception: + pass + def _add_text(self, tag, step, tensor_bytes): # noinspection PyProtectedMember title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Text', logdir_header='title', @@ -718,6 +737,14 @@ class EventTrainsWriter(object): LoggerRoot.get_base_logger(TensorflowBinding).debug( 'No tag for \'value\' existing keys %s' % ', '.join(vdict.keys())) continue + try: + from tensorboard.plugins.hparams.metadata import SESSION_START_INFO_TAG + + if tag == SESSION_START_INFO_TAG: + self._add_hparams(vdict) + continue + except ImportError: + pass metric, values = get_data(vdict, supported_metrics) if metric == 'simpleValue': self._add_scalar(tag=tag, step=step, scalar_data=values) @@ -814,7 +841,8 @@ class EventTrainsWriter(object): return origin_tag @classmethod - def update_current_task(cls, task): + def update_current_task(cls, task, **kwargs): + cls.__report_hparams = kwargs.get('report_hparams', False) if cls.__main_task != task: with cls._add_lock: cls._series_name_lookup = {} diff --git a/clearml/task.py b/clearml/task.py index 17e4f461..1f13c17f 100644 --- a/clearml/task.py +++ b/clearml/task.py @@ -366,10 +366,13 @@ class Task(_Task): - ``True`` - Automatically connect (default) - ``False`` - Do not automatically connect - A dictionary - In addition to a boolean, you can use a dictionary for fined grained control of connected - frameworks. The dictionary keys are frameworks and values are booleans or wildcard strings. + frameworks. The dictionary keys are frameworks and the values are booleans, other dictionaries used for + finer control or wildcard strings. In case of wildcard strings, the local path of models have to match at least one wildcard to be saved/loaded by ClearML. Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``. + Supported keys for finer control: + 'tensorboard': {'report_hparams': bool} # whether or not to report TensorBoard hyperparameters For example: @@ -382,6 +385,9 @@ class Task(_Task): 'joblib': True, 'megengine': True, 'catboost': True } + .. code-block:: py + auto_connect_frameworks={'tensorboard': {'report_hparams': False}} + :param bool auto_resource_monitoring: Automatically create machine resource monitoring plots These plots appear in in the **ClearML Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab, with a title of **:resource monitor:**. @@ -563,38 +569,49 @@ class Task(_Task): # always patch OS forking because of ProcessPool and the alike PatchOsFork.patch_fork() if auto_connect_frameworks: - is_auto_connect_frameworks_bool = not isinstance(auto_connect_frameworks, dict) - if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('hydra', True): + def should_connect(*keys): + """ + Evaluates value of auto_connect_frameworks[keys[0]]...[keys[-1]]. + If at some point in the evaluation, the value of auto_connect_frameworks[keys[0]]...[keys[-1]] is a bool, + that value will be returned. If a dictionary is empty, it will be evaluated to False. + If a key will not be found in the current dictionary, True will be returned. + """ + should_bind_framework = auto_connect_frameworks + for key in keys: + if not isinstance(should_bind_framework, dict): + return bool(should_bind_framework) + if should_bind_framework == {}: + return False + should_bind_framework = should_bind_framework.get(key, True) + return bool(should_bind_framework) + + if should_connect("hydra"): PatchHydra.update_current_task(task) - if is_auto_connect_frameworks_bool or ( - auto_connect_frameworks.get('scikit', True) and - auto_connect_frameworks.get('joblib', True)): + if should_connect("scikit") and should_connect("joblib"): PatchedJoblib.update_current_task(task) - if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('matplotlib', True): + if should_connect("matplotlib"): PatchedMatplotlib.update_current_task(Task.__main_task) - if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tensorflow', True) \ - or auto_connect_frameworks.get('tensorboard', True): + if should_connect("tensorflow") or should_connect("tensorboard"): # allow to disable tfdefines - if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tfdefines', True): + if should_connect("tfdefines"): PatchAbsl.update_current_task(Task.__main_task) TensorflowBinding.update_current_task( task, - patch_reporting=(is_auto_connect_frameworks_bool - or auto_connect_frameworks.get('tensorboard', True)), - patch_model_io=(is_auto_connect_frameworks_bool - or auto_connect_frameworks.get('tensorflow', True)), + patch_reporting=should_connect("tensorboard"), + patch_model_io=should_connect("tensorflow"), + report_hparams=should_connect("tensorboard", "report_hparams"), ) - if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True): + if should_connect("pytorch"): PatchPyTorchModelIO.update_current_task(task) - if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('megengine', True): + if should_connect("megengine"): PatchMegEngineModelIO.update_current_task(task) - if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True): + if should_connect("xgboost"): PatchXGBoostModelIO.update_current_task(task) - if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('catboost', True): + if should_connect("catboost"): PatchCatBoostModelIO.update_current_task(task) - if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('fastai', True): + if should_connect("fastai"): PatchFastai.update_current_task(task) - if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('lightgbm', True): + if should_connect("lightgbm"): PatchLIGHTgbmModelIO.update_current_task(task) if auto_resource_monitoring and not is_sub_process_task_id: resource_monitor_cls = auto_resource_monitoring \