mirror of
https://github.com/clearml/clearml
synced 2025-05-03 12:31:00 +00:00
Fix Tensorflow V1/V2 audio support
This commit is contained in:
parent
8f1dd8ba8b
commit
cb139f2d17
@ -15,7 +15,7 @@ from PIL import Image
|
||||
from ...debugging.log import LoggerRoot
|
||||
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
||||
from ..import_bind import PostImportHookPatching
|
||||
from ...config import running_remotely, config
|
||||
from ...config import running_remotely
|
||||
from ...model import InputModel, OutputModel, Framework
|
||||
|
||||
try:
|
||||
@ -189,6 +189,7 @@ class EventTrainsWriter(object):
|
||||
self._hist_x_granularity = 50
|
||||
self._max_step = 0
|
||||
self._graph_name_lookup = {}
|
||||
self._generic_tensor_type_name_lookup = {}
|
||||
|
||||
def _decode_image(self, img_str, width, height, color_channels):
|
||||
# noinspection PyBroadException
|
||||
@ -458,22 +459,28 @@ class EventTrainsWriter(object):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _add_audio(self, tag, step, values):
|
||||
def _add_audio(self, tag, step, values, audio_data=None):
|
||||
# only report images every specific interval
|
||||
if step % self.image_report_freq != 0:
|
||||
return None
|
||||
|
||||
if values:
|
||||
audio_str = values['encodedAudioString']
|
||||
audio_data = base64.b64decode(audio_str)
|
||||
if audio_data is None:
|
||||
return
|
||||
|
||||
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title',
|
||||
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Audio', logdir_header='title',
|
||||
auto_reduce_num_split=True)
|
||||
step = self._fix_step_counter(title, series, step)
|
||||
|
||||
stream = BytesIO(audio_data)
|
||||
file_extension = guess_extension(values['contentType']) or '.{}'.format(values['contentType'].split('/')[-1])
|
||||
if values:
|
||||
file_extension = guess_extension(values['contentType']) or \
|
||||
'.{}'.format(values['contentType'].split('/')[-1])
|
||||
else:
|
||||
# assume wav as default
|
||||
file_extension = '.wav'
|
||||
self._logger.report_media(
|
||||
title=title,
|
||||
series=series,
|
||||
@ -550,8 +557,8 @@ class EventTrainsWriter(object):
|
||||
tag = vdict.pop('tag', None)
|
||||
if tag is None:
|
||||
# we should not get here
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug('No tag for \'value\' existing keys %s'
|
||||
% ', '.join(vdict.keys()))
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug(
|
||||
'No tag for \'value\' existing keys %s' % ', '.join(vdict.keys()))
|
||||
continue
|
||||
metric, values = get_data(vdict, supported_metrics)
|
||||
if metric == 'simpleValue':
|
||||
@ -563,9 +570,20 @@ class EventTrainsWriter(object):
|
||||
elif metric == 'audio':
|
||||
self._add_audio(tag, step, values)
|
||||
elif metric == 'tensor' and values.get('dtype') == 'DT_STRING':
|
||||
# generic tensor
|
||||
tensor_bytes = base64.b64decode('\n'.join(values['stringVal']))
|
||||
plugin_type = self._generic_tensor_type_name_lookup.get(tag) or \
|
||||
vdict.get('metadata', {}).get('pluginData', {}).get('pluginName', '').lower()
|
||||
if plugin_type == 'audio':
|
||||
self._generic_tensor_type_name_lookup[tag] = plugin_type
|
||||
self._add_audio(tag, step, None, tensor_bytes)
|
||||
elif plugin_type == 'text':
|
||||
# text, just print to console
|
||||
text = base64.b64decode('\n'.join(values['stringVal'])).decode('utf-8', errors='replace')
|
||||
text = tensor_bytes.decode('utf-8', errors='replace')
|
||||
self._logger.report_text(msg='SUMMARY LOG: {} {}'.format(tag, text), print_console=False)
|
||||
else:
|
||||
# we do not support it
|
||||
pass
|
||||
elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT':
|
||||
self._add_plot(tag, step, values, vdict)
|
||||
else:
|
||||
@ -1079,6 +1097,12 @@ class PatchTensorFlowEager(object):
|
||||
PatchTensorFlowEager._add_histogram_event_helper(
|
||||
event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
hist_data=tensor.numpy())
|
||||
elif 'audio' in plugin_type:
|
||||
audio_bytes_list = [a for a in tensor.numpy().flatten() if a]
|
||||
for i, audio_bytes in enumerate(audio_bytes_list):
|
||||
event_writer._add_audio(tag=str(tag) + ('/{}'.format(i) if len(audio_bytes_list) > 1 else ''),
|
||||
step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
values=None, audio_data=audio_bytes)
|
||||
else:
|
||||
pass # print('unsupported plugin_type', plugin_type)
|
||||
except Exception as ex:
|
||||
|
Loading…
Reference in New Issue
Block a user