From 073f4c308ddc90978808f293891de7d4c12a36c1 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 10 Jan 2020 13:39:31 +0200 Subject: [PATCH] Convert ndarray to histogram for axis to get rid of warning in tensorflow binding --- examples/keras_tensorboard.py | 4 ++-- trains/binding/frameworks/tensorflow_bind.py | 14 +++++++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/keras_tensorboard.py b/examples/keras_tensorboard.py index c581a973..5627090f 100644 --- a/examples/keras_tensorboard.py +++ b/examples/keras_tensorboard.py @@ -51,8 +51,8 @@ class TensorBoardImage(TensorBoard): parser = argparse.ArgumentParser(description='Keras MNIST Example') -parser.add_argument('--batch-size', type=int, default=128, help='input batch size for training (default: 64)') -parser.add_argument('--epochs', type=int, default=6, help='number of epochs to train (default: 10)') +parser.add_argument('--batch-size', type=int, default=128, help='input batch size for training (default: 128)') +parser.add_argument('--epochs', type=int, default=6, help='number of epochs to train (default: 6)') args = parser.parse_args() # the data, shuffled and split between train and test sets diff --git a/trains/binding/frameworks/tensorflow_bind.py b/trains/binding/frameworks/tensorflow_bind.py index 4ea4972a..af53c93f 100644 --- a/trains/binding/frameworks/tensorflow_bind.py +++ b/trains/binding/frameworks/tensorflow_bind.py @@ -1026,11 +1026,15 @@ class PatchTensorFlowEager(object): event_writer._add_histogram(tag=tag, step=step, histo_data=hist_data) return - # prepare the dictionary, assume numpy - # histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis - # histo_data['bucket'] is the histogram height, meaning the Y axis - # notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side - histo_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()} + if isinstance(hist_data, np.ndarray): + hist_data = np.histogram(hist_data) + histo_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()} + else: + # prepare the dictionary, assume numpy + # histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis + # histo_data['bucket'] is the histogram height, meaning the Y axis + # notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side + histo_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()} event_writer._add_histogram(tag=tag, step=step, histo_data=histo_data) @staticmethod