diff --git a/examples/keras_tensorboard.py b/examples/keras_tensorboard.py index c581a973..5627090f 100644 --- a/examples/keras_tensorboard.py +++ b/examples/keras_tensorboard.py @@ -51,8 +51,8 @@ class TensorBoardImage(TensorBoard): parser = argparse.ArgumentParser(description='Keras MNIST Example') -parser.add_argument('--batch-size', type=int, default=128, help='input batch size for training (default: 64)') -parser.add_argument('--epochs', type=int, default=6, help='number of epochs to train (default: 10)') +parser.add_argument('--batch-size', type=int, default=128, help='input batch size for training (default: 128)') +parser.add_argument('--epochs', type=int, default=6, help='number of epochs to train (default: 6)') args = parser.parse_args() # the data, shuffled and split between train and test sets diff --git a/trains/binding/frameworks/tensorflow_bind.py b/trains/binding/frameworks/tensorflow_bind.py index 4ea4972a..af53c93f 100644 --- a/trains/binding/frameworks/tensorflow_bind.py +++ b/trains/binding/frameworks/tensorflow_bind.py @@ -1026,11 +1026,15 @@ class PatchTensorFlowEager(object): event_writer._add_histogram(tag=tag, step=step, histo_data=hist_data) return - # prepare the dictionary, assume numpy - # histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis - # histo_data['bucket'] is the histogram height, meaning the Y axis - # notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side - histo_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()} + if isinstance(hist_data, np.ndarray): + hist_data = np.histogram(hist_data) + histo_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()} + else: + # prepare the dictionary, assume numpy + # histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis + # histo_data['bucket'] is the histogram height, meaning the Y axis + # notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side + histo_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()} event_writer._add_histogram(tag=tag, step=step, histo_data=histo_data) @staticmethod