Fix text encoding utf-8 and pr_curve broken in Tensorboard support

This commit is contained in:
allegroai 2020-04-16 16:40:14 +03:00
parent 4bb17ca420
commit 7f4b100042

View File

@ -418,18 +418,21 @@ class EventTrainsWriter(object):
def _add_plot(self, tag, step, values, vdict): def _add_plot(self, tag, step, values, vdict):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
plot_values = np.frombuffer(base64.b64decode(values['tensorContent'].encode('utf-8')), if values.get('floatVal'):
dtype=np.float32) plot_values = np.array(values.get('floatVal'), dtype=np.float32)
else:
plot_values = np.frombuffer(base64.b64decode(values['tensorContent'].encode('utf-8')),
dtype=np.float32)
plot_values = plot_values.reshape((int(values['tensorShape']['dim'][0]['size']), plot_values = plot_values.reshape((int(values['tensorShape']['dim'][0]['size']),
int(values['tensorShape']['dim'][1]['size']))) int(values['tensorShape']['dim'][1]['size'])))
if 'metadata' in vdict: if 'metadata' in vdict:
if tag not in self._series_name_lookup: if tag not in self._series_name_lookup:
self._series_name_lookup[tag] = [(tag, vdict['metadata']['displayName'], self._series_name_lookup[tag] = [(tag, vdict['metadata'].get('displayName', ''),
vdict['metadata']['pluginData']['pluginName'])] vdict['metadata']['pluginData']['pluginName'])]
else: else:
# this should not happen, maybe it's another run, let increase the value # this should not happen, maybe it's another run, let increase the value
self._series_name_lookup[tag] += [(tag + '_%d' % len(self._series_name_lookup[tag]) + 1, self._series_name_lookup[tag] += [(tag + '_%d' % (len(self._series_name_lookup[tag]) + 1),
vdict['metadata']['displayName'], vdict['metadata'].get('displayName', ''),
vdict['metadata']['pluginData']['pluginName'])] vdict['metadata']['pluginData']['pluginName'])]
tag, series, plugin_name = self._series_name_lookup.get(tag, [(tag, tag, '')])[-1] tag, series, plugin_name = self._series_name_lookup.get(tag, [(tag, tag, '')])[-1]
@ -439,7 +442,7 @@ class EventTrainsWriter(object):
# width = 1.0 / (num_thresholds - 1) # width = 1.0 / (num_thresholds - 1)
# thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] # thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
num_thresholds = plot_values.shape[1] num_thresholds = plot_values.shape[1]
width = 1.0 / (num_thresholds - 1) width = 1.0 / num_thresholds
thresholds = np.arange(0.0, 1.0, width, dtype=plot_values.dtype) thresholds = np.arange(0.0, 1.0, width, dtype=plot_values.dtype)
data_points = ['TP ', 'FP ', 'TN ', 'FN ', 'Precision ', ' Recall'] data_points = ['TP ', 'FP ', 'TN ', 'FN ', 'Precision ', ' Recall']
series = [{'name': series, 'data': np.vstack((thresholds, plot_values[-2])).T, series = [{'name': series, 'data': np.vstack((thresholds, plot_values[-2])).T,
@ -561,13 +564,13 @@ class EventTrainsWriter(object):
self._add_audio(tag, step, values) self._add_audio(tag, step, values)
elif metric == 'tensor' and values.get('dtype') == 'DT_STRING': elif metric == 'tensor' and values.get('dtype') == 'DT_STRING':
# text, just print to console # text, just print to console
text = base64.b64decode('\n'.join(values['stringVal'])).decode('utf-8') text = base64.b64decode('\n'.join(values['stringVal'])).decode('utf-8', errors='replace')
self._logger.report_text(msg='SUMMARY LOG: {} {}'.format(tag, text), print_console=False) self._logger.report_text(msg='SUMMARY LOG: {} {}'.format(tag, text), print_console=False)
elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT': elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT':
self._add_plot(tag, step, values, vdict) self._add_plot(tag, step, values, vdict)
else: else:
LoggerRoot.get_base_logger(TensorflowBinding).debug('Event unsupported. tag = %s, vdict keys [%s]' LoggerRoot.get_base_logger(TensorflowBinding).debug(
% (tag, ', '.join(vdict.keys()))) 'Event unsupported. tag = %s, vdict keys [%s]' % (tag, ', '.join(vdict.keys())))
continue continue
def get_logdir(self): def get_logdir(self):