Support multiple EventWriter in TensorFlow eager mode (TF 2.0+)

This commit is contained in:
allegroai 2020-03-20 10:29:18 +02:00
parent b4050ecf25
commit 5db53ba643

View File

@ -271,15 +271,16 @@ class EventTrainsWriter(object):
def _add_scalar(self, tag, step, scalar_data):
default_title = tag if not self._logger._get_tensorboard_auto_group_scalars() else 'Scalars'
series_per_graph = self._logger._get_tensorboard_single_series_per_graph()
title, series = self.tag_splitter(
tag, num_split_parts=1, default_title=default_title, logdir_header='series_last'
tag, num_split_parts=1, default_title=default_title,
logdir_header='title' if series_per_graph else 'series_last'
)
step = self._fix_step_counter(title, series, step)
tag = self._get_add_scalars_event_tag(default_title)
series_per_graph = self._logger._get_tensorboard_single_series_per_graph()
possible_title = tag if series_per_graph else None
possible_tag = None if series_per_graph else tag
@ -954,7 +955,7 @@ class PatchTensorFlowEager(object):
__original_fn_scalar = None
__original_fn_hist = None
__original_fn_image = None
__trains_event_writer = None
__trains_event_writer = {}
defaults_dict = dict(
report_freq=1, image_report_freq=1, histogram_update_freq_multiplier=5,
histogram_granularity=50)
@ -997,19 +998,28 @@ class PatchTensorFlowEager(object):
def _get_event_writer(writer):
if not PatchTensorFlowEager.__main_task:
return None
if PatchTensorFlowEager.__trains_event_writer is None:
if not PatchTensorFlowEager.__trains_event_writer.get(id(writer)):
try:
logdir = writer.get_logdir()
except Exception:
logdir = None
PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter(
# check if we are in eager mode, let's get the global context lopdir
try:
from tensorflow.python.eager import context
logdir = context.context().summary_writer._init_op_fn.keywords.get('logdir')
logdir = logdir.numpy().decode()
except:
logdir = None
PatchTensorFlowEager.__trains_event_writer[id(writer)] = EventTrainsWriter(
logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir,
**PatchTensorFlowEager.defaults_dict)
return PatchTensorFlowEager.__trains_event_writer
return PatchTensorFlowEager.__trains_event_writer[id(writer)]
@staticmethod
def trains_object(self):
return PatchTensorFlowEager.__trains_event_writer
if not PatchTensorFlowEager.__trains_event_writer:
return None
return PatchTensorFlowEager.__trains_event_writer.get(
id(self), list(PatchTensorFlowEager.__trains_event_writer.values())[0])
@staticmethod
def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs):