mirror of
https://github.com/clearml/clearml
synced 2025-02-07 21:33:25 +00:00
Support multiple EventWriter in TensorFlow eager mode (TF 2.0+)
This commit is contained in:
parent
b4050ecf25
commit
5db53ba643
@ -271,15 +271,16 @@ class EventTrainsWriter(object):
|
|||||||
|
|
||||||
def _add_scalar(self, tag, step, scalar_data):
|
def _add_scalar(self, tag, step, scalar_data):
|
||||||
default_title = tag if not self._logger._get_tensorboard_auto_group_scalars() else 'Scalars'
|
default_title = tag if not self._logger._get_tensorboard_auto_group_scalars() else 'Scalars'
|
||||||
|
series_per_graph = self._logger._get_tensorboard_single_series_per_graph()
|
||||||
|
|
||||||
title, series = self.tag_splitter(
|
title, series = self.tag_splitter(
|
||||||
tag, num_split_parts=1, default_title=default_title, logdir_header='series_last'
|
tag, num_split_parts=1, default_title=default_title,
|
||||||
|
logdir_header='title' if series_per_graph else 'series_last'
|
||||||
)
|
)
|
||||||
|
|
||||||
step = self._fix_step_counter(title, series, step)
|
step = self._fix_step_counter(title, series, step)
|
||||||
tag = self._get_add_scalars_event_tag(default_title)
|
tag = self._get_add_scalars_event_tag(default_title)
|
||||||
|
|
||||||
series_per_graph = self._logger._get_tensorboard_single_series_per_graph()
|
|
||||||
|
|
||||||
possible_title = tag if series_per_graph else None
|
possible_title = tag if series_per_graph else None
|
||||||
possible_tag = None if series_per_graph else tag
|
possible_tag = None if series_per_graph else tag
|
||||||
|
|
||||||
@ -954,7 +955,7 @@ class PatchTensorFlowEager(object):
|
|||||||
__original_fn_scalar = None
|
__original_fn_scalar = None
|
||||||
__original_fn_hist = None
|
__original_fn_hist = None
|
||||||
__original_fn_image = None
|
__original_fn_image = None
|
||||||
__trains_event_writer = None
|
__trains_event_writer = {}
|
||||||
defaults_dict = dict(
|
defaults_dict = dict(
|
||||||
report_freq=1, image_report_freq=1, histogram_update_freq_multiplier=5,
|
report_freq=1, image_report_freq=1, histogram_update_freq_multiplier=5,
|
||||||
histogram_granularity=50)
|
histogram_granularity=50)
|
||||||
@ -997,19 +998,28 @@ class PatchTensorFlowEager(object):
|
|||||||
def _get_event_writer(writer):
|
def _get_event_writer(writer):
|
||||||
if not PatchTensorFlowEager.__main_task:
|
if not PatchTensorFlowEager.__main_task:
|
||||||
return None
|
return None
|
||||||
if PatchTensorFlowEager.__trains_event_writer is None:
|
if not PatchTensorFlowEager.__trains_event_writer.get(id(writer)):
|
||||||
try:
|
try:
|
||||||
logdir = writer.get_logdir()
|
logdir = writer.get_logdir()
|
||||||
except Exception:
|
except Exception:
|
||||||
logdir = None
|
# check if we are in eager mode, let's get the global context lopdir
|
||||||
PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter(
|
try:
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
logdir = context.context().summary_writer._init_op_fn.keywords.get('logdir')
|
||||||
|
logdir = logdir.numpy().decode()
|
||||||
|
except:
|
||||||
|
logdir = None
|
||||||
|
PatchTensorFlowEager.__trains_event_writer[id(writer)] = EventTrainsWriter(
|
||||||
logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir,
|
logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir,
|
||||||
**PatchTensorFlowEager.defaults_dict)
|
**PatchTensorFlowEager.defaults_dict)
|
||||||
return PatchTensorFlowEager.__trains_event_writer
|
return PatchTensorFlowEager.__trains_event_writer[id(writer)]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def trains_object(self):
|
def trains_object(self):
|
||||||
return PatchTensorFlowEager.__trains_event_writer
|
if not PatchTensorFlowEager.__trains_event_writer:
|
||||||
|
return None
|
||||||
|
return PatchTensorFlowEager.__trains_event_writer.get(
|
||||||
|
id(self), list(PatchTensorFlowEager.__trains_event_writer.values())[0])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs):
|
def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs):
|
||||||
|
Loading…
Reference in New Issue
Block a user