mirror of
https://github.com/clearml/clearml
synced 2025-04-02 12:08:33 +00:00
Fix TensorFlow >=2 histogram binding
This commit is contained in:
parent
20a9f0997d
commit
9c1d08b826
@ -88,15 +88,16 @@ class WeightsGradientHistHelper(object):
|
||||
|
||||
if isinstance(hist_data, dict):
|
||||
pass
|
||||
elif isinstance(hist_data, np.ndarray):
|
||||
hist_data = np.histogram(hist_data)
|
||||
hist_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()}
|
||||
else:
|
||||
elif isinstance(hist_data, np.ndarray) and np.atleast_2d(hist_data).shape[1] == 3:
|
||||
# prepare the dictionary, assume numpy
|
||||
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
||||
# histo_data['bucket'] is the histogram height, meaning the Y axis
|
||||
# hist_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
||||
# hist_data['bucket'] is the histogram height, meaning the Y axis
|
||||
# notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side
|
||||
hist_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()}
|
||||
else:
|
||||
# assume we have to do the histogram on the data
|
||||
hist_data = np.histogram(hist_data)
|
||||
hist_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()}
|
||||
|
||||
self._add_histogram(title=title, series=series, step=step, hist_data=hist_data)
|
||||
|
||||
@ -128,8 +129,8 @@ class WeightsGradientHistHelper(object):
|
||||
|
||||
# add current sample, if not already here
|
||||
hist_iters = np.append(hist_iters, step)
|
||||
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
||||
# histo_data['bucket'] is the histogram height, meaning the Y axis
|
||||
# hist_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
||||
# hist_data['bucket'] is the histogram height, meaning the Y axis
|
||||
hist = np.array(list(zip(hist_data['bucketLimit'], hist_data['bucket'])), dtype=np.float32)
|
||||
hist = hist[~np.isinf(hist[:, 0]), :]
|
||||
hist_list.append(hist)
|
||||
@ -453,7 +454,7 @@ class EventTrainsWriter(object):
|
||||
value=scalar_data,
|
||||
)
|
||||
|
||||
def _add_histogram(self, tag, step, histo_data):
|
||||
def _add_histogram(self, tag, step, hist_data):
|
||||
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms',
|
||||
logdir_header='series')
|
||||
|
||||
@ -461,7 +462,7 @@ class EventTrainsWriter(object):
|
||||
title=title,
|
||||
series=series,
|
||||
step=step,
|
||||
hist_data=histo_data
|
||||
hist_data=hist_data
|
||||
)
|
||||
|
||||
def _add_plot(self, tag, step, values, vdict):
|
||||
@ -615,7 +616,7 @@ class EventTrainsWriter(object):
|
||||
if metric == 'simpleValue':
|
||||
self._add_scalar(tag=tag, step=step, scalar_data=values)
|
||||
elif metric == 'histo':
|
||||
self._add_histogram(tag=tag, step=step, histo_data=values)
|
||||
self._add_histogram(tag=tag, step=step, hist_data=values)
|
||||
elif metric == 'image':
|
||||
self._add_image(tag=tag, step=step, img_data=values)
|
||||
elif metric == 'audio':
|
||||
|
Loading…
Reference in New Issue
Block a user