Fix TF +2.3 mixed eager mode execution summary metrics not reported

This commit is contained in:
allegroai 2021-02-01 23:41:18 +02:00
parent 3890477056
commit 6778b982a7

View File

@ -399,8 +399,9 @@ class EventTrainsWriter(object):
stack_dim = int(np.sqrt(dims[0]))
# noinspection PyArgumentList
res = img_data_np.reshape(stack_dim, stack_dim, *dims[1:]).transpose((0, 2, 1, 3, 4))
tile_size = res.shape[0] * res.shape[1]
img_data_np = res.reshape(tile_size, tile_size, -1)
tile_size_h = res.shape[0] * res.shape[1]
tile_size_w = res.shape[2] * res.shape[3]
img_data_np = res.reshape(tile_size_h, tile_size_w, -1)
self._logger.report_image(
title=title,
@ -1192,7 +1193,8 @@ class PatchTensorFlowEager(object):
@staticmethod
def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer:
# make sure we can get the tensors values
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
# noinspection PyBroadException
try:
plugin_type = summary_metadata.decode()
@ -1229,19 +1231,47 @@ class PatchTensorFlowEager(object):
@staticmethod
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer:
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
try:
event_writer._add_scalar(tag=str(tag),
step=int(step.numpy()) if not isinstance(step, int) else step,
scalar_data=value.numpy())
except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
elif event_writer:
def _report_summary_op(a_writer, a_step, a_tag, a_value, a_name=None, **_):
if isinstance(a_step, int) or hasattr(a_step, 'numpy'):
try:
str_tag = a_tag.numpy()
str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag)
event_writer._add_scalar(
tag=str_tag,
step=int(a_step.numpy()) if not isinstance(a_step, int) else a_step,
scalar_data=a_value.numpy())
except Exception as a_ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(
'_report_summary_op: {}'.format(str(a_ex)))
# this is a mix of eager and graph execution
try:
from tensorflow.python.eager import context as _context
if not _context.executing_eagerly():
from tensorflow import py_function
from tensorflow.python.framework import ops
with ops.device("cpu:0"):
p_op = py_function(
_report_summary_op,
inp=[writer, step, tag, value, name], Tout=[])
ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, p_op)
except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
return PatchTensorFlowEager.__original_fn_scalar(writer, step, tag, value, name, **kwargs)
@staticmethod
def _write_hist_summary(writer, step, tag, values, name, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer:
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
try:
event_writer._add_histogram(
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
@ -1249,19 +1279,76 @@ class PatchTensorFlowEager(object):
)
except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
elif event_writer:
def _report_summary_op(a_writer, a_step, a_tag, a_value, a_name=None, **_):
if isinstance(a_step, int) or hasattr(a_step, 'numpy'):
try:
str_tag = a_tag.numpy()
str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag)
event_writer._add_histogram(
tag=str_tag,
step=int(a_step.numpy()) if not isinstance(a_step, int) else a_step,
hist_data=a_value.numpy()
)
except Exception as a_ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(
'_report_summary_op: {}'.format(str(a_ex)))
# this is a mix of eager and graph execution
try:
from tensorflow.python.eager import context as _context
if not _context.executing_eagerly():
from tensorflow import py_function
from tensorflow.python.framework import ops
with ops.device("cpu:0"):
p_op = py_function(
_report_summary_op,
inp=[writer, step, tag, value, name], Tout=[])
ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, p_op)
except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs)
@staticmethod
def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer:
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
try:
PatchTensorFlowEager._add_image_event_helper(event_writer, img_data_np=tensor.numpy(),
tag=tag, step=step, **kwargs)
except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
return PatchTensorFlowEager.__original_fn_image(writer, step, tag, tensor, bad_color, max_images, name,
**kwargs)
elif event_writer:
def _report_summary_op(a_writer, a_step, a_tag, a_tensor, a_bad_color, a_max_images, a_name=None, **_):
if isinstance(a_step, int) or hasattr(a_step, 'numpy'):
try:
str_tag = a_tag.numpy()
str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag)
PatchTensorFlowEager._add_image_event_helper(
event_writer, img_data_np=a_tensor.numpy(),
tag=str_tag, step=a_step, **kwargs)
except Exception as a_ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(
'_report_summary_op: {}'.format(str(a_ex)))
# this is a mix of eager and graph execution
try:
from tensorflow.python.eager import context as _context
if not _context.executing_eagerly():
from tensorflow import py_function
from tensorflow.python.framework import ops
with ops.device("cpu:0"):
p_op = py_function(
_report_summary_op,
inp=[writer, step, tag, tensor, bad_color, max_images, name], Tout=[])
ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, p_op)
except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
return PatchTensorFlowEager.__original_fn_image(
writer, step, tag, tensor, bad_color, max_images, name, **kwargs)
@staticmethod
def _add_image_event_helper(event_writer, img_data_np, tag, step, **kwargs):