mirror of
https://github.com/clearml/clearml
synced 2025-01-31 17:17:00 +00:00
Fix text encoding utf-8 and pr_curve broken in Tensorboard support
This commit is contained in:
parent
4bb17ca420
commit
7f4b100042
@ -418,18 +418,21 @@ class EventTrainsWriter(object):
|
|||||||
def _add_plot(self, tag, step, values, vdict):
|
def _add_plot(self, tag, step, values, vdict):
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
plot_values = np.frombuffer(base64.b64decode(values['tensorContent'].encode('utf-8')),
|
if values.get('floatVal'):
|
||||||
dtype=np.float32)
|
plot_values = np.array(values.get('floatVal'), dtype=np.float32)
|
||||||
|
else:
|
||||||
|
plot_values = np.frombuffer(base64.b64decode(values['tensorContent'].encode('utf-8')),
|
||||||
|
dtype=np.float32)
|
||||||
plot_values = plot_values.reshape((int(values['tensorShape']['dim'][0]['size']),
|
plot_values = plot_values.reshape((int(values['tensorShape']['dim'][0]['size']),
|
||||||
int(values['tensorShape']['dim'][1]['size'])))
|
int(values['tensorShape']['dim'][1]['size'])))
|
||||||
if 'metadata' in vdict:
|
if 'metadata' in vdict:
|
||||||
if tag not in self._series_name_lookup:
|
if tag not in self._series_name_lookup:
|
||||||
self._series_name_lookup[tag] = [(tag, vdict['metadata']['displayName'],
|
self._series_name_lookup[tag] = [(tag, vdict['metadata'].get('displayName', ''),
|
||||||
vdict['metadata']['pluginData']['pluginName'])]
|
vdict['metadata']['pluginData']['pluginName'])]
|
||||||
else:
|
else:
|
||||||
# this should not happen, maybe it's another run, let increase the value
|
# this should not happen, maybe it's another run, let increase the value
|
||||||
self._series_name_lookup[tag] += [(tag + '_%d' % len(self._series_name_lookup[tag]) + 1,
|
self._series_name_lookup[tag] += [(tag + '_%d' % (len(self._series_name_lookup[tag]) + 1),
|
||||||
vdict['metadata']['displayName'],
|
vdict['metadata'].get('displayName', ''),
|
||||||
vdict['metadata']['pluginData']['pluginName'])]
|
vdict['metadata']['pluginData']['pluginName'])]
|
||||||
|
|
||||||
tag, series, plugin_name = self._series_name_lookup.get(tag, [(tag, tag, '')])[-1]
|
tag, series, plugin_name = self._series_name_lookup.get(tag, [(tag, tag, '')])[-1]
|
||||||
@ -439,7 +442,7 @@ class EventTrainsWriter(object):
|
|||||||
# width = 1.0 / (num_thresholds - 1)
|
# width = 1.0 / (num_thresholds - 1)
|
||||||
# thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
|
# thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
|
||||||
num_thresholds = plot_values.shape[1]
|
num_thresholds = plot_values.shape[1]
|
||||||
width = 1.0 / (num_thresholds - 1)
|
width = 1.0 / num_thresholds
|
||||||
thresholds = np.arange(0.0, 1.0, width, dtype=plot_values.dtype)
|
thresholds = np.arange(0.0, 1.0, width, dtype=plot_values.dtype)
|
||||||
data_points = ['TP ', 'FP ', 'TN ', 'FN ', 'Precision ', ' Recall']
|
data_points = ['TP ', 'FP ', 'TN ', 'FN ', 'Precision ', ' Recall']
|
||||||
series = [{'name': series, 'data': np.vstack((thresholds, plot_values[-2])).T,
|
series = [{'name': series, 'data': np.vstack((thresholds, plot_values[-2])).T,
|
||||||
@ -561,13 +564,13 @@ class EventTrainsWriter(object):
|
|||||||
self._add_audio(tag, step, values)
|
self._add_audio(tag, step, values)
|
||||||
elif metric == 'tensor' and values.get('dtype') == 'DT_STRING':
|
elif metric == 'tensor' and values.get('dtype') == 'DT_STRING':
|
||||||
# text, just print to console
|
# text, just print to console
|
||||||
text = base64.b64decode('\n'.join(values['stringVal'])).decode('utf-8')
|
text = base64.b64decode('\n'.join(values['stringVal'])).decode('utf-8', errors='replace')
|
||||||
self._logger.report_text(msg='SUMMARY LOG: {} {}'.format(tag, text), print_console=False)
|
self._logger.report_text(msg='SUMMARY LOG: {} {}'.format(tag, text), print_console=False)
|
||||||
elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT':
|
elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT':
|
||||||
self._add_plot(tag, step, values, vdict)
|
self._add_plot(tag, step, values, vdict)
|
||||||
else:
|
else:
|
||||||
LoggerRoot.get_base_logger(TensorflowBinding).debug('Event unsupported. tag = %s, vdict keys [%s]'
|
LoggerRoot.get_base_logger(TensorflowBinding).debug(
|
||||||
% (tag, ', '.join(vdict.keys())))
|
'Event unsupported. tag = %s, vdict keys [%s]' % (tag, ', '.join(vdict.keys())))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def get_logdir(self):
|
def get_logdir(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user