mirror of
https://github.com/clearml/clearml
synced 2025-01-31 17:17:00 +00:00
Fix text encoding utf-8 and pr_curve broken in Tensorboard support
This commit is contained in:
parent
4bb17ca420
commit
7f4b100042
@ -418,18 +418,21 @@ class EventTrainsWriter(object):
|
||||
def _add_plot(self, tag, step, values, vdict):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if values.get('floatVal'):
|
||||
plot_values = np.array(values.get('floatVal'), dtype=np.float32)
|
||||
else:
|
||||
plot_values = np.frombuffer(base64.b64decode(values['tensorContent'].encode('utf-8')),
|
||||
dtype=np.float32)
|
||||
plot_values = plot_values.reshape((int(values['tensorShape']['dim'][0]['size']),
|
||||
int(values['tensorShape']['dim'][1]['size'])))
|
||||
if 'metadata' in vdict:
|
||||
if tag not in self._series_name_lookup:
|
||||
self._series_name_lookup[tag] = [(tag, vdict['metadata']['displayName'],
|
||||
self._series_name_lookup[tag] = [(tag, vdict['metadata'].get('displayName', ''),
|
||||
vdict['metadata']['pluginData']['pluginName'])]
|
||||
else:
|
||||
# this should not happen, maybe it's another run, let increase the value
|
||||
self._series_name_lookup[tag] += [(tag + '_%d' % len(self._series_name_lookup[tag]) + 1,
|
||||
vdict['metadata']['displayName'],
|
||||
self._series_name_lookup[tag] += [(tag + '_%d' % (len(self._series_name_lookup[tag]) + 1),
|
||||
vdict['metadata'].get('displayName', ''),
|
||||
vdict['metadata']['pluginData']['pluginName'])]
|
||||
|
||||
tag, series, plugin_name = self._series_name_lookup.get(tag, [(tag, tag, '')])[-1]
|
||||
@ -439,7 +442,7 @@ class EventTrainsWriter(object):
|
||||
# width = 1.0 / (num_thresholds - 1)
|
||||
# thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
|
||||
num_thresholds = plot_values.shape[1]
|
||||
width = 1.0 / (num_thresholds - 1)
|
||||
width = 1.0 / num_thresholds
|
||||
thresholds = np.arange(0.0, 1.0, width, dtype=plot_values.dtype)
|
||||
data_points = ['TP ', 'FP ', 'TN ', 'FN ', 'Precision ', ' Recall']
|
||||
series = [{'name': series, 'data': np.vstack((thresholds, plot_values[-2])).T,
|
||||
@ -561,13 +564,13 @@ class EventTrainsWriter(object):
|
||||
self._add_audio(tag, step, values)
|
||||
elif metric == 'tensor' and values.get('dtype') == 'DT_STRING':
|
||||
# text, just print to console
|
||||
text = base64.b64decode('\n'.join(values['stringVal'])).decode('utf-8')
|
||||
text = base64.b64decode('\n'.join(values['stringVal'])).decode('utf-8', errors='replace')
|
||||
self._logger.report_text(msg='SUMMARY LOG: {} {}'.format(tag, text), print_console=False)
|
||||
elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT':
|
||||
self._add_plot(tag, step, values, vdict)
|
||||
else:
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug('Event unsupported. tag = %s, vdict keys [%s]'
|
||||
% (tag, ', '.join(vdict.keys())))
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug(
|
||||
'Event unsupported. tag = %s, vdict keys [%s]' % (tag, ', '.join(vdict.keys())))
|
||||
continue
|
||||
|
||||
def get_logdir(self):
|
||||
|
Loading…
Reference in New Issue
Block a user