From f52fcb9668294b4ab06be0774280e2a4426d6588 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 12 Nov 2021 20:08:25 +0200 Subject: [PATCH] Fix TF 2.7 support (get logdir on with multiple TB writers) --- clearml/binding/frameworks/tensorflow_bind.py | 44 ++++++++++++++++++- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/clearml/binding/frameworks/tensorflow_bind.py b/clearml/binding/frameworks/tensorflow_bind.py index e44a7dc2..1f7bcd7b 100644 --- a/clearml/binding/frameworks/tensorflow_bind.py +++ b/clearml/binding/frameworks/tensorflow_bind.py @@ -1165,6 +1165,7 @@ class PatchTensorFlowEager(object): __original_fn_hist = None __original_fn_image = None __trains_event_writer = {} + __tf_tb_writer_id_to_logdir = {} defaults_dict = dict( report_freq=1, image_report_freq=1, histogram_update_freq_multiplier=5, histogram_granularity=50) @@ -1206,6 +1207,42 @@ class PatchTensorFlowEager(object): except Exception as ex: LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex)) + # tensorflow 2.7 support (getting logdir) + try: + import tensorflow # noqa + import tensorflow.python # noqa + from tensorflow.python.ops import gen_summary_ops + gen_summary_ops.create_summary_file_writer = _patched_call( + gen_summary_ops.create_summary_file_writer, + PatchTensorFlowEager._create_summary_file_writer + ) + except Exception: + pass + + @staticmethod + def _create_summary_file_writer(original_fn, *args, **kwargs): + # noinspection PyBroadException + try: + a_logdir = None + a_writer = None + if kwargs and 'logdir' in kwargs: + a_logdir = kwargs.get('logdir') + elif args and len(args) >= 2: + a_logdir = args[1] + + if kwargs and 'writer' in kwargs: + a_writer = kwargs.get('writer') + elif args and len(args) >= 1: + a_writer = args[0] + + if a_writer is not None and a_logdir is not None: + a_logdir = a_logdir.numpy().decode() + PatchTensorFlowEager.__tf_tb_writer_id_to_logdir[id(a_writer)] = a_logdir + except Exception: + pass + + return original_fn(*args, **kwargs) + @staticmethod def _get_event_writer(writer): if not PatchTensorFlowEager.__main_task: @@ -1226,11 +1263,14 @@ class PatchTensorFlowEager(object): from tensorflow.python.ops.summary_ops_v2 import _summary_state # noqa logdir = _summary_state.writer._init_op_fn.keywords.get('logdir') except Exception: - logdir = None + try: + logdir = PatchTensorFlowEager.__tf_tb_writer_id_to_logdir[id(writer)] + except Exception: + logdir = None # noinspection PyBroadException try: if logdir is not None: - logdir = logdir.numpy().decode() + logdir = logdir.numpy().decode() if not isinstance(logdir, str) else logdir except Exception: logdir = None