From 6778b982a77188256b1ee4c2216ae063e2612970 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 1 Feb 2021 23:41:18 +0200 Subject: [PATCH] Fix TF +2.3 mixed eager mode execution summary metrics not reported --- clearml/binding/frameworks/tensorflow_bind.py | 103 ++++++++++++++++-- 1 file changed, 95 insertions(+), 8 deletions(-) diff --git a/clearml/binding/frameworks/tensorflow_bind.py b/clearml/binding/frameworks/tensorflow_bind.py index 269973f2..a6871bdd 100644 --- a/clearml/binding/frameworks/tensorflow_bind.py +++ b/clearml/binding/frameworks/tensorflow_bind.py @@ -399,8 +399,9 @@ class EventTrainsWriter(object): stack_dim = int(np.sqrt(dims[0])) # noinspection PyArgumentList res = img_data_np.reshape(stack_dim, stack_dim, *dims[1:]).transpose((0, 2, 1, 3, 4)) - tile_size = res.shape[0] * res.shape[1] - img_data_np = res.reshape(tile_size, tile_size, -1) + tile_size_h = res.shape[0] * res.shape[1] + tile_size_w = res.shape[2] * res.shape[3] + img_data_np = res.reshape(tile_size_h, tile_size_w, -1) self._logger.report_image( title=title, @@ -1192,7 +1193,8 @@ class PatchTensorFlowEager(object): @staticmethod def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs): event_writer = PatchTensorFlowEager._get_event_writer(writer) - if event_writer: + # make sure we can get the tensors values + if event_writer and isinstance(step, int) or hasattr(step, 'numpy'): # noinspection PyBroadException try: plugin_type = summary_metadata.decode() @@ -1229,19 +1231,47 @@ class PatchTensorFlowEager(object): @staticmethod def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs): event_writer = PatchTensorFlowEager._get_event_writer(writer) - if event_writer: + if event_writer and isinstance(step, int) or hasattr(step, 'numpy'): try: event_writer._add_scalar(tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step, scalar_data=value.numpy()) except Exception as ex: LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) + elif event_writer: + def _report_summary_op(a_writer, a_step, a_tag, a_value, a_name=None, **_): + if isinstance(a_step, int) or hasattr(a_step, 'numpy'): + try: + str_tag = a_tag.numpy() + str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag) + event_writer._add_scalar( + tag=str_tag, + step=int(a_step.numpy()) if not isinstance(a_step, int) else a_step, + scalar_data=a_value.numpy()) + except Exception as a_ex: + LoggerRoot.get_base_logger(TensorflowBinding).warning( + '_report_summary_op: {}'.format(str(a_ex))) + + # this is a mix of eager and graph execution + try: + from tensorflow.python.eager import context as _context + if not _context.executing_eagerly(): + from tensorflow import py_function + from tensorflow.python.framework import ops + with ops.device("cpu:0"): + p_op = py_function( + _report_summary_op, + inp=[writer, step, tag, value, name], Tout=[]) + ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, p_op) + except Exception as ex: + LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) + return PatchTensorFlowEager.__original_fn_scalar(writer, step, tag, value, name, **kwargs) @staticmethod def _write_hist_summary(writer, step, tag, values, name, **kwargs): event_writer = PatchTensorFlowEager._get_event_writer(writer) - if event_writer: + if event_writer and isinstance(step, int) or hasattr(step, 'numpy'): try: event_writer._add_histogram( tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step, @@ -1249,19 +1279,76 @@ class PatchTensorFlowEager(object): ) except Exception as ex: LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) + elif event_writer: + def _report_summary_op(a_writer, a_step, a_tag, a_value, a_name=None, **_): + if isinstance(a_step, int) or hasattr(a_step, 'numpy'): + try: + str_tag = a_tag.numpy() + str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag) + event_writer._add_histogram( + tag=str_tag, + step=int(a_step.numpy()) if not isinstance(a_step, int) else a_step, + hist_data=a_value.numpy() + ) + except Exception as a_ex: + LoggerRoot.get_base_logger(TensorflowBinding).warning( + '_report_summary_op: {}'.format(str(a_ex))) + + # this is a mix of eager and graph execution + try: + from tensorflow.python.eager import context as _context + if not _context.executing_eagerly(): + from tensorflow import py_function + from tensorflow.python.framework import ops + with ops.device("cpu:0"): + p_op = py_function( + _report_summary_op, + inp=[writer, step, tag, value, name], Tout=[]) + ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, p_op) + except Exception as ex: + LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) + return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs) @staticmethod def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs): event_writer = PatchTensorFlowEager._get_event_writer(writer) - if event_writer: + if event_writer and isinstance(step, int) or hasattr(step, 'numpy'): try: PatchTensorFlowEager._add_image_event_helper(event_writer, img_data_np=tensor.numpy(), tag=tag, step=step, **kwargs) except Exception as ex: LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) - return PatchTensorFlowEager.__original_fn_image(writer, step, tag, tensor, bad_color, max_images, name, - **kwargs) + elif event_writer: + def _report_summary_op(a_writer, a_step, a_tag, a_tensor, a_bad_color, a_max_images, a_name=None, **_): + if isinstance(a_step, int) or hasattr(a_step, 'numpy'): + try: + str_tag = a_tag.numpy() + str_tag = str_tag.decode() if isinstance(str_tag, bytes) else str(str_tag) + PatchTensorFlowEager._add_image_event_helper( + event_writer, img_data_np=a_tensor.numpy(), + tag=str_tag, step=a_step, **kwargs) + + except Exception as a_ex: + LoggerRoot.get_base_logger(TensorflowBinding).warning( + '_report_summary_op: {}'.format(str(a_ex))) + + # this is a mix of eager and graph execution + try: + from tensorflow.python.eager import context as _context + if not _context.executing_eagerly(): + from tensorflow import py_function + from tensorflow.python.framework import ops + with ops.device("cpu:0"): + p_op = py_function( + _report_summary_op, + inp=[writer, step, tag, tensor, bad_color, max_images, name], Tout=[]) + ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, p_op) + except Exception as ex: + LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) + + return PatchTensorFlowEager.__original_fn_image( + writer, step, tag, tensor, bad_color, max_images, name, **kwargs) @staticmethod def _add_image_event_helper(event_writer, img_data_np, tag, step, **kwargs):