Add tensorboard hparams binding

This commit is contained in:
allegroai 2022-04-09 14:15:38 +03:00
parent 5a6ec697e1
commit 260690b990
2 changed files with 68 additions and 23 deletions

View File

@ -32,11 +32,11 @@ except ImportError:
class TensorflowBinding(object):
@classmethod
def update_current_task(cls, task, patch_reporting=True, patch_model_io=True):
def update_current_task(cls, task, patch_reporting=True, patch_model_io=True, report_hparams=True):
if not task:
IsTensorboardInit.clear_tensorboard_used()
EventTrainsWriter.update_current_task(task)
EventTrainsWriter.update_current_task(task, report_hparams=report_hparams)
if patch_reporting:
PatchSummaryToEventTransformer.update_current_task(task)
@ -232,6 +232,7 @@ class EventTrainsWriter(object):
ClearML events and reports the events (metrics) for an ClearML task (logger).
"""
__main_task = None
__report_hparams = True
_add_lock = threading.RLock()
_series_name_lookup = {}
@ -627,6 +628,24 @@ class EventTrainsWriter(object):
max_history=self.max_keep_images,
)
def _add_hparams(self, hparams_metadata):
if not EventTrainsWriter.__report_hparams:
return
# noinspection PyBroadException
try:
from tensorboard.plugins.hparams.metadata import parse_session_start_info_plugin_data
content = hparams_metadata["metadata"]["pluginData"]["content"]
content = base64.b64decode(content)
session_start_info = parse_session_start_info_plugin_data(content)
session_start_info = MessageToDict(session_start_info)
hparams = session_start_info["hparams"]
EventTrainsWriter.__main_task.update_parameters(
{"TB_hparams/{}".format(k): v for k, v in hparams.items()}
)
except Exception:
pass
def _add_text(self, tag, step, tensor_bytes):
# noinspection PyProtectedMember
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Text', logdir_header='title',
@ -718,6 +737,14 @@ class EventTrainsWriter(object):
LoggerRoot.get_base_logger(TensorflowBinding).debug(
'No tag for \'value\' existing keys %s' % ', '.join(vdict.keys()))
continue
try:
from tensorboard.plugins.hparams.metadata import SESSION_START_INFO_TAG
if tag == SESSION_START_INFO_TAG:
self._add_hparams(vdict)
continue
except ImportError:
pass
metric, values = get_data(vdict, supported_metrics)
if metric == 'simpleValue':
self._add_scalar(tag=tag, step=step, scalar_data=values)
@ -814,7 +841,8 @@ class EventTrainsWriter(object):
return origin_tag
@classmethod
def update_current_task(cls, task):
def update_current_task(cls, task, **kwargs):
cls.__report_hparams = kwargs.get('report_hparams', False)
if cls.__main_task != task:
with cls._add_lock:
cls._series_name_lookup = {}

View File

@ -366,10 +366,13 @@ class Task(_Task):
- ``True`` - Automatically connect (default)
- ``False`` - Do not automatically connect
- A dictionary - In addition to a boolean, you can use a dictionary for fined grained control of connected
frameworks. The dictionary keys are frameworks and values are booleans or wildcard strings.
frameworks. The dictionary keys are frameworks and the values are booleans, other dictionaries used for
finer control or wildcard strings.
In case of wildcard strings, the local path of models have to match at least one wildcard to be
saved/loaded by ClearML.
Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``.
Supported keys for finer control:
'tensorboard': {'report_hparams': bool} # whether or not to report TensorBoard hyperparameters
For example:
@ -382,6 +385,9 @@ class Task(_Task):
'joblib': True, 'megengine': True, 'catboost': True
}
.. code-block:: py
auto_connect_frameworks={'tensorboard': {'report_hparams': False}}
:param bool auto_resource_monitoring: Automatically create machine resource monitoring plots
These plots appear in in the **ClearML Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab,
with a title of **:resource monitor:**.
@ -563,38 +569,49 @@ class Task(_Task):
# always patch OS forking because of ProcessPool and the alike
PatchOsFork.patch_fork()
if auto_connect_frameworks:
is_auto_connect_frameworks_bool = not isinstance(auto_connect_frameworks, dict)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('hydra', True):
def should_connect(*keys):
"""
Evaluates value of auto_connect_frameworks[keys[0]]...[keys[-1]].
If at some point in the evaluation, the value of auto_connect_frameworks[keys[0]]...[keys[-1]] is a bool,
that value will be returned. If a dictionary is empty, it will be evaluated to False.
If a key will not be found in the current dictionary, True will be returned.
"""
should_bind_framework = auto_connect_frameworks
for key in keys:
if not isinstance(should_bind_framework, dict):
return bool(should_bind_framework)
if should_bind_framework == {}:
return False
should_bind_framework = should_bind_framework.get(key, True)
return bool(should_bind_framework)
if should_connect("hydra"):
PatchHydra.update_current_task(task)
if is_auto_connect_frameworks_bool or (
auto_connect_frameworks.get('scikit', True) and
auto_connect_frameworks.get('joblib', True)):
if should_connect("scikit") and should_connect("joblib"):
PatchedJoblib.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('matplotlib', True):
if should_connect("matplotlib"):
PatchedMatplotlib.update_current_task(Task.__main_task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tensorflow', True) \
or auto_connect_frameworks.get('tensorboard', True):
if should_connect("tensorflow") or should_connect("tensorboard"):
# allow to disable tfdefines
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tfdefines', True):
if should_connect("tfdefines"):
PatchAbsl.update_current_task(Task.__main_task)
TensorflowBinding.update_current_task(
task,
patch_reporting=(is_auto_connect_frameworks_bool
or auto_connect_frameworks.get('tensorboard', True)),
patch_model_io=(is_auto_connect_frameworks_bool
or auto_connect_frameworks.get('tensorflow', True)),
patch_reporting=should_connect("tensorboard"),
patch_model_io=should_connect("tensorflow"),
report_hparams=should_connect("tensorboard", "report_hparams"),
)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True):
if should_connect("pytorch"):
PatchPyTorchModelIO.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('megengine', True):
if should_connect("megengine"):
PatchMegEngineModelIO.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
if should_connect("xgboost"):
PatchXGBoostModelIO.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('catboost', True):
if should_connect("catboost"):
PatchCatBoostModelIO.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('fastai', True):
if should_connect("fastai"):
PatchFastai.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('lightgbm', True):
if should_connect("lightgbm"):
PatchLIGHTgbmModelIO.update_current_task(task)
if auto_resource_monitoring and not is_sub_process_task_id:
resource_monitor_cls = auto_resource_monitoring \