Fix TF 2.7 support (get logdir on with multiple TB writers)

This commit is contained in:
allegroai 2021-11-12 20:08:25 +02:00
parent 77c985d961
commit f52fcb9668

View File

@ -1165,6 +1165,7 @@ class PatchTensorFlowEager(object):
__original_fn_hist = None
__original_fn_image = None
__trains_event_writer = {}
__tf_tb_writer_id_to_logdir = {}
defaults_dict = dict(
report_freq=1, image_report_freq=1, histogram_update_freq_multiplier=5,
histogram_granularity=50)
@ -1206,6 +1207,42 @@ class PatchTensorFlowEager(object):
except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
# tensorflow 2.7 support (getting logdir)
try:
import tensorflow # noqa
import tensorflow.python # noqa
from tensorflow.python.ops import gen_summary_ops
gen_summary_ops.create_summary_file_writer = _patched_call(
gen_summary_ops.create_summary_file_writer,
PatchTensorFlowEager._create_summary_file_writer
)
except Exception:
pass
@staticmethod
def _create_summary_file_writer(original_fn, *args, **kwargs):
# noinspection PyBroadException
try:
a_logdir = None
a_writer = None
if kwargs and 'logdir' in kwargs:
a_logdir = kwargs.get('logdir')
elif args and len(args) >= 2:
a_logdir = args[1]
if kwargs and 'writer' in kwargs:
a_writer = kwargs.get('writer')
elif args and len(args) >= 1:
a_writer = args[0]
if a_writer is not None and a_logdir is not None:
a_logdir = a_logdir.numpy().decode()
PatchTensorFlowEager.__tf_tb_writer_id_to_logdir[id(a_writer)] = a_logdir
except Exception:
pass
return original_fn(*args, **kwargs)
@staticmethod
def _get_event_writer(writer):
if not PatchTensorFlowEager.__main_task:
@ -1226,11 +1263,14 @@ class PatchTensorFlowEager(object):
from tensorflow.python.ops.summary_ops_v2 import _summary_state # noqa
logdir = _summary_state.writer._init_op_fn.keywords.get('logdir')
except Exception:
logdir = None
try:
logdir = PatchTensorFlowEager.__tf_tb_writer_id_to_logdir[id(writer)]
except Exception:
logdir = None
# noinspection PyBroadException
try:
if logdir is not None:
logdir = logdir.numpy().decode()
logdir = logdir.numpy().decode() if not isinstance(logdir, str) else logdir
except Exception:
logdir = None