mirror of
https://github.com/clearml/clearml
synced 2025-03-03 18:52:12 +00:00
Fix TF 2.7 support (get logdir on with multiple TB writers)
This commit is contained in:
parent
77c985d961
commit
f52fcb9668
@ -1165,6 +1165,7 @@ class PatchTensorFlowEager(object):
|
||||
__original_fn_hist = None
|
||||
__original_fn_image = None
|
||||
__trains_event_writer = {}
|
||||
__tf_tb_writer_id_to_logdir = {}
|
||||
defaults_dict = dict(
|
||||
report_freq=1, image_report_freq=1, histogram_update_freq_multiplier=5,
|
||||
histogram_granularity=50)
|
||||
@ -1206,6 +1207,42 @@ class PatchTensorFlowEager(object):
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug(str(ex))
|
||||
|
||||
# tensorflow 2.7 support (getting logdir)
|
||||
try:
|
||||
import tensorflow # noqa
|
||||
import tensorflow.python # noqa
|
||||
from tensorflow.python.ops import gen_summary_ops
|
||||
gen_summary_ops.create_summary_file_writer = _patched_call(
|
||||
gen_summary_ops.create_summary_file_writer,
|
||||
PatchTensorFlowEager._create_summary_file_writer
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _create_summary_file_writer(original_fn, *args, **kwargs):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
a_logdir = None
|
||||
a_writer = None
|
||||
if kwargs and 'logdir' in kwargs:
|
||||
a_logdir = kwargs.get('logdir')
|
||||
elif args and len(args) >= 2:
|
||||
a_logdir = args[1]
|
||||
|
||||
if kwargs and 'writer' in kwargs:
|
||||
a_writer = kwargs.get('writer')
|
||||
elif args and len(args) >= 1:
|
||||
a_writer = args[0]
|
||||
|
||||
if a_writer is not None and a_logdir is not None:
|
||||
a_logdir = a_logdir.numpy().decode()
|
||||
PatchTensorFlowEager.__tf_tb_writer_id_to_logdir[id(a_writer)] = a_logdir
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return original_fn(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _get_event_writer(writer):
|
||||
if not PatchTensorFlowEager.__main_task:
|
||||
@ -1226,11 +1263,14 @@ class PatchTensorFlowEager(object):
|
||||
from tensorflow.python.ops.summary_ops_v2 import _summary_state # noqa
|
||||
logdir = _summary_state.writer._init_op_fn.keywords.get('logdir')
|
||||
except Exception:
|
||||
logdir = None
|
||||
try:
|
||||
logdir = PatchTensorFlowEager.__tf_tb_writer_id_to_logdir[id(writer)]
|
||||
except Exception:
|
||||
logdir = None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if logdir is not None:
|
||||
logdir = logdir.numpy().decode()
|
||||
logdir = logdir.numpy().decode() if not isinstance(logdir, str) else logdir
|
||||
except Exception:
|
||||
logdir = None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user