diff --git a/clearml/binding/frameworks/tensorflow_bind.py b/clearml/binding/frameworks/tensorflow_bind.py index a6871bdd..051cbbc0 100644 --- a/clearml/binding/frameworks/tensorflow_bind.py +++ b/clearml/binding/frameworks/tensorflow_bind.py @@ -1257,12 +1257,12 @@ class PatchTensorFlowEager(object): from tensorflow.python.eager import context as _context if not _context.executing_eagerly(): from tensorflow import py_function - from tensorflow.python.framework import ops - with ops.device("cpu:0"): - p_op = py_function( - _report_summary_op, - inp=[writer, step, tag, value, name], Tout=[]) - ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, p_op) + # just creating the operator is enough (for some reason) + # to make sure it is added into the execution tree. + # the operator itself, will do the reporting to the backend + py_function( + _report_summary_op, + inp=[writer, step, tag, value, name], Tout=[]) except Exception as ex: LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) @@ -1299,12 +1299,12 @@ class PatchTensorFlowEager(object): from tensorflow.python.eager import context as _context if not _context.executing_eagerly(): from tensorflow import py_function - from tensorflow.python.framework import ops - with ops.device("cpu:0"): - p_op = py_function( - _report_summary_op, - inp=[writer, step, tag, value, name], Tout=[]) - ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, p_op) + # just creating the operator is enough (for some reason) + # to make sure it is added into the execution tree. + # the operator itself, will do the reporting to the backend + py_function( + _report_summary_op, + inp=[writer, step, tag, value, name], Tout=[]) except Exception as ex: LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) @@ -1338,12 +1338,12 @@ class PatchTensorFlowEager(object): from tensorflow.python.eager import context as _context if not _context.executing_eagerly(): from tensorflow import py_function - from tensorflow.python.framework import ops - with ops.device("cpu:0"): - p_op = py_function( - _report_summary_op, - inp=[writer, step, tag, tensor, bad_color, max_images, name], Tout=[]) - ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, p_op) + # just creating the operator is enough (for some reason) + # to make sure it is added into the execution tree. + # the operator itself, will do the reporting to the backend + py_function( + _report_summary_op, + inp=[writer, step, tag, tensor, bad_color, max_images, name], Tout=[]) except Exception as ex: LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) @@ -1370,6 +1370,12 @@ class PatchTensorFlowEager(object): img_data_np=img_data_np, max_keep_images=kwargs.get('max_images')) + @staticmethod + def _nothing_op(*_, **__): + """Convenient else branch for when summaries do not record.""" + from tensorflow.python.framework import constant_op + return constant_op.constant(False) + # noinspection PyPep8Naming,SpellCheckingInspection class PatchKerasModelIO(object):