diff --git a/trains/binding/frameworks/tensorflow_bind.py b/trains/binding/frameworks/tensorflow_bind.py index d95a497b..fbc3a6bd 100644 --- a/trains/binding/frameworks/tensorflow_bind.py +++ b/trains/binding/frameworks/tensorflow_bind.py @@ -15,7 +15,7 @@ from PIL import Image from ...debugging.log import LoggerRoot from ..frameworks import _patched_call, WeightsFileHandler, _Empty from ..import_bind import PostImportHookPatching -from ...config import running_remotely, config +from ...config import running_remotely from ...model import InputModel, OutputModel, Framework try: @@ -189,6 +189,7 @@ class EventTrainsWriter(object): self._hist_x_granularity = 50 self._max_step = 0 self._graph_name_lookup = {} + self._generic_tensor_type_name_lookup = {} def _decode_image(self, img_str, width, height, color_channels): # noinspection PyBroadException @@ -458,22 +459,28 @@ class EventTrainsWriter(object): except Exception: pass - def _add_audio(self, tag, step, values): + def _add_audio(self, tag, step, values, audio_data=None): # only report images every specific interval if step % self.image_report_freq != 0: return None - audio_str = values['encodedAudioString'] - audio_data = base64.b64decode(audio_str) + if values: + audio_str = values['encodedAudioString'] + audio_data = base64.b64decode(audio_str) if audio_data is None: return - title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title', + title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Audio', logdir_header='title', auto_reduce_num_split=True) step = self._fix_step_counter(title, series, step) stream = BytesIO(audio_data) - file_extension = guess_extension(values['contentType']) or '.{}'.format(values['contentType'].split('/')[-1]) + if values: + file_extension = guess_extension(values['contentType']) or \ + '.{}'.format(values['contentType'].split('/')[-1]) + else: + # assume wav as default + file_extension = '.wav' self._logger.report_media( title=title, series=series, @@ -550,8 +557,8 @@ class EventTrainsWriter(object): tag = vdict.pop('tag', None) if tag is None: # we should not get here - LoggerRoot.get_base_logger(TensorflowBinding).debug('No tag for \'value\' existing keys %s' - % ', '.join(vdict.keys())) + LoggerRoot.get_base_logger(TensorflowBinding).debug( + 'No tag for \'value\' existing keys %s' % ', '.join(vdict.keys())) continue metric, values = get_data(vdict, supported_metrics) if metric == 'simpleValue': @@ -563,9 +570,20 @@ class EventTrainsWriter(object): elif metric == 'audio': self._add_audio(tag, step, values) elif metric == 'tensor' and values.get('dtype') == 'DT_STRING': - # text, just print to console - text = base64.b64decode('\n'.join(values['stringVal'])).decode('utf-8', errors='replace') - self._logger.report_text(msg='SUMMARY LOG: {} {}'.format(tag, text), print_console=False) + # generic tensor + tensor_bytes = base64.b64decode('\n'.join(values['stringVal'])) + plugin_type = self._generic_tensor_type_name_lookup.get(tag) or \ + vdict.get('metadata', {}).get('pluginData', {}).get('pluginName', '').lower() + if plugin_type == 'audio': + self._generic_tensor_type_name_lookup[tag] = plugin_type + self._add_audio(tag, step, None, tensor_bytes) + elif plugin_type == 'text': + # text, just print to console + text = tensor_bytes.decode('utf-8', errors='replace') + self._logger.report_text(msg='SUMMARY LOG: {} {}'.format(tag, text), print_console=False) + else: + # we do not support it + pass elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT': self._add_plot(tag, step, values, vdict) else: @@ -1079,6 +1097,12 @@ class PatchTensorFlowEager(object): PatchTensorFlowEager._add_histogram_event_helper( event_writer, tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step, hist_data=tensor.numpy()) + elif 'audio' in plugin_type: + audio_bytes_list = [a for a in tensor.numpy().flatten() if a] + for i, audio_bytes in enumerate(audio_bytes_list): + event_writer._add_audio(tag=str(tag) + ('/{}'.format(i) if len(audio_bytes_list) > 1 else ''), + step=int(step.numpy()) if not isinstance(step, int) else step, + values=None, audio_data=audio_bytes) else: pass # print('unsupported plugin_type', plugin_type) except Exception as ex: