Refactor histogram code for PyTorch Ignite integration

This commit is contained in:
allegroai 2020-05-08 22:12:49 +03:00
parent 966cd6118a
commit 0298b84030
2 changed files with 156 additions and 117 deletions

View File

@ -54,12 +54,17 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
elif hasattr(f, 'as_posix'):
filename = f.as_posix()
elif hasattr(f, 'name'):
filename = f.name
# noinspection PyBroadException
try:
f.flush()
except Exception:
pass
if not isinstance(f.name, six.string_types):
# Probably a BufferedRandom object that has no meaningful name (still no harm flushing)
return ret
filename = f.name
else:
filename = None
except Exception:

View File

@ -51,6 +51,136 @@ class IsTensorboardInit(object):
return original_init(self, *args, **kwargs)
class WeightsGradientHistHelper(object):
def __init__(self, logger, report_freq=100, histogram_update_freq_multiplier=10, histogram_granularity=50):
self._logger = logger
self.report_freq = report_freq
self._histogram_granularity = histogram_granularity
self._histogram_update_freq_multiplier = histogram_update_freq_multiplier
self._histogram_update_call_counter = 0
self._hist_report_cache = {}
self._hist_x_granularity = 50
@staticmethod
def _sample_histograms(_hist_iters, _histogram_granularity):
# re-sample history based on distribution of samples across time (steps)
ratio = ((_hist_iters[-1] - _hist_iters[_histogram_granularity]) /
(_hist_iters[_histogram_granularity - 1] - _hist_iters[0])) if \
_hist_iters.size > _histogram_granularity else 0.
cur_idx_below = np.arange(0, min(_hist_iters.size, _histogram_granularity - 1))
np.random.shuffle(cur_idx_below)
cur_idx_below = cur_idx_below[:int(_histogram_granularity * (1.0 - ratio / (1 + ratio)) + 0.5)]
if ratio > 0.0:
cur_idx_above = np.arange(_histogram_granularity - 1, _hist_iters.size)
np.random.shuffle(cur_idx_above)
cur_idx_above = cur_idx_above[:int(_histogram_granularity * ratio / (1 + ratio))]
else:
cur_idx_above = np.array([])
_cur_idx = np.unique(np.sort(np.concatenate((cur_idx_below, cur_idx_above)).astype(np.int)))
return _cur_idx
def add_histogram(self, title, series, step, hist_data):
# only collect histogram every specific interval
self._histogram_update_call_counter += 1
if self._histogram_update_call_counter % self.report_freq != 0 or \
self._histogram_update_call_counter < self.report_freq - 1:
return None
if isinstance(hist_data, dict):
pass
elif isinstance(hist_data, np.ndarray):
hist_data = np.histogram(hist_data)
hist_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()}
else:
# prepare the dictionary, assume numpy
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
# histo_data['bucket'] is the histogram height, meaning the Y axis
# notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side
hist_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()}
self._add_histogram(title=title, series=series, step=step, hist_data=hist_data)
def _add_histogram(self, title, series, step, hist_data):
# only collect histogram every specific interval
self._histogram_update_call_counter += 1
if self._histogram_update_call_counter % self.report_freq != 0 or \
self._histogram_update_call_counter < self.report_freq - 1:
return None
# generate forward matrix of the histograms
# Y-axis (rows) is iteration (from 0 to current Step)
# X-axis averaged bins (conformed sample 'bucketLimit')
# Z-axis actual value (interpolated 'bucket')
step = EventTrainsWriter._fix_step_counter(title, series, step)
# get histograms from cache
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
# resample data so we are always constrained in number of histogram we keep
if hist_iters.size >= self._histogram_granularity ** 2:
idx = self._sample_histograms(hist_iters, self._histogram_granularity)
hist_iters = hist_iters[idx]
hist_list = [hist_list[i] for i in idx]
# check if current sample is not already here (actually happens some times)
if step in hist_iters:
return None
# add current sample, if not already here
hist_iters = np.append(hist_iters, step)
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
# histo_data['bucket'] is the histogram height, meaning the Y axis
hist = np.array(list(zip(hist_data['bucketLimit'], hist_data['bucket'])), dtype=np.float32)
hist = hist[~np.isinf(hist[:, 0]), :]
hist_list.append(hist)
# keep track of min/max values of histograms (for later re-binning)
if minmax is None:
minmax = hist[:, 0].min(), hist[:, 0].max()
else:
minmax = min(minmax[0], hist[:, 0].min()), max(minmax[1], hist[:, 0].max())
# update the cache
self._hist_report_cache[(title, series)] = hist_list, hist_iters, minmax
# only report histogram every specific interval, but do report the first few, so you know there are histograms
if hist_iters.size < 1 or (hist_iters.size >= self._histogram_update_freq_multiplier and
hist_iters.size % self._histogram_update_freq_multiplier != 0):
return None
# resample histograms on a unified bin axis
_minmax = minmax[0] - 1, minmax[1] + 1
prev_xedge = np.arange(start=_minmax[0],
step=(_minmax[1] - _minmax[0]) / (self._hist_x_granularity - 2), stop=_minmax[1])
# uniformly select histograms and the last one
cur_idx = self._sample_histograms(hist_iters, self._histogram_granularity)
report_hist = np.zeros(shape=(len(cur_idx), prev_xedge.size), dtype=np.float32)
for i, n in enumerate(cur_idx):
h = hist_list[n]
report_hist[i, :] = np.interp(prev_xedge, h[:, 0], h[:, 1], right=0, left=0)
yedges = hist_iters[cur_idx]
xedges = prev_xedge
# if only a single line make, add another zero line, for the scatter plot to draw
if report_hist.shape[0] < 2:
report_hist = np.vstack((np.zeros_like(report_hist), report_hist))
# create 3d line (scatter) of histograms
skipx = max(1, int(xedges.size / 10))
skipy = max(1, int(yedges.size / 10))
xlabels = ['%.2f' % v if i % skipx == 0 else '' for i, v in enumerate(xedges[:-1])]
ylabels = [str(int(v)) if i % skipy == 0 else '' for i, v in enumerate(yedges)]
self._logger.report_surface(
title=title,
series=series,
iteration=0,
xaxis=' ',
yaxis='iteration',
xlabels=xlabels,
ylabels=ylabels,
matrix=report_hist,
camera=(-0.1, +1.3, 1.4))
class EventTrainsWriter(object):
"""
TF SummaryWriter implementation that converts the tensorboard's summary into
@ -190,6 +320,12 @@ class EventTrainsWriter(object):
self._max_step = 0
self._graph_name_lookup = {}
self._generic_tensor_type_name_lookup = {}
self._grad_helper = WeightsGradientHistHelper(
logger=logger,
report_freq=report_freq,
histogram_update_freq_multiplier=histogram_update_freq_multiplier,
histogram_granularity=histogram_granularity
)
def _decode_image(self, img_str, width, height, color_channels):
# noinspection PyBroadException
@ -318,103 +454,15 @@ class EventTrainsWriter(object):
)
def _add_histogram(self, tag, step, histo_data):
def _sample_histograms(_hist_iters, _histogram_granularity):
# resample history based on distribution of samples across time (steps)
ratio = ((_hist_iters[-1] - _hist_iters[_histogram_granularity]) /
(_hist_iters[_histogram_granularity - 1] - _hist_iters[0])) if \
_hist_iters.size > _histogram_granularity else 0.
cur_idx_below = np.arange(0, min(_hist_iters.size, _histogram_granularity - 1))
np.random.shuffle(cur_idx_below)
cur_idx_below = cur_idx_below[:int(_histogram_granularity * (1.0 - ratio / (1 + ratio)) + 0.5)]
if ratio > 0.0:
cur_idx_above = np.arange(_histogram_granularity - 1, _hist_iters.size)
np.random.shuffle(cur_idx_above)
cur_idx_above = cur_idx_above[:int(_histogram_granularity * ratio / (1 + ratio))]
else:
cur_idx_above = np.array([])
_cur_idx = np.unique(np.sort(np.concatenate((cur_idx_below, cur_idx_above)).astype(np.int)))
return _cur_idx
# only collect histogram every specific interval
self._histogram_update_call_counter += 1
if self._histogram_update_call_counter % self.report_freq != 0 or \
self._histogram_update_call_counter < self.report_freq - 1:
return None
# generate forward matrix of the histograms
# Y-axis (rows) is iteration (from 0 to current Step)
# X-axis averaged bins (conformed sample 'bucketLimit')
# Z-axis actual value (interpolated 'bucket')
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms',
logdir_header='series')
step = self._fix_step_counter(title, series, step)
# get histograms from cache
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
# resample data so we are always constrained in number of histogram we keep
if hist_iters.size >= self.histogram_granularity ** 2:
idx = _sample_histograms(hist_iters, self.histogram_granularity)
hist_iters = hist_iters[idx]
hist_list = [hist_list[i] for i in idx]
# check if current sample is not already here (actually happens some times)
if step in hist_iters:
return None
# add current sample, if not already here
hist_iters = np.append(hist_iters, step)
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
# histo_data['bucket'] is the histogram height, meaning the Y axis
hist = np.array(list(zip(histo_data['bucketLimit'], histo_data['bucket'])), dtype=np.float32)
hist = hist[~np.isinf(hist[:, 0]), :]
hist_list.append(hist)
# keep track of min/max values of histograms (for later re-binning)
if minmax is None:
minmax = hist[:, 0].min(), hist[:, 0].max()
else:
minmax = min(minmax[0], hist[:, 0].min()), max(minmax[1], hist[:, 0].max())
# update the cache
self._hist_report_cache[(title, series)] = hist_list, hist_iters, minmax
# only report histogram every specific interval, but do report the first few, so you know there are histograms
if hist_iters.size < 1 or (hist_iters.size >= self.histogram_update_freq_multiplier and
hist_iters.size % self.histogram_update_freq_multiplier != 0):
return None
# resample histograms on a unified bin axis
_minmax = minmax[0] - 1, minmax[1] + 1
prev_xedge = np.arange(start=_minmax[0],
step=(_minmax[1] - _minmax[0]) / (self._hist_x_granularity - 2), stop=_minmax[1])
# uniformly select histograms and the last one
cur_idx = _sample_histograms(hist_iters, self.histogram_granularity)
report_hist = np.zeros(shape=(len(cur_idx), prev_xedge.size), dtype=np.float32)
for i, n in enumerate(cur_idx):
h = hist_list[n]
report_hist[i, :] = np.interp(prev_xedge, h[:, 0], h[:, 1], right=0, left=0)
yedges = hist_iters[cur_idx]
xedges = prev_xedge
# if only a single line make, add another zero line, for the scatter plot to draw
if report_hist.shape[0] < 2:
report_hist = np.vstack((np.zeros_like(report_hist), report_hist))
# create 3d line (scatter) of histograms
skipx = max(1, int(xedges.size / 10))
skipy = max(1, int(yedges.size / 10))
xlabels = ['%.2f' % v if i % skipx == 0 else '' for i, v in enumerate(xedges[:-1])]
ylabels = [str(int(v)) if i % skipy == 0 else '' for i, v in enumerate(yedges)]
self._logger.report_surface(
self._grad_helper.add_histogram(
title=title,
series=series,
iteration=0,
xaxis=' ',
yaxis='iteration',
xlabels=xlabels,
ylabels=ylabels,
matrix=report_hist,
camera=(-0.1, +1.3, 1.4))
step=step,
hist_data=histo_data
)
def _add_plot(self, tag, step, values, vdict):
# noinspection PyBroadException
@ -490,7 +538,8 @@ class EventTrainsWriter(object):
max_history=self.max_keep_images,
)
def _fix_step_counter(self, title, series, step):
@staticmethod
def _fix_step_counter(title, series, step):
key = (title, series)
if key not in EventTrainsWriter._title_series_wraparound_counter:
EventTrainsWriter._title_series_wraparound_counter[key] = {'first_step': step, 'last_step': step,
@ -1094,9 +1143,10 @@ class PatchTensorFlowEager(object):
PatchTensorFlowEager._add_image_event_helper(event_writer, img_data_np=img_data_np,
tag=tag, step=step, **kwargs)
elif plugin_type.endswith('histograms'):
PatchTensorFlowEager._add_histogram_event_helper(
event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
hist_data=tensor.numpy())
event_writer._add_histogram(
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
hist_data=tensor.numpy()
)
elif 'audio' in plugin_type:
audio_bytes_list = [a for a in tensor.numpy().flatten() if a]
for i, audio_bytes in enumerate(audio_bytes_list):
@ -1126,9 +1176,10 @@ class PatchTensorFlowEager(object):
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer:
try:
PatchTensorFlowEager._add_histogram_event_helper(
event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
hist_data=values.numpy())
event_writer._add_histogram(
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
hist_data=values.numpy()
)
except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs)
@ -1145,23 +1196,6 @@ class PatchTensorFlowEager(object):
return PatchTensorFlowEager.__original_fn_image(writer, step, tag, tensor, bad_color, max_images, name,
**kwargs)
@staticmethod
def _add_histogram_event_helper(event_writer, hist_data, tag, step):
if isinstance(hist_data, dict):
event_writer._add_histogram(tag=tag, step=step, histo_data=hist_data)
return
if isinstance(hist_data, np.ndarray):
hist_data = np.histogram(hist_data)
histo_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()}
else:
# prepare the dictionary, assume numpy
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
# histo_data['bucket'] is the histogram height, meaning the Y axis
# notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side
histo_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()}
event_writer._add_histogram(tag=tag, step=step, histo_data=histo_data)
@staticmethod
def _add_image_event_helper(event_writer, img_data_np, tag, step, **kwargs):
if img_data_np.ndim == 1 and img_data_np.size >= 3 and \