diff --git a/docs/trains.conf b/docs/trains.conf index d45211cb..79e0b727 100644 --- a/docs/trains.conf +++ b/docs/trains.conf @@ -42,6 +42,9 @@ sdk { quality: 87 subsampling: 0 } + + # Support plot-per-graph fully matching Tensorboard behavior (i.e. if this is set to False, each series should have its own graph) + tensorboard_auto_group_scalars: True } network { diff --git a/trains/binding/frameworks/tensorflow_bind.py b/trains/binding/frameworks/tensorflow_bind.py index 6caaf49b..659f9f21 100644 --- a/trains/binding/frameworks/tensorflow_bind.py +++ b/trains/binding/frameworks/tensorflow_bind.py @@ -14,7 +14,7 @@ from PIL import Image from ...debugging.log import LoggerRoot from ..frameworks import _patched_call, WeightsFileHandler, _Empty from ..import_bind import PostImportHookPatching -from ...config import running_remotely +from ...config import running_remotely, config from ...model import InputModel, OutputModel, Framework try: @@ -274,7 +274,14 @@ class EventTrainsWriter(object): title, series = self.tag_splitter(tag, num_split_parts=1, default_title=default_title, logdir_header='series_last') step = self._fix_step_counter(title, series, step) + tag = self._get_add_scalars_event_tag(default_title) + group = config.get('metrics.tensorboard_auto_group_scalars', True) + possible_title = None if group else tag + possible_tag = tag if group else None + + title = title + possible_title if possible_title else title + series = possible_tag or series # update scalar cache num, value = self._scalar_report_cache.get((title, series), (0, 0)) self._scalar_report_cache[(title, series)] = (num + 1, value + scalar_data) @@ -553,6 +560,38 @@ class EventTrainsWriter(object): """ pass + def _get_add_scalars_event_tag(self, title_prefix): + """ + + :param str title_prefix: the table title prefix that was added to the series. + :return: str same as tensorboard use + """ + # HACK - this is tensorboard Summary util function, original path: + # ~/torch/utils/tensorboard/summary.py + def _clean_tag(name): + import re as _re + _INVALID_TAG_CHARACTERS = _re.compile(r'[^-/\w\.]') + if name is not None: + new_name = _INVALID_TAG_CHARACTERS.sub('_', name) + new_name = new_name.lstrip('/') # Remove leading slashes + if new_name != name: + LoggerRoot.get_base_logger(TensorflowBinding).debug( + 'Summary name %s is illegal; using %s instead.' % (name, new_name)) + name = new_name + return name + + main_path = self._logdir + try: + main_path = _clean_tag(main_path) + origin_tag = main_path.rpartition("/")[2].replace(title_prefix, "", 1) + if title_prefix and origin_tag[0] == "_": # add_scalars tag + origin_tag = origin_tag[1:] # Remove the first "_" that was added by the main_tag in tensorboard + else: + return "" + except Exception: + origin_tag = "" + return origin_tag + class ProxyEventsWriter(object): def __init__(self, events): diff --git a/trains/config/default/sdk.conf b/trains/config/default/sdk.conf index efcbbf54..bccb3c10 100644 --- a/trains/config/default/sdk.conf +++ b/trains/config/default/sdk.conf @@ -32,6 +32,8 @@ quality: 87 subsampling: 0 } + # Support plot-per-graph fully matching Tensorboard behavior (i.e. if this is set to False, each series should have its own graph) + tensorboard_auto_group_scalars: True } network {