mirror of
https://github.com/clearml/clearml
synced 2025-05-29 17:48:33 +00:00
Refactor histogram code for PyTorch Ignite integration
This commit is contained in:
parent
966cd6118a
commit
0298b84030
@ -54,12 +54,17 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
elif hasattr(f, 'as_posix'):
|
elif hasattr(f, 'as_posix'):
|
||||||
filename = f.as_posix()
|
filename = f.as_posix()
|
||||||
elif hasattr(f, 'name'):
|
elif hasattr(f, 'name'):
|
||||||
filename = f.name
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
f.flush()
|
f.flush()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if not isinstance(f.name, six.string_types):
|
||||||
|
# Probably a BufferedRandom object that has no meaningful name (still no harm flushing)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
filename = f.name
|
||||||
else:
|
else:
|
||||||
filename = None
|
filename = None
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -51,6 +51,136 @@ class IsTensorboardInit(object):
|
|||||||
return original_init(self, *args, **kwargs)
|
return original_init(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class WeightsGradientHistHelper(object):
|
||||||
|
def __init__(self, logger, report_freq=100, histogram_update_freq_multiplier=10, histogram_granularity=50):
|
||||||
|
self._logger = logger
|
||||||
|
self.report_freq = report_freq
|
||||||
|
self._histogram_granularity = histogram_granularity
|
||||||
|
self._histogram_update_freq_multiplier = histogram_update_freq_multiplier
|
||||||
|
self._histogram_update_call_counter = 0
|
||||||
|
self._hist_report_cache = {}
|
||||||
|
self._hist_x_granularity = 50
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sample_histograms(_hist_iters, _histogram_granularity):
|
||||||
|
# re-sample history based on distribution of samples across time (steps)
|
||||||
|
ratio = ((_hist_iters[-1] - _hist_iters[_histogram_granularity]) /
|
||||||
|
(_hist_iters[_histogram_granularity - 1] - _hist_iters[0])) if \
|
||||||
|
_hist_iters.size > _histogram_granularity else 0.
|
||||||
|
cur_idx_below = np.arange(0, min(_hist_iters.size, _histogram_granularity - 1))
|
||||||
|
np.random.shuffle(cur_idx_below)
|
||||||
|
cur_idx_below = cur_idx_below[:int(_histogram_granularity * (1.0 - ratio / (1 + ratio)) + 0.5)]
|
||||||
|
if ratio > 0.0:
|
||||||
|
cur_idx_above = np.arange(_histogram_granularity - 1, _hist_iters.size)
|
||||||
|
np.random.shuffle(cur_idx_above)
|
||||||
|
cur_idx_above = cur_idx_above[:int(_histogram_granularity * ratio / (1 + ratio))]
|
||||||
|
else:
|
||||||
|
cur_idx_above = np.array([])
|
||||||
|
_cur_idx = np.unique(np.sort(np.concatenate((cur_idx_below, cur_idx_above)).astype(np.int)))
|
||||||
|
return _cur_idx
|
||||||
|
|
||||||
|
def add_histogram(self, title, series, step, hist_data):
|
||||||
|
# only collect histogram every specific interval
|
||||||
|
self._histogram_update_call_counter += 1
|
||||||
|
if self._histogram_update_call_counter % self.report_freq != 0 or \
|
||||||
|
self._histogram_update_call_counter < self.report_freq - 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(hist_data, dict):
|
||||||
|
pass
|
||||||
|
elif isinstance(hist_data, np.ndarray):
|
||||||
|
hist_data = np.histogram(hist_data)
|
||||||
|
hist_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()}
|
||||||
|
else:
|
||||||
|
# prepare the dictionary, assume numpy
|
||||||
|
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
||||||
|
# histo_data['bucket'] is the histogram height, meaning the Y axis
|
||||||
|
# notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side
|
||||||
|
hist_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()}
|
||||||
|
|
||||||
|
self._add_histogram(title=title, series=series, step=step, hist_data=hist_data)
|
||||||
|
|
||||||
|
def _add_histogram(self, title, series, step, hist_data):
|
||||||
|
# only collect histogram every specific interval
|
||||||
|
self._histogram_update_call_counter += 1
|
||||||
|
if self._histogram_update_call_counter % self.report_freq != 0 or \
|
||||||
|
self._histogram_update_call_counter < self.report_freq - 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# generate forward matrix of the histograms
|
||||||
|
# Y-axis (rows) is iteration (from 0 to current Step)
|
||||||
|
# X-axis averaged bins (conformed sample 'bucketLimit')
|
||||||
|
# Z-axis actual value (interpolated 'bucket')
|
||||||
|
step = EventTrainsWriter._fix_step_counter(title, series, step)
|
||||||
|
|
||||||
|
# get histograms from cache
|
||||||
|
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
|
||||||
|
|
||||||
|
# resample data so we are always constrained in number of histogram we keep
|
||||||
|
if hist_iters.size >= self._histogram_granularity ** 2:
|
||||||
|
idx = self._sample_histograms(hist_iters, self._histogram_granularity)
|
||||||
|
hist_iters = hist_iters[idx]
|
||||||
|
hist_list = [hist_list[i] for i in idx]
|
||||||
|
|
||||||
|
# check if current sample is not already here (actually happens some times)
|
||||||
|
if step in hist_iters:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# add current sample, if not already here
|
||||||
|
hist_iters = np.append(hist_iters, step)
|
||||||
|
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
||||||
|
# histo_data['bucket'] is the histogram height, meaning the Y axis
|
||||||
|
hist = np.array(list(zip(hist_data['bucketLimit'], hist_data['bucket'])), dtype=np.float32)
|
||||||
|
hist = hist[~np.isinf(hist[:, 0]), :]
|
||||||
|
hist_list.append(hist)
|
||||||
|
# keep track of min/max values of histograms (for later re-binning)
|
||||||
|
if minmax is None:
|
||||||
|
minmax = hist[:, 0].min(), hist[:, 0].max()
|
||||||
|
else:
|
||||||
|
minmax = min(minmax[0], hist[:, 0].min()), max(minmax[1], hist[:, 0].max())
|
||||||
|
|
||||||
|
# update the cache
|
||||||
|
self._hist_report_cache[(title, series)] = hist_list, hist_iters, minmax
|
||||||
|
|
||||||
|
# only report histogram every specific interval, but do report the first few, so you know there are histograms
|
||||||
|
if hist_iters.size < 1 or (hist_iters.size >= self._histogram_update_freq_multiplier and
|
||||||
|
hist_iters.size % self._histogram_update_freq_multiplier != 0):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# resample histograms on a unified bin axis
|
||||||
|
_minmax = minmax[0] - 1, minmax[1] + 1
|
||||||
|
prev_xedge = np.arange(start=_minmax[0],
|
||||||
|
step=(_minmax[1] - _minmax[0]) / (self._hist_x_granularity - 2), stop=_minmax[1])
|
||||||
|
# uniformly select histograms and the last one
|
||||||
|
cur_idx = self._sample_histograms(hist_iters, self._histogram_granularity)
|
||||||
|
report_hist = np.zeros(shape=(len(cur_idx), prev_xedge.size), dtype=np.float32)
|
||||||
|
for i, n in enumerate(cur_idx):
|
||||||
|
h = hist_list[n]
|
||||||
|
report_hist[i, :] = np.interp(prev_xedge, h[:, 0], h[:, 1], right=0, left=0)
|
||||||
|
yedges = hist_iters[cur_idx]
|
||||||
|
xedges = prev_xedge
|
||||||
|
|
||||||
|
# if only a single line make, add another zero line, for the scatter plot to draw
|
||||||
|
if report_hist.shape[0] < 2:
|
||||||
|
report_hist = np.vstack((np.zeros_like(report_hist), report_hist))
|
||||||
|
|
||||||
|
# create 3d line (scatter) of histograms
|
||||||
|
skipx = max(1, int(xedges.size / 10))
|
||||||
|
skipy = max(1, int(yedges.size / 10))
|
||||||
|
xlabels = ['%.2f' % v if i % skipx == 0 else '' for i, v in enumerate(xedges[:-1])]
|
||||||
|
ylabels = [str(int(v)) if i % skipy == 0 else '' for i, v in enumerate(yedges)]
|
||||||
|
self._logger.report_surface(
|
||||||
|
title=title,
|
||||||
|
series=series,
|
||||||
|
iteration=0,
|
||||||
|
xaxis=' ',
|
||||||
|
yaxis='iteration',
|
||||||
|
xlabels=xlabels,
|
||||||
|
ylabels=ylabels,
|
||||||
|
matrix=report_hist,
|
||||||
|
camera=(-0.1, +1.3, 1.4))
|
||||||
|
|
||||||
|
|
||||||
class EventTrainsWriter(object):
|
class EventTrainsWriter(object):
|
||||||
"""
|
"""
|
||||||
TF SummaryWriter implementation that converts the tensorboard's summary into
|
TF SummaryWriter implementation that converts the tensorboard's summary into
|
||||||
@ -190,6 +320,12 @@ class EventTrainsWriter(object):
|
|||||||
self._max_step = 0
|
self._max_step = 0
|
||||||
self._graph_name_lookup = {}
|
self._graph_name_lookup = {}
|
||||||
self._generic_tensor_type_name_lookup = {}
|
self._generic_tensor_type_name_lookup = {}
|
||||||
|
self._grad_helper = WeightsGradientHistHelper(
|
||||||
|
logger=logger,
|
||||||
|
report_freq=report_freq,
|
||||||
|
histogram_update_freq_multiplier=histogram_update_freq_multiplier,
|
||||||
|
histogram_granularity=histogram_granularity
|
||||||
|
)
|
||||||
|
|
||||||
def _decode_image(self, img_str, width, height, color_channels):
|
def _decode_image(self, img_str, width, height, color_channels):
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -318,103 +454,15 @@ class EventTrainsWriter(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _add_histogram(self, tag, step, histo_data):
|
def _add_histogram(self, tag, step, histo_data):
|
||||||
def _sample_histograms(_hist_iters, _histogram_granularity):
|
|
||||||
# resample history based on distribution of samples across time (steps)
|
|
||||||
ratio = ((_hist_iters[-1] - _hist_iters[_histogram_granularity]) /
|
|
||||||
(_hist_iters[_histogram_granularity - 1] - _hist_iters[0])) if \
|
|
||||||
_hist_iters.size > _histogram_granularity else 0.
|
|
||||||
cur_idx_below = np.arange(0, min(_hist_iters.size, _histogram_granularity - 1))
|
|
||||||
np.random.shuffle(cur_idx_below)
|
|
||||||
cur_idx_below = cur_idx_below[:int(_histogram_granularity * (1.0 - ratio / (1 + ratio)) + 0.5)]
|
|
||||||
if ratio > 0.0:
|
|
||||||
cur_idx_above = np.arange(_histogram_granularity - 1, _hist_iters.size)
|
|
||||||
np.random.shuffle(cur_idx_above)
|
|
||||||
cur_idx_above = cur_idx_above[:int(_histogram_granularity * ratio / (1 + ratio))]
|
|
||||||
else:
|
|
||||||
cur_idx_above = np.array([])
|
|
||||||
_cur_idx = np.unique(np.sort(np.concatenate((cur_idx_below, cur_idx_above)).astype(np.int)))
|
|
||||||
return _cur_idx
|
|
||||||
|
|
||||||
# only collect histogram every specific interval
|
|
||||||
self._histogram_update_call_counter += 1
|
|
||||||
if self._histogram_update_call_counter % self.report_freq != 0 or \
|
|
||||||
self._histogram_update_call_counter < self.report_freq - 1:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# generate forward matrix of the histograms
|
|
||||||
# Y-axis (rows) is iteration (from 0 to current Step)
|
|
||||||
# X-axis averaged bins (conformed sample 'bucketLimit')
|
|
||||||
# Z-axis actual value (interpolated 'bucket')
|
|
||||||
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms',
|
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms',
|
||||||
logdir_header='series')
|
logdir_header='series')
|
||||||
step = self._fix_step_counter(title, series, step)
|
|
||||||
|
|
||||||
# get histograms from cache
|
self._grad_helper.add_histogram(
|
||||||
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
|
|
||||||
|
|
||||||
# resample data so we are always constrained in number of histogram we keep
|
|
||||||
if hist_iters.size >= self.histogram_granularity ** 2:
|
|
||||||
idx = _sample_histograms(hist_iters, self.histogram_granularity)
|
|
||||||
hist_iters = hist_iters[idx]
|
|
||||||
hist_list = [hist_list[i] for i in idx]
|
|
||||||
|
|
||||||
# check if current sample is not already here (actually happens some times)
|
|
||||||
if step in hist_iters:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# add current sample, if not already here
|
|
||||||
hist_iters = np.append(hist_iters, step)
|
|
||||||
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
|
||||||
# histo_data['bucket'] is the histogram height, meaning the Y axis
|
|
||||||
hist = np.array(list(zip(histo_data['bucketLimit'], histo_data['bucket'])), dtype=np.float32)
|
|
||||||
hist = hist[~np.isinf(hist[:, 0]), :]
|
|
||||||
hist_list.append(hist)
|
|
||||||
# keep track of min/max values of histograms (for later re-binning)
|
|
||||||
if minmax is None:
|
|
||||||
minmax = hist[:, 0].min(), hist[:, 0].max()
|
|
||||||
else:
|
|
||||||
minmax = min(minmax[0], hist[:, 0].min()), max(minmax[1], hist[:, 0].max())
|
|
||||||
|
|
||||||
# update the cache
|
|
||||||
self._hist_report_cache[(title, series)] = hist_list, hist_iters, minmax
|
|
||||||
|
|
||||||
# only report histogram every specific interval, but do report the first few, so you know there are histograms
|
|
||||||
if hist_iters.size < 1 or (hist_iters.size >= self.histogram_update_freq_multiplier and
|
|
||||||
hist_iters.size % self.histogram_update_freq_multiplier != 0):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# resample histograms on a unified bin axis
|
|
||||||
_minmax = minmax[0] - 1, minmax[1] + 1
|
|
||||||
prev_xedge = np.arange(start=_minmax[0],
|
|
||||||
step=(_minmax[1] - _minmax[0]) / (self._hist_x_granularity - 2), stop=_minmax[1])
|
|
||||||
# uniformly select histograms and the last one
|
|
||||||
cur_idx = _sample_histograms(hist_iters, self.histogram_granularity)
|
|
||||||
report_hist = np.zeros(shape=(len(cur_idx), prev_xedge.size), dtype=np.float32)
|
|
||||||
for i, n in enumerate(cur_idx):
|
|
||||||
h = hist_list[n]
|
|
||||||
report_hist[i, :] = np.interp(prev_xedge, h[:, 0], h[:, 1], right=0, left=0)
|
|
||||||
yedges = hist_iters[cur_idx]
|
|
||||||
xedges = prev_xedge
|
|
||||||
|
|
||||||
# if only a single line make, add another zero line, for the scatter plot to draw
|
|
||||||
if report_hist.shape[0] < 2:
|
|
||||||
report_hist = np.vstack((np.zeros_like(report_hist), report_hist))
|
|
||||||
|
|
||||||
# create 3d line (scatter) of histograms
|
|
||||||
skipx = max(1, int(xedges.size / 10))
|
|
||||||
skipy = max(1, int(yedges.size / 10))
|
|
||||||
xlabels = ['%.2f' % v if i % skipx == 0 else '' for i, v in enumerate(xedges[:-1])]
|
|
||||||
ylabels = [str(int(v)) if i % skipy == 0 else '' for i, v in enumerate(yedges)]
|
|
||||||
self._logger.report_surface(
|
|
||||||
title=title,
|
title=title,
|
||||||
series=series,
|
series=series,
|
||||||
iteration=0,
|
step=step,
|
||||||
xaxis=' ',
|
hist_data=histo_data
|
||||||
yaxis='iteration',
|
)
|
||||||
xlabels=xlabels,
|
|
||||||
ylabels=ylabels,
|
|
||||||
matrix=report_hist,
|
|
||||||
camera=(-0.1, +1.3, 1.4))
|
|
||||||
|
|
||||||
def _add_plot(self, tag, step, values, vdict):
|
def _add_plot(self, tag, step, values, vdict):
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -490,7 +538,8 @@ class EventTrainsWriter(object):
|
|||||||
max_history=self.max_keep_images,
|
max_history=self.max_keep_images,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _fix_step_counter(self, title, series, step):
|
@staticmethod
|
||||||
|
def _fix_step_counter(title, series, step):
|
||||||
key = (title, series)
|
key = (title, series)
|
||||||
if key not in EventTrainsWriter._title_series_wraparound_counter:
|
if key not in EventTrainsWriter._title_series_wraparound_counter:
|
||||||
EventTrainsWriter._title_series_wraparound_counter[key] = {'first_step': step, 'last_step': step,
|
EventTrainsWriter._title_series_wraparound_counter[key] = {'first_step': step, 'last_step': step,
|
||||||
@ -1094,9 +1143,10 @@ class PatchTensorFlowEager(object):
|
|||||||
PatchTensorFlowEager._add_image_event_helper(event_writer, img_data_np=img_data_np,
|
PatchTensorFlowEager._add_image_event_helper(event_writer, img_data_np=img_data_np,
|
||||||
tag=tag, step=step, **kwargs)
|
tag=tag, step=step, **kwargs)
|
||||||
elif plugin_type.endswith('histograms'):
|
elif plugin_type.endswith('histograms'):
|
||||||
PatchTensorFlowEager._add_histogram_event_helper(
|
event_writer._add_histogram(
|
||||||
event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||||
hist_data=tensor.numpy())
|
hist_data=tensor.numpy()
|
||||||
|
)
|
||||||
elif 'audio' in plugin_type:
|
elif 'audio' in plugin_type:
|
||||||
audio_bytes_list = [a for a in tensor.numpy().flatten() if a]
|
audio_bytes_list = [a for a in tensor.numpy().flatten() if a]
|
||||||
for i, audio_bytes in enumerate(audio_bytes_list):
|
for i, audio_bytes in enumerate(audio_bytes_list):
|
||||||
@ -1126,9 +1176,10 @@ class PatchTensorFlowEager(object):
|
|||||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||||
if event_writer:
|
if event_writer:
|
||||||
try:
|
try:
|
||||||
PatchTensorFlowEager._add_histogram_event_helper(
|
event_writer._add_histogram(
|
||||||
event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||||
hist_data=values.numpy())
|
hist_data=values.numpy()
|
||||||
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
|
||||||
return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs)
|
return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs)
|
||||||
@ -1145,23 +1196,6 @@ class PatchTensorFlowEager(object):
|
|||||||
return PatchTensorFlowEager.__original_fn_image(writer, step, tag, tensor, bad_color, max_images, name,
|
return PatchTensorFlowEager.__original_fn_image(writer, step, tag, tensor, bad_color, max_images, name,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _add_histogram_event_helper(event_writer, hist_data, tag, step):
|
|
||||||
if isinstance(hist_data, dict):
|
|
||||||
event_writer._add_histogram(tag=tag, step=step, histo_data=hist_data)
|
|
||||||
return
|
|
||||||
|
|
||||||
if isinstance(hist_data, np.ndarray):
|
|
||||||
hist_data = np.histogram(hist_data)
|
|
||||||
histo_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()}
|
|
||||||
else:
|
|
||||||
# prepare the dictionary, assume numpy
|
|
||||||
# histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
|
|
||||||
# histo_data['bucket'] is the histogram height, meaning the Y axis
|
|
||||||
# notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side
|
|
||||||
histo_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()}
|
|
||||||
event_writer._add_histogram(tag=tag, step=step, histo_data=histo_data)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _add_image_event_helper(event_writer, img_data_np, tag, step, **kwargs):
|
def _add_image_event_helper(event_writer, img_data_np, tag, step, **kwargs):
|
||||||
if img_data_np.ndim == 1 and img_data_np.size >= 3 and \
|
if img_data_np.ndim == 1 and img_data_np.size >= 3 and \
|
||||||
|
Loading…
Reference in New Issue
Block a user