mirror of
https://github.com/clearml/clearml
synced 2025-01-31 17:17:00 +00:00
Support multiple EventWriter in TensorFlow eager mode (TF 2.0+)
This commit is contained in:
parent
b4050ecf25
commit
5db53ba643
@ -271,15 +271,16 @@ class EventTrainsWriter(object):
|
||||
|
||||
def _add_scalar(self, tag, step, scalar_data):
|
||||
default_title = tag if not self._logger._get_tensorboard_auto_group_scalars() else 'Scalars'
|
||||
series_per_graph = self._logger._get_tensorboard_single_series_per_graph()
|
||||
|
||||
title, series = self.tag_splitter(
|
||||
tag, num_split_parts=1, default_title=default_title, logdir_header='series_last'
|
||||
tag, num_split_parts=1, default_title=default_title,
|
||||
logdir_header='title' if series_per_graph else 'series_last'
|
||||
)
|
||||
|
||||
step = self._fix_step_counter(title, series, step)
|
||||
tag = self._get_add_scalars_event_tag(default_title)
|
||||
|
||||
series_per_graph = self._logger._get_tensorboard_single_series_per_graph()
|
||||
|
||||
possible_title = tag if series_per_graph else None
|
||||
possible_tag = None if series_per_graph else tag
|
||||
|
||||
@ -954,7 +955,7 @@ class PatchTensorFlowEager(object):
|
||||
__original_fn_scalar = None
|
||||
__original_fn_hist = None
|
||||
__original_fn_image = None
|
||||
__trains_event_writer = None
|
||||
__trains_event_writer = {}
|
||||
defaults_dict = dict(
|
||||
report_freq=1, image_report_freq=1, histogram_update_freq_multiplier=5,
|
||||
histogram_granularity=50)
|
||||
@ -997,19 +998,28 @@ class PatchTensorFlowEager(object):
|
||||
def _get_event_writer(writer):
|
||||
if not PatchTensorFlowEager.__main_task:
|
||||
return None
|
||||
if PatchTensorFlowEager.__trains_event_writer is None:
|
||||
if not PatchTensorFlowEager.__trains_event_writer.get(id(writer)):
|
||||
try:
|
||||
logdir = writer.get_logdir()
|
||||
except Exception:
|
||||
logdir = None
|
||||
PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter(
|
||||
# check if we are in eager mode, let's get the global context lopdir
|
||||
try:
|
||||
from tensorflow.python.eager import context
|
||||
logdir = context.context().summary_writer._init_op_fn.keywords.get('logdir')
|
||||
logdir = logdir.numpy().decode()
|
||||
except:
|
||||
logdir = None
|
||||
PatchTensorFlowEager.__trains_event_writer[id(writer)] = EventTrainsWriter(
|
||||
logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir,
|
||||
**PatchTensorFlowEager.defaults_dict)
|
||||
return PatchTensorFlowEager.__trains_event_writer
|
||||
return PatchTensorFlowEager.__trains_event_writer[id(writer)]
|
||||
|
||||
@staticmethod
|
||||
def trains_object(self):
|
||||
return PatchTensorFlowEager.__trains_event_writer
|
||||
if not PatchTensorFlowEager.__trains_event_writer:
|
||||
return None
|
||||
return PatchTensorFlowEager.__trains_event_writer.get(
|
||||
id(self), list(PatchTensorFlowEager.__trains_event_writer.values())[0])
|
||||
|
||||
@staticmethod
|
||||
def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs):
|
||||
|
Loading…
Reference in New Issue
Block a user