From 9c1d08b8264c98f70e52f7587467301630406bb5 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 15 Jun 2020 22:23:09 +0300 Subject: [PATCH] Fix TensorFlow >=2 histogram binding --- trains/binding/frameworks/tensorflow_bind.py | 23 ++++++++++---------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/trains/binding/frameworks/tensorflow_bind.py b/trains/binding/frameworks/tensorflow_bind.py index 40d70867..71a2af7e 100644 --- a/trains/binding/frameworks/tensorflow_bind.py +++ b/trains/binding/frameworks/tensorflow_bind.py @@ -88,15 +88,16 @@ class WeightsGradientHistHelper(object): if isinstance(hist_data, dict): pass - elif isinstance(hist_data, np.ndarray): - hist_data = np.histogram(hist_data) - hist_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()} - else: + elif isinstance(hist_data, np.ndarray) and np.atleast_2d(hist_data).shape[1] == 3: # prepare the dictionary, assume numpy - # histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis - # histo_data['bucket'] is the histogram height, meaning the Y axis + # hist_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis + # hist_data['bucket'] is the histogram height, meaning the Y axis # notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side hist_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()} + else: + # assume we have to do the histogram on the data + hist_data = np.histogram(hist_data) + hist_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()} self._add_histogram(title=title, series=series, step=step, hist_data=hist_data) @@ -128,8 +129,8 @@ class WeightsGradientHistHelper(object): # add current sample, if not already here hist_iters = np.append(hist_iters, step) - # histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis - # histo_data['bucket'] is the histogram height, meaning the Y axis + # hist_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis + # hist_data['bucket'] is the histogram height, meaning the Y axis hist = np.array(list(zip(hist_data['bucketLimit'], hist_data['bucket'])), dtype=np.float32) hist = hist[~np.isinf(hist[:, 0]), :] hist_list.append(hist) @@ -453,7 +454,7 @@ class EventTrainsWriter(object): value=scalar_data, ) - def _add_histogram(self, tag, step, histo_data): + def _add_histogram(self, tag, step, hist_data): title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms', logdir_header='series') @@ -461,7 +462,7 @@ class EventTrainsWriter(object): title=title, series=series, step=step, - hist_data=histo_data + hist_data=hist_data ) def _add_plot(self, tag, step, values, vdict): @@ -615,7 +616,7 @@ class EventTrainsWriter(object): if metric == 'simpleValue': self._add_scalar(tag=tag, step=step, scalar_data=values) elif metric == 'histo': - self._add_histogram(tag=tag, step=step, histo_data=values) + self._add_histogram(tag=tag, step=step, hist_data=values) elif metric == 'image': self._add_image(tag=tag, step=step, img_data=values) elif metric == 'audio':