From 0298b8403036dd65ceacabc61eb9fe8a08cb634d Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 8 May 2020 22:12:49 +0300 Subject: [PATCH] Refactor histogram code for PyTorch Ignite integration --- trains/binding/frameworks/pytorch_bind.py | 7 +- trains/binding/frameworks/tensorflow_bind.py | 266 +++++++++++-------- 2 files changed, 156 insertions(+), 117 deletions(-) diff --git a/trains/binding/frameworks/pytorch_bind.py b/trains/binding/frameworks/pytorch_bind.py index 9a80d4ed..da178967 100644 --- a/trains/binding/frameworks/pytorch_bind.py +++ b/trains/binding/frameworks/pytorch_bind.py @@ -54,12 +54,17 @@ class PatchPyTorchModelIO(PatchBaseModelIO): elif hasattr(f, 'as_posix'): filename = f.as_posix() elif hasattr(f, 'name'): - filename = f.name # noinspection PyBroadException try: f.flush() except Exception: pass + + if not isinstance(f.name, six.string_types): + # Probably a BufferedRandom object that has no meaningful name (still no harm flushing) + return ret + + filename = f.name else: filename = None except Exception: diff --git a/trains/binding/frameworks/tensorflow_bind.py b/trains/binding/frameworks/tensorflow_bind.py index fbc3a6bd..46f701ca 100644 --- a/trains/binding/frameworks/tensorflow_bind.py +++ b/trains/binding/frameworks/tensorflow_bind.py @@ -51,6 +51,136 @@ class IsTensorboardInit(object): return original_init(self, *args, **kwargs) +class WeightsGradientHistHelper(object): + def __init__(self, logger, report_freq=100, histogram_update_freq_multiplier=10, histogram_granularity=50): + self._logger = logger + self.report_freq = report_freq + self._histogram_granularity = histogram_granularity + self._histogram_update_freq_multiplier = histogram_update_freq_multiplier + self._histogram_update_call_counter = 0 + self._hist_report_cache = {} + self._hist_x_granularity = 50 + + @staticmethod + def _sample_histograms(_hist_iters, _histogram_granularity): + # re-sample history based on distribution of samples across time (steps) + ratio = ((_hist_iters[-1] - _hist_iters[_histogram_granularity]) / + (_hist_iters[_histogram_granularity - 1] - _hist_iters[0])) if \ + _hist_iters.size > _histogram_granularity else 0. + cur_idx_below = np.arange(0, min(_hist_iters.size, _histogram_granularity - 1)) + np.random.shuffle(cur_idx_below) + cur_idx_below = cur_idx_below[:int(_histogram_granularity * (1.0 - ratio / (1 + ratio)) + 0.5)] + if ratio > 0.0: + cur_idx_above = np.arange(_histogram_granularity - 1, _hist_iters.size) + np.random.shuffle(cur_idx_above) + cur_idx_above = cur_idx_above[:int(_histogram_granularity * ratio / (1 + ratio))] + else: + cur_idx_above = np.array([]) + _cur_idx = np.unique(np.sort(np.concatenate((cur_idx_below, cur_idx_above)).astype(np.int))) + return _cur_idx + + def add_histogram(self, title, series, step, hist_data): + # only collect histogram every specific interval + self._histogram_update_call_counter += 1 + if self._histogram_update_call_counter % self.report_freq != 0 or \ + self._histogram_update_call_counter < self.report_freq - 1: + return None + + if isinstance(hist_data, dict): + pass + elif isinstance(hist_data, np.ndarray): + hist_data = np.histogram(hist_data) + hist_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()} + else: + # prepare the dictionary, assume numpy + # histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis + # histo_data['bucket'] is the histogram height, meaning the Y axis + # notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side + hist_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()} + + self._add_histogram(title=title, series=series, step=step, hist_data=hist_data) + + def _add_histogram(self, title, series, step, hist_data): + # only collect histogram every specific interval + self._histogram_update_call_counter += 1 + if self._histogram_update_call_counter % self.report_freq != 0 or \ + self._histogram_update_call_counter < self.report_freq - 1: + return None + + # generate forward matrix of the histograms + # Y-axis (rows) is iteration (from 0 to current Step) + # X-axis averaged bins (conformed sample 'bucketLimit') + # Z-axis actual value (interpolated 'bucket') + step = EventTrainsWriter._fix_step_counter(title, series, step) + + # get histograms from cache + hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None)) + + # resample data so we are always constrained in number of histogram we keep + if hist_iters.size >= self._histogram_granularity ** 2: + idx = self._sample_histograms(hist_iters, self._histogram_granularity) + hist_iters = hist_iters[idx] + hist_list = [hist_list[i] for i in idx] + + # check if current sample is not already here (actually happens some times) + if step in hist_iters: + return None + + # add current sample, if not already here + hist_iters = np.append(hist_iters, step) + # histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis + # histo_data['bucket'] is the histogram height, meaning the Y axis + hist = np.array(list(zip(hist_data['bucketLimit'], hist_data['bucket'])), dtype=np.float32) + hist = hist[~np.isinf(hist[:, 0]), :] + hist_list.append(hist) + # keep track of min/max values of histograms (for later re-binning) + if minmax is None: + minmax = hist[:, 0].min(), hist[:, 0].max() + else: + minmax = min(minmax[0], hist[:, 0].min()), max(minmax[1], hist[:, 0].max()) + + # update the cache + self._hist_report_cache[(title, series)] = hist_list, hist_iters, minmax + + # only report histogram every specific interval, but do report the first few, so you know there are histograms + if hist_iters.size < 1 or (hist_iters.size >= self._histogram_update_freq_multiplier and + hist_iters.size % self._histogram_update_freq_multiplier != 0): + return None + + # resample histograms on a unified bin axis + _minmax = minmax[0] - 1, minmax[1] + 1 + prev_xedge = np.arange(start=_minmax[0], + step=(_minmax[1] - _minmax[0]) / (self._hist_x_granularity - 2), stop=_minmax[1]) + # uniformly select histograms and the last one + cur_idx = self._sample_histograms(hist_iters, self._histogram_granularity) + report_hist = np.zeros(shape=(len(cur_idx), prev_xedge.size), dtype=np.float32) + for i, n in enumerate(cur_idx): + h = hist_list[n] + report_hist[i, :] = np.interp(prev_xedge, h[:, 0], h[:, 1], right=0, left=0) + yedges = hist_iters[cur_idx] + xedges = prev_xedge + + # if only a single line make, add another zero line, for the scatter plot to draw + if report_hist.shape[0] < 2: + report_hist = np.vstack((np.zeros_like(report_hist), report_hist)) + + # create 3d line (scatter) of histograms + skipx = max(1, int(xedges.size / 10)) + skipy = max(1, int(yedges.size / 10)) + xlabels = ['%.2f' % v if i % skipx == 0 else '' for i, v in enumerate(xedges[:-1])] + ylabels = [str(int(v)) if i % skipy == 0 else '' for i, v in enumerate(yedges)] + self._logger.report_surface( + title=title, + series=series, + iteration=0, + xaxis=' ', + yaxis='iteration', + xlabels=xlabels, + ylabels=ylabels, + matrix=report_hist, + camera=(-0.1, +1.3, 1.4)) + + class EventTrainsWriter(object): """ TF SummaryWriter implementation that converts the tensorboard's summary into @@ -190,6 +320,12 @@ class EventTrainsWriter(object): self._max_step = 0 self._graph_name_lookup = {} self._generic_tensor_type_name_lookup = {} + self._grad_helper = WeightsGradientHistHelper( + logger=logger, + report_freq=report_freq, + histogram_update_freq_multiplier=histogram_update_freq_multiplier, + histogram_granularity=histogram_granularity + ) def _decode_image(self, img_str, width, height, color_channels): # noinspection PyBroadException @@ -318,103 +454,15 @@ class EventTrainsWriter(object): ) def _add_histogram(self, tag, step, histo_data): - def _sample_histograms(_hist_iters, _histogram_granularity): - # resample history based on distribution of samples across time (steps) - ratio = ((_hist_iters[-1] - _hist_iters[_histogram_granularity]) / - (_hist_iters[_histogram_granularity - 1] - _hist_iters[0])) if \ - _hist_iters.size > _histogram_granularity else 0. - cur_idx_below = np.arange(0, min(_hist_iters.size, _histogram_granularity - 1)) - np.random.shuffle(cur_idx_below) - cur_idx_below = cur_idx_below[:int(_histogram_granularity * (1.0 - ratio / (1 + ratio)) + 0.5)] - if ratio > 0.0: - cur_idx_above = np.arange(_histogram_granularity - 1, _hist_iters.size) - np.random.shuffle(cur_idx_above) - cur_idx_above = cur_idx_above[:int(_histogram_granularity * ratio / (1 + ratio))] - else: - cur_idx_above = np.array([]) - _cur_idx = np.unique(np.sort(np.concatenate((cur_idx_below, cur_idx_above)).astype(np.int))) - return _cur_idx - - # only collect histogram every specific interval - self._histogram_update_call_counter += 1 - if self._histogram_update_call_counter % self.report_freq != 0 or \ - self._histogram_update_call_counter < self.report_freq - 1: - return None - - # generate forward matrix of the histograms - # Y-axis (rows) is iteration (from 0 to current Step) - # X-axis averaged bins (conformed sample 'bucketLimit') - # Z-axis actual value (interpolated 'bucket') title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms', logdir_header='series') - step = self._fix_step_counter(title, series, step) - # get histograms from cache - hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None)) - - # resample data so we are always constrained in number of histogram we keep - if hist_iters.size >= self.histogram_granularity ** 2: - idx = _sample_histograms(hist_iters, self.histogram_granularity) - hist_iters = hist_iters[idx] - hist_list = [hist_list[i] for i in idx] - - # check if current sample is not already here (actually happens some times) - if step in hist_iters: - return None - - # add current sample, if not already here - hist_iters = np.append(hist_iters, step) - # histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis - # histo_data['bucket'] is the histogram height, meaning the Y axis - hist = np.array(list(zip(histo_data['bucketLimit'], histo_data['bucket'])), dtype=np.float32) - hist = hist[~np.isinf(hist[:, 0]), :] - hist_list.append(hist) - # keep track of min/max values of histograms (for later re-binning) - if minmax is None: - minmax = hist[:, 0].min(), hist[:, 0].max() - else: - minmax = min(minmax[0], hist[:, 0].min()), max(minmax[1], hist[:, 0].max()) - - # update the cache - self._hist_report_cache[(title, series)] = hist_list, hist_iters, minmax - - # only report histogram every specific interval, but do report the first few, so you know there are histograms - if hist_iters.size < 1 or (hist_iters.size >= self.histogram_update_freq_multiplier and - hist_iters.size % self.histogram_update_freq_multiplier != 0): - return None - - # resample histograms on a unified bin axis - _minmax = minmax[0] - 1, minmax[1] + 1 - prev_xedge = np.arange(start=_minmax[0], - step=(_minmax[1] - _minmax[0]) / (self._hist_x_granularity - 2), stop=_minmax[1]) - # uniformly select histograms and the last one - cur_idx = _sample_histograms(hist_iters, self.histogram_granularity) - report_hist = np.zeros(shape=(len(cur_idx), prev_xedge.size), dtype=np.float32) - for i, n in enumerate(cur_idx): - h = hist_list[n] - report_hist[i, :] = np.interp(prev_xedge, h[:, 0], h[:, 1], right=0, left=0) - yedges = hist_iters[cur_idx] - xedges = prev_xedge - - # if only a single line make, add another zero line, for the scatter plot to draw - if report_hist.shape[0] < 2: - report_hist = np.vstack((np.zeros_like(report_hist), report_hist)) - - # create 3d line (scatter) of histograms - skipx = max(1, int(xedges.size / 10)) - skipy = max(1, int(yedges.size / 10)) - xlabels = ['%.2f' % v if i % skipx == 0 else '' for i, v in enumerate(xedges[:-1])] - ylabels = [str(int(v)) if i % skipy == 0 else '' for i, v in enumerate(yedges)] - self._logger.report_surface( + self._grad_helper.add_histogram( title=title, series=series, - iteration=0, - xaxis=' ', - yaxis='iteration', - xlabels=xlabels, - ylabels=ylabels, - matrix=report_hist, - camera=(-0.1, +1.3, 1.4)) + step=step, + hist_data=histo_data + ) def _add_plot(self, tag, step, values, vdict): # noinspection PyBroadException @@ -490,7 +538,8 @@ class EventTrainsWriter(object): max_history=self.max_keep_images, ) - def _fix_step_counter(self, title, series, step): + @staticmethod + def _fix_step_counter(title, series, step): key = (title, series) if key not in EventTrainsWriter._title_series_wraparound_counter: EventTrainsWriter._title_series_wraparound_counter[key] = {'first_step': step, 'last_step': step, @@ -1094,9 +1143,10 @@ class PatchTensorFlowEager(object): PatchTensorFlowEager._add_image_event_helper(event_writer, img_data_np=img_data_np, tag=tag, step=step, **kwargs) elif plugin_type.endswith('histograms'): - PatchTensorFlowEager._add_histogram_event_helper( - event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step, - hist_data=tensor.numpy()) + event_writer._add_histogram( + tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step, + hist_data=tensor.numpy() + ) elif 'audio' in plugin_type: audio_bytes_list = [a for a in tensor.numpy().flatten() if a] for i, audio_bytes in enumerate(audio_bytes_list): @@ -1126,9 +1176,10 @@ class PatchTensorFlowEager(object): event_writer = PatchTensorFlowEager._get_event_writer(writer) if event_writer: try: - PatchTensorFlowEager._add_histogram_event_helper( - event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step, - hist_data=values.numpy()) + event_writer._add_histogram( + tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step, + hist_data=values.numpy() + ) except Exception as ex: LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex)) return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs) @@ -1145,23 +1196,6 @@ class PatchTensorFlowEager(object): return PatchTensorFlowEager.__original_fn_image(writer, step, tag, tensor, bad_color, max_images, name, **kwargs) - @staticmethod - def _add_histogram_event_helper(event_writer, hist_data, tag, step): - if isinstance(hist_data, dict): - event_writer._add_histogram(tag=tag, step=step, histo_data=hist_data) - return - - if isinstance(hist_data, np.ndarray): - hist_data = np.histogram(hist_data) - histo_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()} - else: - # prepare the dictionary, assume numpy - # histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis - # histo_data['bucket'] is the histogram height, meaning the Y axis - # notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side - histo_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()} - event_writer._add_histogram(tag=tag, step=step, histo_data=histo_data) - @staticmethod def _add_image_event_helper(event_writer, img_data_np, tag, step, **kwargs): if img_data_np.ndim == 1 and img_data_np.size >= 3 and \