mirror of
https://github.com/clearml/clearml
synced 2025-05-29 17:48:33 +00:00
Refactor histogram code for PyTorch Ignite integration
This commit is contained in:
parent
966cd6118a
commit
0298b84030
@ -54,12 +54,17 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
elif hasattr(f, 'as_posix'):
|
||||
filename = f.as_posix()
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
f.flush()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not isinstance(f.name, six.string_types):
|
||||
# Probably a BufferedRandom object that has no meaningful name (still no harm flushing)
|
||||
return ret
|
||||
|
||||
filename = f.name
|
||||
else:
|
||||
filename = None
|
||||
except Exception:
|
||||
|
@ -51,6 +51,136 @@ class IsTensorboardInit(object):
|
||||
return original_init(self, *args, **kwargs)
|
||||
|
||||
|
||||
class WeightsGradientHistHelper(object):
|
||||
def __init__(self, logger, report_freq=100, histogram_update_freq_multiplier=10, histogram_granularity=50):
|
||||
self._logger = logger
|
||||
self.report_freq = report_freq
|
||||
self._histogram_granularity = histogram_granularity
|
||||
self._histogram_update_freq_multiplier = histogram_update_freq_multiplier
|
||||
self._histogram_update_call_counter = 0
|
||||
self._hist_report_cache = {}
|
||||
self._hist_x_granularity = 50
|
||||
|
||||
@staticmethod
|
||||
def _sample_histograms(_hist_iters, _histogram_granularity):
|
||||
# re-sample history based on distribution of samples across time (steps)
|
||||
ratio = ((_hist_iters[-1] - _hist_iters[_histogram_granularity]) /
|
||||
(_hist_iters[_histogram_granularity - 1] - _hist_iters[0])) if \
|
||||
_hist_iters.size > _histogram_granularity else 0.
|
||||
cur_idx_below = np.arange(0, min(_hist_iters.size, _histogram_granularity - 1))
|
||||
np.random.shuffle(cur_idx_below)
|
||||
cur_idx_below = cur_idx_below[:int(_histogram_granularity * (1.0 - ratio / (1 + ratio)) + 0.5)]
|
||||
if ratio > 0.0:
|
||||
cur_idx_above = np.arange(_histogram_granularity - 1, _hist_iters.size)
|
||||
np.random.shuffle(cur_idx_above)
|
||||
cur_idx_above = cur_idx_above[:int(_histogram_granularity * ratio / (1 + ratio))]
|
||||
else:
|
||||
cur_idx_above = np.array([])
|
||||
_cur_idx = np.unique(np.sort(np.concatenate((cur_idx_below, cur_idx_above)).astype(np.int)))
|
||||
return _cur_idx
|
||||
|
||||
def add_histogram(self, title, series, step, hist_data):
|
||||
# only collect histogram every specific interval
|
||||
self._histogram_update_call_counter += 1
|
||||
if self._histogram_update_call_counter % self.report_freq != 0 or \
|
||||
self._histogram_update_call_counter < self.report_freq - 1:
|
||||
return None
|
||||
|
||||
if isinstance(hist_data, dict):
|
||||
pass
|
||||
elif isinstance(hist_data, np.ndarray):
|
||||
hist_data = np.histogram(hist_data)
|
||||
hist_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()}
|
||||
else:
|
||||
# prepare the dictionary, assume numpy
|
||||
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
||||
# histo_data['bucket'] is the histogram height, meaning the Y axis
|
||||
# notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side
|
||||
hist_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()}
|
||||
|
||||
self._add_histogram(title=title, series=series, step=step, hist_data=hist_data)
|
||||
|
||||
def _add_histogram(self, title, series, step, hist_data):
|
||||
# only collect histogram every specific interval
|
||||
self._histogram_update_call_counter += 1
|
||||
if self._histogram_update_call_counter % self.report_freq != 0 or \
|
||||
self._histogram_update_call_counter < self.report_freq - 1:
|
||||
return None
|
||||
|
||||
# generate forward matrix of the histograms
|
||||
# Y-axis (rows) is iteration (from 0 to current Step)
|
||||
# X-axis averaged bins (conformed sample 'bucketLimit')
|
||||
# Z-axis actual value (interpolated 'bucket')
|
||||
step = EventTrainsWriter._fix_step_counter(title, series, step)
|
||||
|
||||
# get histograms from cache
|
||||
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
|
||||
|
||||
# resample data so we are always constrained in number of histogram we keep
|
||||
if hist_iters.size >= self._histogram_granularity ** 2:
|
||||
idx = self._sample_histograms(hist_iters, self._histogram_granularity)
|
||||
hist_iters = hist_iters[idx]
|
||||
hist_list = [hist_list[i] for i in idx]
|
||||
|
||||
# check if current sample is not already here (actually happens some times)
|
||||
if step in hist_iters:
|
||||
return None
|
||||
|
||||
# add current sample, if not already here
|
||||
hist_iters = np.append(hist_iters, step)
|
||||
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
||||
# histo_data['bucket'] is the histogram height, meaning the Y axis
|
||||
hist = np.array(list(zip(hist_data['bucketLimit'], hist_data['bucket'])), dtype=np.float32)
|
||||
hist = hist[~np.isinf(hist[:, 0]), :]
|
||||
hist_list.append(hist)
|
||||
# keep track of min/max values of histograms (for later re-binning)
|
||||
if minmax is None:
|
||||
minmax = hist[:, 0].min(), hist[:, 0].max()
|
||||
else:
|
||||
minmax = min(minmax[0], hist[:, 0].min()), max(minmax[1], hist[:, 0].max())
|
||||
|
||||
# update the cache
|
||||
self._hist_report_cache[(title, series)] = hist_list, hist_iters, minmax
|
||||
|
||||
# only report histogram every specific interval, but do report the first few, so you know there are histograms
|
||||
if hist_iters.size < 1 or (hist_iters.size >= self._histogram_update_freq_multiplier and
|
||||
hist_iters.size % self._histogram_update_freq_multiplier != 0):
|
||||
return None
|
||||
|
||||
# resample histograms on a unified bin axis
|
||||
_minmax = minmax[0] - 1, minmax[1] + 1
|
||||
prev_xedge = np.arange(start=_minmax[0],
|
||||
step=(_minmax[1] - _minmax[0]) / (self._hist_x_granularity - 2), stop=_minmax[1])
|
||||
# uniformly select histograms and the last one
|
||||
cur_idx = self._sample_histograms(hist_iters, self._histogram_granularity)
|
||||
report_hist = np.zeros(shape=(len(cur_idx), prev_xedge.size), dtype=np.float32)
|
||||
for i, n in enumerate(cur_idx):
|
||||
h = hist_list[n]
|
||||
report_hist[i, :] = np.interp(prev_xedge, h[:, 0], h[:, 1], right=0, left=0)
|
||||
yedges = hist_iters[cur_idx]
|
||||
xedges = prev_xedge
|
||||
|
||||
# if only a single line make, add another zero line, for the scatter plot to draw
|
||||
if report_hist.shape[0] < 2:
|
||||
report_hist = np.vstack((np.zeros_like(report_hist), report_hist))
|
||||
|
||||
# create 3d line (scatter) of histograms
|
||||
skipx = max(1, int(xedges.size / 10))
|
||||
skipy = max(1, int(yedges.size / 10))
|
||||
xlabels = ['%.2f' % v if i % skipx == 0 else '' for i, v in enumerate(xedges[:-1])]
|
||||
ylabels = [str(int(v)) if i % skipy == 0 else '' for i, v in enumerate(yedges)]
|
||||
self._logger.report_surface(
|
||||
title=title,
|
||||
series=series,
|
||||
iteration=0,
|
||||
xaxis=' ',
|
||||
yaxis='iteration',
|
||||
xlabels=xlabels,
|
||||
ylabels=ylabels,
|
||||
matrix=report_hist,
|
||||
camera=(-0.1, +1.3, 1.4))
|
||||
|
||||
|
||||
class EventTrainsWriter(object):
|
||||
"""
|
||||
TF SummaryWriter implementation that converts the tensorboard's summary into
|
||||
@ -190,6 +320,12 @@ class EventTrainsWriter(object):
|
||||
self._max_step = 0
|
||||
self._graph_name_lookup = {}
|
||||
self._generic_tensor_type_name_lookup = {}
|
||||
self._grad_helper = WeightsGradientHistHelper(
|
||||
logger=logger,
|
||||
report_freq=report_freq,
|
||||
histogram_update_freq_multiplier=histogram_update_freq_multiplier,
|
||||
histogram_granularity=histogram_granularity
|
||||
)
|
||||
|
||||
def _decode_image(self, img_str, width, height, color_channels):
|
||||
# noinspection PyBroadException
|
||||
@ -318,103 +454,15 @@ class EventTrainsWriter(object):
|
||||
)
|
||||
|
||||
def _add_histogram(self, tag, step, histo_data):
|
||||
def _sample_histograms(_hist_iters, _histogram_granularity):
|
||||
# resample history based on distribution of samples across time (steps)
|
||||
ratio = ((_hist_iters[-1] - _hist_iters[_histogram_granularity]) /
|
||||
(_hist_iters[_histogram_granularity - 1] - _hist_iters[0])) if \
|
||||
_hist_iters.size > _histogram_granularity else 0.
|
||||
cur_idx_below = np.arange(0, min(_hist_iters.size, _histogram_granularity - 1))
|
||||
np.random.shuffle(cur_idx_below)
|
||||
cur_idx_below = cur_idx_below[:int(_histogram_granularity * (1.0 - ratio / (1 + ratio)) + 0.5)]
|
||||
if ratio > 0.0:
|
||||
cur_idx_above = np.arange(_histogram_granularity - 1, _hist_iters.size)
|
||||
np.random.shuffle(cur_idx_above)
|
||||
cur_idx_above = cur_idx_above[:int(_histogram_granularity * ratio / (1 + ratio))]
|
||||
else:
|
||||
cur_idx_above = np.array([])
|
||||
_cur_idx = np.unique(np.sort(np.concatenate((cur_idx_below, cur_idx_above)).astype(np.int)))
|
||||
return _cur_idx
|
||||
|
||||
# only collect histogram every specific interval
|
||||
self._histogram_update_call_counter += 1
|
||||
if self._histogram_update_call_counter % self.report_freq != 0 or \
|
||||
self._histogram_update_call_counter < self.report_freq - 1:
|
||||
return None
|
||||
|
||||
# generate forward matrix of the histograms
|
||||
# Y-axis (rows) is iteration (from 0 to current Step)
|
||||
# X-axis averaged bins (conformed sample 'bucketLimit')
|
||||
# Z-axis actual value (interpolated 'bucket')
|
||||
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms',
|
||||
logdir_header='series')
|
||||
step = self._fix_step_counter(title, series, step)
|
||||
|
||||
# get histograms from cache
|
||||
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
|
||||
|
||||
# resample data so we are always constrained in number of histogram we keep
|
||||
if hist_iters.size >= self.histogram_granularity ** 2:
|
||||
idx = _sample_histograms(hist_iters, self.histogram_granularity)
|
||||
hist_iters = hist_iters[idx]
|
||||
hist_list = [hist_list[i] for i in idx]
|
||||
|
||||
# check if current sample is not already here (actually happens some times)
|
||||
if step in hist_iters:
|
||||
return None
|
||||
|
||||
# add current sample, if not already here
|
||||
hist_iters = np.append(hist_iters, step)
|
||||
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
||||
# histo_data['bucket'] is the histogram height, meaning the Y axis
|
||||
hist = np.array(list(zip(histo_data['bucketLimit'], histo_data['bucket'])), dtype=np.float32)
|
||||
hist = hist[~np.isinf(hist[:, 0]), :]
|
||||
hist_list.append(hist)
|
||||
# keep track of min/max values of histograms (for later re-binning)
|
||||
if minmax is None:
|
||||
minmax = hist[:, 0].min(), hist[:, 0].max()
|
||||
else:
|
||||
minmax = min(minmax[0], hist[:, 0].min()), max(minmax[1], hist[:, 0].max())
|
||||
|
||||
# update the cache
|
||||
self._hist_report_cache[(title, series)] = hist_list, hist_iters, minmax
|
||||
|
||||
# only report histogram every specific interval, but do report the first few, so you know there are histograms
|
||||
if hist_iters.size < 1 or (hist_iters.size >= self.histogram_update_freq_multiplier and
|
||||
hist_iters.size % self.histogram_update_freq_multiplier != 0):
|
||||
return None
|
||||
|
||||
# resample histograms on a unified bin axis
|
||||
_minmax = minmax[0] - 1, minmax[1] + 1
|
||||
prev_xedge = np.arange(start=_minmax[0],
|
||||
step=(_minmax[1] - _minmax[0]) / (self._hist_x_granularity - 2), stop=_minmax[1])
|
||||
# uniformly select histograms and the last one
|
||||
cur_idx = _sample_histograms(hist_iters, self.histogram_granularity)
|
||||
report_hist = np.zeros(shape=(len(cur_idx), prev_xedge.size), dtype=np.float32)
|
||||
for i, n in enumerate(cur_idx):
|
||||
h = hist_list[n]
|
||||
report_hist[i, :] = np.interp(prev_xedge, h[:, 0], h[:, 1], right=0, left=0)
|
||||
yedges = hist_iters[cur_idx]
|
||||
xedges = prev_xedge
|
||||
|
||||
# if only a single line make, add another zero line, for the scatter plot to draw
|
||||
if report_hist.shape[0] < 2:
|
||||
report_hist = np.vstack((np.zeros_like(report_hist), report_hist))
|
||||
|
||||
# create 3d line (scatter) of histograms
|
||||
skipx = max(1, int(xedges.size / 10))
|
||||
skipy = max(1, int(yedges.size / 10))
|
||||
xlabels = ['%.2f' % v if i % skipx == 0 else '' for i, v in enumerate(xedges[:-1])]
|
||||
ylabels = [str(int(v)) if i % skipy == 0 else '' for i, v in enumerate(yedges)]
|
||||
self._logger.report_surface(
|
||||
self._grad_helper.add_histogram(
|
||||
title=title,
|
||||
series=series,
|
||||
iteration=0,
|
||||
xaxis=' ',
|
||||
yaxis='iteration',
|
||||
xlabels=xlabels,
|
||||
ylabels=ylabels,
|
||||
matrix=report_hist,
|
||||
camera=(-0.1, +1.3, 1.4))
|
||||
step=step,
|
||||
hist_data=histo_data
|
||||
)
|
||||
|
||||
def _add_plot(self, tag, step, values, vdict):
|
||||
# noinspection PyBroadException
|
||||
@ -490,7 +538,8 @@ class EventTrainsWriter(object):
|
||||
max_history=self.max_keep_images,
|
||||
)
|
||||
|
||||
def _fix_step_counter(self, title, series, step):
|
||||
@staticmethod
|
||||
def _fix_step_counter(title, series, step):
|
||||
key = (title, series)
|
||||
if key not in EventTrainsWriter._title_series_wraparound_counter:
|
||||
EventTrainsWriter._title_series_wraparound_counter[key] = {'first_step': step, 'last_step': step,
|
||||
@ -1094,9 +1143,10 @@ class PatchTensorFlowEager(object):
|
||||
PatchTensorFlowEager._add_image_event_helper(event_writer, img_data_np=img_data_np,
|
||||
tag=tag, step=step, **kwargs)
|
||||
elif plugin_type.endswith('histograms'):
|
||||
PatchTensorFlowEager._add_histogram_event_helper(
|
||||
event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
hist_data=tensor.numpy())
|
||||
event_writer._add_histogram(
|
||||
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
hist_data=tensor.numpy()
|
||||
)
|
||||
elif 'audio' in plugin_type:
|
||||
audio_bytes_list = [a for a in tensor.numpy().flatten() if a]
|
||||
for i, audio_bytes in enumerate(audio_bytes_list):
|
||||
@ -1126,9 +1176,10 @@ class PatchTensorFlowEager(object):
|
||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||
if event_writer:
|
||||
try:
|
||||
PatchTensorFlowEager._add_histogram_event_helper(
|
||||
event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
hist_data=values.numpy())
|
||||
event_writer._add_histogram(
|
||||
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
hist_data=values.numpy()
|
||||
)
|
||||
except Exception as ex:
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||
return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs)
|
||||
@ -1145,23 +1196,6 @@ class PatchTensorFlowEager(object):
|
||||
return PatchTensorFlowEager.__original_fn_image(writer, step, tag, tensor, bad_color, max_images, name,
|
||||
**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _add_histogram_event_helper(event_writer, hist_data, tag, step):
|
||||
if isinstance(hist_data, dict):
|
||||
event_writer._add_histogram(tag=tag, step=step, histo_data=hist_data)
|
||||
return
|
||||
|
||||
if isinstance(hist_data, np.ndarray):
|
||||
hist_data = np.histogram(hist_data)
|
||||
histo_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()}
|
||||
else:
|
||||
# prepare the dictionary, assume numpy
|
||||
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
||||
# histo_data['bucket'] is the histogram height, meaning the Y axis
|
||||
# notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side
|
||||
histo_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()}
|
||||
event_writer._add_histogram(tag=tag, step=step, histo_data=histo_data)
|
||||
|
||||
@staticmethod
|
||||
def _add_image_event_helper(event_writer, img_data_np, tag, step, **kwargs):
|
||||
if img_data_np.ndim == 1 and img_data_np.size >= 3 and \
|
||||
|
Loading…
Reference in New Issue
Block a user