diff --git a/trains/binding/frameworks/tensorflow_bind.py b/trains/binding/frameworks/tensorflow_bind.py index bc73d632..0afc93ed 100644 --- a/trains/binding/frameworks/tensorflow_bind.py +++ b/trains/binding/frameworks/tensorflow_bind.py @@ -271,15 +271,16 @@ class EventTrainsWriter(object): def _add_scalar(self, tag, step, scalar_data): default_title = tag if not self._logger._get_tensorboard_auto_group_scalars() else 'Scalars' + series_per_graph = self._logger._get_tensorboard_single_series_per_graph() + title, series = self.tag_splitter( - tag, num_split_parts=1, default_title=default_title, logdir_header='series_last' + tag, num_split_parts=1, default_title=default_title, + logdir_header='title' if series_per_graph else 'series_last' ) step = self._fix_step_counter(title, series, step) tag = self._get_add_scalars_event_tag(default_title) - series_per_graph = self._logger._get_tensorboard_single_series_per_graph() - possible_title = tag if series_per_graph else None possible_tag = None if series_per_graph else tag @@ -954,7 +955,7 @@ class PatchTensorFlowEager(object): __original_fn_scalar = None __original_fn_hist = None __original_fn_image = None - __trains_event_writer = None + __trains_event_writer = {} defaults_dict = dict( report_freq=1, image_report_freq=1, histogram_update_freq_multiplier=5, histogram_granularity=50) @@ -997,19 +998,28 @@ class PatchTensorFlowEager(object): def _get_event_writer(writer): if not PatchTensorFlowEager.__main_task: return None - if PatchTensorFlowEager.__trains_event_writer is None: + if not PatchTensorFlowEager.__trains_event_writer.get(id(writer)): try: logdir = writer.get_logdir() except Exception: - logdir = None - PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter( + # check if we are in eager mode, let's get the global context lopdir + try: + from tensorflow.python.eager import context + logdir = context.context().summary_writer._init_op_fn.keywords.get('logdir') + logdir = logdir.numpy().decode() + except: + logdir = None + PatchTensorFlowEager.__trains_event_writer[id(writer)] = EventTrainsWriter( logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir, **PatchTensorFlowEager.defaults_dict) - return PatchTensorFlowEager.__trains_event_writer + return PatchTensorFlowEager.__trains_event_writer[id(writer)] @staticmethod def trains_object(self): - return PatchTensorFlowEager.__trains_event_writer + if not PatchTensorFlowEager.__trains_event_writer: + return None + return PatchTensorFlowEager.__trains_event_writer.get( + id(self), list(PatchTensorFlowEager.__trains_event_writer.values())[0]) @staticmethod def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs):