mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Add tensorboard hparams binding
This commit is contained in:
parent
5a6ec697e1
commit
260690b990
@ -32,11 +32,11 @@ except ImportError:
|
|||||||
|
|
||||||
class TensorflowBinding(object):
|
class TensorflowBinding(object):
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_current_task(cls, task, patch_reporting=True, patch_model_io=True):
|
def update_current_task(cls, task, patch_reporting=True, patch_model_io=True, report_hparams=True):
|
||||||
if not task:
|
if not task:
|
||||||
IsTensorboardInit.clear_tensorboard_used()
|
IsTensorboardInit.clear_tensorboard_used()
|
||||||
|
|
||||||
EventTrainsWriter.update_current_task(task)
|
EventTrainsWriter.update_current_task(task, report_hparams=report_hparams)
|
||||||
|
|
||||||
if patch_reporting:
|
if patch_reporting:
|
||||||
PatchSummaryToEventTransformer.update_current_task(task)
|
PatchSummaryToEventTransformer.update_current_task(task)
|
||||||
@ -232,6 +232,7 @@ class EventTrainsWriter(object):
|
|||||||
ClearML events and reports the events (metrics) for an ClearML task (logger).
|
ClearML events and reports the events (metrics) for an ClearML task (logger).
|
||||||
"""
|
"""
|
||||||
__main_task = None
|
__main_task = None
|
||||||
|
__report_hparams = True
|
||||||
_add_lock = threading.RLock()
|
_add_lock = threading.RLock()
|
||||||
_series_name_lookup = {}
|
_series_name_lookup = {}
|
||||||
|
|
||||||
@ -627,6 +628,24 @@ class EventTrainsWriter(object):
|
|||||||
max_history=self.max_keep_images,
|
max_history=self.max_keep_images,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _add_hparams(self, hparams_metadata):
|
||||||
|
if not EventTrainsWriter.__report_hparams:
|
||||||
|
return
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
from tensorboard.plugins.hparams.metadata import parse_session_start_info_plugin_data
|
||||||
|
|
||||||
|
content = hparams_metadata["metadata"]["pluginData"]["content"]
|
||||||
|
content = base64.b64decode(content)
|
||||||
|
session_start_info = parse_session_start_info_plugin_data(content)
|
||||||
|
session_start_info = MessageToDict(session_start_info)
|
||||||
|
hparams = session_start_info["hparams"]
|
||||||
|
EventTrainsWriter.__main_task.update_parameters(
|
||||||
|
{"TB_hparams/{}".format(k): v for k, v in hparams.items()}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _add_text(self, tag, step, tensor_bytes):
|
def _add_text(self, tag, step, tensor_bytes):
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Text', logdir_header='title',
|
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Text', logdir_header='title',
|
||||||
@ -718,6 +737,14 @@ class EventTrainsWriter(object):
|
|||||||
LoggerRoot.get_base_logger(TensorflowBinding).debug(
|
LoggerRoot.get_base_logger(TensorflowBinding).debug(
|
||||||
'No tag for \'value\' existing keys %s' % ', '.join(vdict.keys()))
|
'No tag for \'value\' existing keys %s' % ', '.join(vdict.keys()))
|
||||||
continue
|
continue
|
||||||
|
try:
|
||||||
|
from tensorboard.plugins.hparams.metadata import SESSION_START_INFO_TAG
|
||||||
|
|
||||||
|
if tag == SESSION_START_INFO_TAG:
|
||||||
|
self._add_hparams(vdict)
|
||||||
|
continue
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
metric, values = get_data(vdict, supported_metrics)
|
metric, values = get_data(vdict, supported_metrics)
|
||||||
if metric == 'simpleValue':
|
if metric == 'simpleValue':
|
||||||
self._add_scalar(tag=tag, step=step, scalar_data=values)
|
self._add_scalar(tag=tag, step=step, scalar_data=values)
|
||||||
@ -814,7 +841,8 @@ class EventTrainsWriter(object):
|
|||||||
return origin_tag
|
return origin_tag
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_current_task(cls, task):
|
def update_current_task(cls, task, **kwargs):
|
||||||
|
cls.__report_hparams = kwargs.get('report_hparams', False)
|
||||||
if cls.__main_task != task:
|
if cls.__main_task != task:
|
||||||
with cls._add_lock:
|
with cls._add_lock:
|
||||||
cls._series_name_lookup = {}
|
cls._series_name_lookup = {}
|
||||||
|
@ -366,10 +366,13 @@ class Task(_Task):
|
|||||||
- ``True`` - Automatically connect (default)
|
- ``True`` - Automatically connect (default)
|
||||||
- ``False`` - Do not automatically connect
|
- ``False`` - Do not automatically connect
|
||||||
- A dictionary - In addition to a boolean, you can use a dictionary for fined grained control of connected
|
- A dictionary - In addition to a boolean, you can use a dictionary for fined grained control of connected
|
||||||
frameworks. The dictionary keys are frameworks and values are booleans or wildcard strings.
|
frameworks. The dictionary keys are frameworks and the values are booleans, other dictionaries used for
|
||||||
|
finer control or wildcard strings.
|
||||||
In case of wildcard strings, the local path of models have to match at least one wildcard to be
|
In case of wildcard strings, the local path of models have to match at least one wildcard to be
|
||||||
saved/loaded by ClearML.
|
saved/loaded by ClearML.
|
||||||
Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``.
|
Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``.
|
||||||
|
Supported keys for finer control:
|
||||||
|
'tensorboard': {'report_hparams': bool} # whether or not to report TensorBoard hyperparameters
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
|
|
||||||
@ -382,6 +385,9 @@ class Task(_Task):
|
|||||||
'joblib': True, 'megengine': True, 'catboost': True
|
'joblib': True, 'megengine': True, 'catboost': True
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.. code-block:: py
|
||||||
|
auto_connect_frameworks={'tensorboard': {'report_hparams': False}}
|
||||||
|
|
||||||
:param bool auto_resource_monitoring: Automatically create machine resource monitoring plots
|
:param bool auto_resource_monitoring: Automatically create machine resource monitoring plots
|
||||||
These plots appear in in the **ClearML Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab,
|
These plots appear in in the **ClearML Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab,
|
||||||
with a title of **:resource monitor:**.
|
with a title of **:resource monitor:**.
|
||||||
@ -563,38 +569,49 @@ class Task(_Task):
|
|||||||
# always patch OS forking because of ProcessPool and the alike
|
# always patch OS forking because of ProcessPool and the alike
|
||||||
PatchOsFork.patch_fork()
|
PatchOsFork.patch_fork()
|
||||||
if auto_connect_frameworks:
|
if auto_connect_frameworks:
|
||||||
is_auto_connect_frameworks_bool = not isinstance(auto_connect_frameworks, dict)
|
def should_connect(*keys):
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('hydra', True):
|
"""
|
||||||
|
Evaluates value of auto_connect_frameworks[keys[0]]...[keys[-1]].
|
||||||
|
If at some point in the evaluation, the value of auto_connect_frameworks[keys[0]]...[keys[-1]] is a bool,
|
||||||
|
that value will be returned. If a dictionary is empty, it will be evaluated to False.
|
||||||
|
If a key will not be found in the current dictionary, True will be returned.
|
||||||
|
"""
|
||||||
|
should_bind_framework = auto_connect_frameworks
|
||||||
|
for key in keys:
|
||||||
|
if not isinstance(should_bind_framework, dict):
|
||||||
|
return bool(should_bind_framework)
|
||||||
|
if should_bind_framework == {}:
|
||||||
|
return False
|
||||||
|
should_bind_framework = should_bind_framework.get(key, True)
|
||||||
|
return bool(should_bind_framework)
|
||||||
|
|
||||||
|
if should_connect("hydra"):
|
||||||
PatchHydra.update_current_task(task)
|
PatchHydra.update_current_task(task)
|
||||||
if is_auto_connect_frameworks_bool or (
|
if should_connect("scikit") and should_connect("joblib"):
|
||||||
auto_connect_frameworks.get('scikit', True) and
|
|
||||||
auto_connect_frameworks.get('joblib', True)):
|
|
||||||
PatchedJoblib.update_current_task(task)
|
PatchedJoblib.update_current_task(task)
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('matplotlib', True):
|
if should_connect("matplotlib"):
|
||||||
PatchedMatplotlib.update_current_task(Task.__main_task)
|
PatchedMatplotlib.update_current_task(Task.__main_task)
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tensorflow', True) \
|
if should_connect("tensorflow") or should_connect("tensorboard"):
|
||||||
or auto_connect_frameworks.get('tensorboard', True):
|
|
||||||
# allow to disable tfdefines
|
# allow to disable tfdefines
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tfdefines', True):
|
if should_connect("tfdefines"):
|
||||||
PatchAbsl.update_current_task(Task.__main_task)
|
PatchAbsl.update_current_task(Task.__main_task)
|
||||||
TensorflowBinding.update_current_task(
|
TensorflowBinding.update_current_task(
|
||||||
task,
|
task,
|
||||||
patch_reporting=(is_auto_connect_frameworks_bool
|
patch_reporting=should_connect("tensorboard"),
|
||||||
or auto_connect_frameworks.get('tensorboard', True)),
|
patch_model_io=should_connect("tensorflow"),
|
||||||
patch_model_io=(is_auto_connect_frameworks_bool
|
report_hparams=should_connect("tensorboard", "report_hparams"),
|
||||||
or auto_connect_frameworks.get('tensorflow', True)),
|
|
||||||
)
|
)
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True):
|
if should_connect("pytorch"):
|
||||||
PatchPyTorchModelIO.update_current_task(task)
|
PatchPyTorchModelIO.update_current_task(task)
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('megengine', True):
|
if should_connect("megengine"):
|
||||||
PatchMegEngineModelIO.update_current_task(task)
|
PatchMegEngineModelIO.update_current_task(task)
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
|
if should_connect("xgboost"):
|
||||||
PatchXGBoostModelIO.update_current_task(task)
|
PatchXGBoostModelIO.update_current_task(task)
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('catboost', True):
|
if should_connect("catboost"):
|
||||||
PatchCatBoostModelIO.update_current_task(task)
|
PatchCatBoostModelIO.update_current_task(task)
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('fastai', True):
|
if should_connect("fastai"):
|
||||||
PatchFastai.update_current_task(task)
|
PatchFastai.update_current_task(task)
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('lightgbm', True):
|
if should_connect("lightgbm"):
|
||||||
PatchLIGHTgbmModelIO.update_current_task(task)
|
PatchLIGHTgbmModelIO.update_current_task(task)
|
||||||
if auto_resource_monitoring and not is_sub_process_task_id:
|
if auto_resource_monitoring and not is_sub_process_task_id:
|
||||||
resource_monitor_cls = auto_resource_monitoring \
|
resource_monitor_cls = auto_resource_monitoring \
|
||||||
|
Loading…
Reference in New Issue
Block a user