mirror of
https://github.com/clearml/clearml
synced 2025-04-03 04:21:03 +00:00
Fix TensorFlow >=2 histogram binding
This commit is contained in:
parent
20a9f0997d
commit
9c1d08b826
@ -88,15 +88,16 @@ class WeightsGradientHistHelper(object):
|
|||||||
|
|
||||||
if isinstance(hist_data, dict):
|
if isinstance(hist_data, dict):
|
||||||
pass
|
pass
|
||||||
elif isinstance(hist_data, np.ndarray):
|
elif isinstance(hist_data, np.ndarray) and np.atleast_2d(hist_data).shape[1] == 3:
|
||||||
hist_data = np.histogram(hist_data)
|
|
||||||
hist_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()}
|
|
||||||
else:
|
|
||||||
# prepare the dictionary, assume numpy
|
# prepare the dictionary, assume numpy
|
||||||
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
# hist_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
||||||
# histo_data['bucket'] is the histogram height, meaning the Y axis
|
# hist_data['bucket'] is the histogram height, meaning the Y axis
|
||||||
# notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side
|
# notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side
|
||||||
hist_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()}
|
hist_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()}
|
||||||
|
else:
|
||||||
|
# assume we have to do the histogram on the data
|
||||||
|
hist_data = np.histogram(hist_data)
|
||||||
|
hist_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()}
|
||||||
|
|
||||||
self._add_histogram(title=title, series=series, step=step, hist_data=hist_data)
|
self._add_histogram(title=title, series=series, step=step, hist_data=hist_data)
|
||||||
|
|
||||||
@ -128,8 +129,8 @@ class WeightsGradientHistHelper(object):
|
|||||||
|
|
||||||
# add current sample, if not already here
|
# add current sample, if not already here
|
||||||
hist_iters = np.append(hist_iters, step)
|
hist_iters = np.append(hist_iters, step)
|
||||||
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
# hist_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
||||||
# histo_data['bucket'] is the histogram height, meaning the Y axis
|
# hist_data['bucket'] is the histogram height, meaning the Y axis
|
||||||
hist = np.array(list(zip(hist_data['bucketLimit'], hist_data['bucket'])), dtype=np.float32)
|
hist = np.array(list(zip(hist_data['bucketLimit'], hist_data['bucket'])), dtype=np.float32)
|
||||||
hist = hist[~np.isinf(hist[:, 0]), :]
|
hist = hist[~np.isinf(hist[:, 0]), :]
|
||||||
hist_list.append(hist)
|
hist_list.append(hist)
|
||||||
@ -453,7 +454,7 @@ class EventTrainsWriter(object):
|
|||||||
value=scalar_data,
|
value=scalar_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_histogram(self, tag, step, histo_data):
|
def _add_histogram(self, tag, step, hist_data):
|
||||||
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms',
|
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms',
|
||||||
logdir_header='series')
|
logdir_header='series')
|
||||||
|
|
||||||
@ -461,7 +462,7 @@ class EventTrainsWriter(object):
|
|||||||
title=title,
|
title=title,
|
||||||
series=series,
|
series=series,
|
||||||
step=step,
|
step=step,
|
||||||
hist_data=histo_data
|
hist_data=hist_data
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_plot(self, tag, step, values, vdict):
|
def _add_plot(self, tag, step, values, vdict):
|
||||||
@ -615,7 +616,7 @@ class EventTrainsWriter(object):
|
|||||||
if metric == 'simpleValue':
|
if metric == 'simpleValue':
|
||||||
self._add_scalar(tag=tag, step=step, scalar_data=values)
|
self._add_scalar(tag=tag, step=step, scalar_data=values)
|
||||||
elif metric == 'histo':
|
elif metric == 'histo':
|
||||||
self._add_histogram(tag=tag, step=step, histo_data=values)
|
self._add_histogram(tag=tag, step=step, hist_data=values)
|
||||||
elif metric == 'image':
|
elif metric == 'image':
|
||||||
self._add_image(tag=tag, step=step, img_data=values)
|
self._add_image(tag=tag, step=step, img_data=values)
|
||||||
elif metric == 'audio':
|
elif metric == 'audio':
|
||||||
|
Loading…
Reference in New Issue
Block a user