Fix Tensorflow V1/V2 audio support

This commit is contained in:
allegroai 2020-04-16 16:46:02 +03:00
parent 8f1dd8ba8b
commit cb139f2d17

View File

@ -15,7 +15,7 @@ from PIL import Image
from ...debugging.log import LoggerRoot
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
from ..import_bind import PostImportHookPatching
from ...config import running_remotely, config
from ...config import running_remotely
from ...model import InputModel, OutputModel, Framework
try:
@ -189,6 +189,7 @@ class EventTrainsWriter(object):
self._hist_x_granularity = 50
self._max_step = 0
self._graph_name_lookup = {}
self._generic_tensor_type_name_lookup = {}
def _decode_image(self, img_str, width, height, color_channels):
# noinspection PyBroadException
@ -458,22 +459,28 @@ class EventTrainsWriter(object):
except Exception:
pass
def _add_audio(self, tag, step, values):
def _add_audio(self, tag, step, values, audio_data=None):
# only report images every specific interval
if step % self.image_report_freq != 0:
return None
if values:
audio_str = values['encodedAudioString']
audio_data = base64.b64decode(audio_str)
if audio_data is None:
return
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title',
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Audio', logdir_header='title',
auto_reduce_num_split=True)
step = self._fix_step_counter(title, series, step)
stream = BytesIO(audio_data)
file_extension = guess_extension(values['contentType']) or '.{}'.format(values['contentType'].split('/')[-1])
if values:
file_extension = guess_extension(values['contentType']) or \
'.{}'.format(values['contentType'].split('/')[-1])
else:
# assume wav as default
file_extension = '.wav'
self._logger.report_media(
title=title,
series=series,
@ -550,8 +557,8 @@ class EventTrainsWriter(object):
tag = vdict.pop('tag', None)
if tag is None:
# we should not get here
LoggerRoot.get_base_logger(TensorflowBinding).debug('No tag for \'value\' existing keys %s'
% ', '.join(vdict.keys()))
LoggerRoot.get_base_logger(TensorflowBinding).debug(
'No tag for \'value\' existing keys %s' % ', '.join(vdict.keys()))
continue
metric, values = get_data(vdict, supported_metrics)
if metric == 'simpleValue':
@ -563,9 +570,20 @@ class EventTrainsWriter(object):
elif metric == 'audio':
self._add_audio(tag, step, values)
elif metric == 'tensor' and values.get('dtype') == 'DT_STRING':
# generic tensor
tensor_bytes = base64.b64decode('\n'.join(values['stringVal']))
plugin_type = self._generic_tensor_type_name_lookup.get(tag) or \
vdict.get('metadata', {}).get('pluginData', {}).get('pluginName', '').lower()
if plugin_type == 'audio':
self._generic_tensor_type_name_lookup[tag] = plugin_type
self._add_audio(tag, step, None, tensor_bytes)
elif plugin_type == 'text':
# text, just print to console
text = base64.b64decode('\n'.join(values['stringVal'])).decode('utf-8', errors='replace')
text = tensor_bytes.decode('utf-8', errors='replace')
self._logger.report_text(msg='SUMMARY LOG: {} {}'.format(tag, text), print_console=False)
else:
# we do not support it
pass
elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT':
self._add_plot(tag, step, values, vdict)
else:
@ -1079,6 +1097,12 @@ class PatchTensorFlowEager(object):
PatchTensorFlowEager._add_histogram_event_helper(
event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
hist_data=tensor.numpy())
elif 'audio' in plugin_type:
audio_bytes_list = [a for a in tensor.numpy().flatten() if a]
for i, audio_bytes in enumerate(audio_bytes_list):
event_writer._add_audio(tag=str(tag) + ('/{}'.format(i) if len(audio_bytes_list) > 1 else ''),
step=int(step.numpy()) if not isinstance(step, int) else step,
values=None, audio_data=audio_bytes)
else:
pass # print('unsupported plugin_type', plugin_type)
except Exception as ex: