mirror of
https://github.com/clearml/clearml
synced 2025-04-16 21:42:10 +00:00
Add tensorboard hparams binding
This commit is contained in:
parent
5a6ec697e1
commit
260690b990
@ -32,11 +32,11 @@ except ImportError:
|
||||
|
||||
class TensorflowBinding(object):
|
||||
@classmethod
|
||||
def update_current_task(cls, task, patch_reporting=True, patch_model_io=True):
|
||||
def update_current_task(cls, task, patch_reporting=True, patch_model_io=True, report_hparams=True):
|
||||
if not task:
|
||||
IsTensorboardInit.clear_tensorboard_used()
|
||||
|
||||
EventTrainsWriter.update_current_task(task)
|
||||
EventTrainsWriter.update_current_task(task, report_hparams=report_hparams)
|
||||
|
||||
if patch_reporting:
|
||||
PatchSummaryToEventTransformer.update_current_task(task)
|
||||
@ -232,6 +232,7 @@ class EventTrainsWriter(object):
|
||||
ClearML events and reports the events (metrics) for an ClearML task (logger).
|
||||
"""
|
||||
__main_task = None
|
||||
__report_hparams = True
|
||||
_add_lock = threading.RLock()
|
||||
_series_name_lookup = {}
|
||||
|
||||
@ -627,6 +628,24 @@ class EventTrainsWriter(object):
|
||||
max_history=self.max_keep_images,
|
||||
)
|
||||
|
||||
def _add_hparams(self, hparams_metadata):
|
||||
if not EventTrainsWriter.__report_hparams:
|
||||
return
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from tensorboard.plugins.hparams.metadata import parse_session_start_info_plugin_data
|
||||
|
||||
content = hparams_metadata["metadata"]["pluginData"]["content"]
|
||||
content = base64.b64decode(content)
|
||||
session_start_info = parse_session_start_info_plugin_data(content)
|
||||
session_start_info = MessageToDict(session_start_info)
|
||||
hparams = session_start_info["hparams"]
|
||||
EventTrainsWriter.__main_task.update_parameters(
|
||||
{"TB_hparams/{}".format(k): v for k, v in hparams.items()}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _add_text(self, tag, step, tensor_bytes):
|
||||
# noinspection PyProtectedMember
|
||||
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Text', logdir_header='title',
|
||||
@ -718,6 +737,14 @@ class EventTrainsWriter(object):
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug(
|
||||
'No tag for \'value\' existing keys %s' % ', '.join(vdict.keys()))
|
||||
continue
|
||||
try:
|
||||
from tensorboard.plugins.hparams.metadata import SESSION_START_INFO_TAG
|
||||
|
||||
if tag == SESSION_START_INFO_TAG:
|
||||
self._add_hparams(vdict)
|
||||
continue
|
||||
except ImportError:
|
||||
pass
|
||||
metric, values = get_data(vdict, supported_metrics)
|
||||
if metric == 'simpleValue':
|
||||
self._add_scalar(tag=tag, step=step, scalar_data=values)
|
||||
@ -814,7 +841,8 @@ class EventTrainsWriter(object):
|
||||
return origin_tag
|
||||
|
||||
@classmethod
|
||||
def update_current_task(cls, task):
|
||||
def update_current_task(cls, task, **kwargs):
|
||||
cls.__report_hparams = kwargs.get('report_hparams', False)
|
||||
if cls.__main_task != task:
|
||||
with cls._add_lock:
|
||||
cls._series_name_lookup = {}
|
||||
|
@ -366,10 +366,13 @@ class Task(_Task):
|
||||
- ``True`` - Automatically connect (default)
|
||||
- ``False`` - Do not automatically connect
|
||||
- A dictionary - In addition to a boolean, you can use a dictionary for fined grained control of connected
|
||||
frameworks. The dictionary keys are frameworks and values are booleans or wildcard strings.
|
||||
frameworks. The dictionary keys are frameworks and the values are booleans, other dictionaries used for
|
||||
finer control or wildcard strings.
|
||||
In case of wildcard strings, the local path of models have to match at least one wildcard to be
|
||||
saved/loaded by ClearML.
|
||||
Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``.
|
||||
Supported keys for finer control:
|
||||
'tensorboard': {'report_hparams': bool} # whether or not to report TensorBoard hyperparameters
|
||||
|
||||
For example:
|
||||
|
||||
@ -382,6 +385,9 @@ class Task(_Task):
|
||||
'joblib': True, 'megengine': True, 'catboost': True
|
||||
}
|
||||
|
||||
.. code-block:: py
|
||||
auto_connect_frameworks={'tensorboard': {'report_hparams': False}}
|
||||
|
||||
:param bool auto_resource_monitoring: Automatically create machine resource monitoring plots
|
||||
These plots appear in in the **ClearML Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab,
|
||||
with a title of **:resource monitor:**.
|
||||
@ -563,38 +569,49 @@ class Task(_Task):
|
||||
# always patch OS forking because of ProcessPool and the alike
|
||||
PatchOsFork.patch_fork()
|
||||
if auto_connect_frameworks:
|
||||
is_auto_connect_frameworks_bool = not isinstance(auto_connect_frameworks, dict)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('hydra', True):
|
||||
def should_connect(*keys):
|
||||
"""
|
||||
Evaluates value of auto_connect_frameworks[keys[0]]...[keys[-1]].
|
||||
If at some point in the evaluation, the value of auto_connect_frameworks[keys[0]]...[keys[-1]] is a bool,
|
||||
that value will be returned. If a dictionary is empty, it will be evaluated to False.
|
||||
If a key will not be found in the current dictionary, True will be returned.
|
||||
"""
|
||||
should_bind_framework = auto_connect_frameworks
|
||||
for key in keys:
|
||||
if not isinstance(should_bind_framework, dict):
|
||||
return bool(should_bind_framework)
|
||||
if should_bind_framework == {}:
|
||||
return False
|
||||
should_bind_framework = should_bind_framework.get(key, True)
|
||||
return bool(should_bind_framework)
|
||||
|
||||
if should_connect("hydra"):
|
||||
PatchHydra.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or (
|
||||
auto_connect_frameworks.get('scikit', True) and
|
||||
auto_connect_frameworks.get('joblib', True)):
|
||||
if should_connect("scikit") and should_connect("joblib"):
|
||||
PatchedJoblib.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('matplotlib', True):
|
||||
if should_connect("matplotlib"):
|
||||
PatchedMatplotlib.update_current_task(Task.__main_task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tensorflow', True) \
|
||||
or auto_connect_frameworks.get('tensorboard', True):
|
||||
if should_connect("tensorflow") or should_connect("tensorboard"):
|
||||
# allow to disable tfdefines
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tfdefines', True):
|
||||
if should_connect("tfdefines"):
|
||||
PatchAbsl.update_current_task(Task.__main_task)
|
||||
TensorflowBinding.update_current_task(
|
||||
task,
|
||||
patch_reporting=(is_auto_connect_frameworks_bool
|
||||
or auto_connect_frameworks.get('tensorboard', True)),
|
||||
patch_model_io=(is_auto_connect_frameworks_bool
|
||||
or auto_connect_frameworks.get('tensorflow', True)),
|
||||
patch_reporting=should_connect("tensorboard"),
|
||||
patch_model_io=should_connect("tensorflow"),
|
||||
report_hparams=should_connect("tensorboard", "report_hparams"),
|
||||
)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True):
|
||||
if should_connect("pytorch"):
|
||||
PatchPyTorchModelIO.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('megengine', True):
|
||||
if should_connect("megengine"):
|
||||
PatchMegEngineModelIO.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
|
||||
if should_connect("xgboost"):
|
||||
PatchXGBoostModelIO.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('catboost', True):
|
||||
if should_connect("catboost"):
|
||||
PatchCatBoostModelIO.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('fastai', True):
|
||||
if should_connect("fastai"):
|
||||
PatchFastai.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('lightgbm', True):
|
||||
if should_connect("lightgbm"):
|
||||
PatchLIGHTgbmModelIO.update_current_task(task)
|
||||
if auto_resource_monitoring and not is_sub_process_task_id:
|
||||
resource_monitor_cls = auto_resource_monitoring \
|
||||
|
Loading…
Reference in New Issue
Block a user