Fix TensorFlow >=2 histogram binding

This commit is contained in:
allegroai 2020-06-15 22:23:09 +03:00
parent 20a9f0997d
commit 9c1d08b826

View File

@ -88,15 +88,16 @@ class WeightsGradientHistHelper(object):
if isinstance(hist_data, dict):
pass
elif isinstance(hist_data, np.ndarray):
hist_data = np.histogram(hist_data)
hist_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()}
else:
elif isinstance(hist_data, np.ndarray) and np.atleast_2d(hist_data).shape[1] == 3:
# prepare the dictionary, assume numpy
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
# histo_data['bucket'] is the histogram height, meaning the Y axis
# hist_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
# hist_data['bucket'] is the histogram height, meaning the Y axis
# notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side
hist_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()}
else:
# assume we have to do the histogram on the data
hist_data = np.histogram(hist_data)
hist_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()}
self._add_histogram(title=title, series=series, step=step, hist_data=hist_data)
@ -128,8 +129,8 @@ class WeightsGradientHistHelper(object):
# add current sample, if not already here
hist_iters = np.append(hist_iters, step)
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
# histo_data['bucket'] is the histogram height, meaning the Y axis
# hist_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
# hist_data['bucket'] is the histogram height, meaning the Y axis
hist = np.array(list(zip(hist_data['bucketLimit'], hist_data['bucket'])), dtype=np.float32)
hist = hist[~np.isinf(hist[:, 0]), :]
hist_list.append(hist)
@ -453,7 +454,7 @@ class EventTrainsWriter(object):
value=scalar_data,
)
def _add_histogram(self, tag, step, histo_data):
def _add_histogram(self, tag, step, hist_data):
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms',
logdir_header='series')
@ -461,7 +462,7 @@ class EventTrainsWriter(object):
title=title,
series=series,
step=step,
hist_data=histo_data
hist_data=hist_data
)
def _add_plot(self, tag, step, values, vdict):
@ -615,7 +616,7 @@ class EventTrainsWriter(object):
if metric == 'simpleValue':
self._add_scalar(tag=tag, step=step, scalar_data=values)
elif metric == 'histo':
self._add_histogram(tag=tag, step=step, histo_data=values)
self._add_histogram(tag=tag, step=step, hist_data=values)
elif metric == 'image':
self._add_image(tag=tag, step=step, img_data=values)
elif metric == 'audio':