Convert ndarray to histogram for axis to get rid of warning in tensorflow binding

This commit is contained in:
allegroai 2020-01-10 13:39:31 +02:00
parent 163ace8856
commit 073f4c308d
2 changed files with 11 additions and 7 deletions

View File

@ -51,8 +51,8 @@ class TensorBoardImage(TensorBoard):
parser = argparse.ArgumentParser(description='Keras MNIST Example') parser = argparse.ArgumentParser(description='Keras MNIST Example')
parser.add_argument('--batch-size', type=int, default=128, help='input batch size for training (default: 64)') parser.add_argument('--batch-size', type=int, default=128, help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=6, help='number of epochs to train (default: 10)') parser.add_argument('--epochs', type=int, default=6, help='number of epochs to train (default: 6)')
args = parser.parse_args() args = parser.parse_args()
# the data, shuffled and split between train and test sets # the data, shuffled and split between train and test sets

View File

@ -1026,11 +1026,15 @@ class PatchTensorFlowEager(object):
event_writer._add_histogram(tag=tag, step=step, histo_data=hist_data) event_writer._add_histogram(tag=tag, step=step, histo_data=hist_data)
return return
# prepare the dictionary, assume numpy if isinstance(hist_data, np.ndarray):
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis hist_data = np.histogram(hist_data)
# histo_data['bucket'] is the histogram height, meaning the Y axis histo_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()}
# notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side else:
histo_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()} # prepare the dictionary, assume numpy
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
# histo_data['bucket'] is the histogram height, meaning the Y axis
# notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side
histo_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()}
event_writer._add_histogram(tag=tag, step=step, histo_data=histo_data) event_writer._add_histogram(tag=tag, step=step, histo_data=histo_data)
@staticmethod @staticmethod