From 7f4b100042febb028d6bcb36bf9869b0f49d95d5 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 16 Apr 2020 16:40:14 +0300 Subject: [PATCH] Fix text encoding utf-8 and pr_curve broken in Tensorboard support --- trains/binding/frameworks/tensorflow_bind.py | 21 +++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/trains/binding/frameworks/tensorflow_bind.py b/trains/binding/frameworks/tensorflow_bind.py index fc15b263..d95a497b 100644 --- a/trains/binding/frameworks/tensorflow_bind.py +++ b/trains/binding/frameworks/tensorflow_bind.py @@ -418,18 +418,21 @@ class EventTrainsWriter(object): def _add_plot(self, tag, step, values, vdict): # noinspection PyBroadException try: - plot_values = np.frombuffer(base64.b64decode(values['tensorContent'].encode('utf-8')), - dtype=np.float32) + if values.get('floatVal'): + plot_values = np.array(values.get('floatVal'), dtype=np.float32) + else: + plot_values = np.frombuffer(base64.b64decode(values['tensorContent'].encode('utf-8')), + dtype=np.float32) plot_values = plot_values.reshape((int(values['tensorShape']['dim'][0]['size']), int(values['tensorShape']['dim'][1]['size']))) if 'metadata' in vdict: if tag not in self._series_name_lookup: - self._series_name_lookup[tag] = [(tag, vdict['metadata']['displayName'], + self._series_name_lookup[tag] = [(tag, vdict['metadata'].get('displayName', ''), vdict['metadata']['pluginData']['pluginName'])] else: # this should not happen, maybe it's another run, let increase the value - self._series_name_lookup[tag] += [(tag + '_%d' % len(self._series_name_lookup[tag]) + 1, - vdict['metadata']['displayName'], + self._series_name_lookup[tag] += [(tag + '_%d' % (len(self._series_name_lookup[tag]) + 1), + vdict['metadata'].get('displayName', ''), vdict['metadata']['pluginData']['pluginName'])] tag, series, plugin_name = self._series_name_lookup.get(tag, [(tag, tag, '')])[-1] @@ -439,7 +442,7 @@ class EventTrainsWriter(object): # width = 1.0 / (num_thresholds - 1) # thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] num_thresholds = plot_values.shape[1] - width = 1.0 / (num_thresholds - 1) + width = 1.0 / num_thresholds thresholds = np.arange(0.0, 1.0, width, dtype=plot_values.dtype) data_points = ['TP ', 'FP ', 'TN ', 'FN ', 'Precision ', ' Recall'] series = [{'name': series, 'data': np.vstack((thresholds, plot_values[-2])).T, @@ -561,13 +564,13 @@ class EventTrainsWriter(object): self._add_audio(tag, step, values) elif metric == 'tensor' and values.get('dtype') == 'DT_STRING': # text, just print to console - text = base64.b64decode('\n'.join(values['stringVal'])).decode('utf-8') + text = base64.b64decode('\n'.join(values['stringVal'])).decode('utf-8', errors='replace') self._logger.report_text(msg='SUMMARY LOG: {} {}'.format(tag, text), print_console=False) elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT': self._add_plot(tag, step, values, vdict) else: - LoggerRoot.get_base_logger(TensorflowBinding).debug('Event unsupported. tag = %s, vdict keys [%s]' - % (tag, ', '.join(vdict.keys()))) + LoggerRoot.get_base_logger(TensorflowBinding).debug( + 'Event unsupported. tag = %s, vdict keys [%s]' % (tag, ', '.join(vdict.keys()))) continue def get_logdir(self):