From 648779380c1e157eb6f342bd1a9a6b6d3337b9ca Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 9 Apr 2020 13:14:14 +0300 Subject: [PATCH] Add media (audio) support for both Logger and Tensorboard bind --- trains/backend_interface/metrics/events.py | 38 +++++++++--- trains/backend_interface/metrics/reporter.py | 56 ++++++++++++++++- trains/binding/frameworks/tensorflow_bind.py | 30 ++++++++- trains/logger.py | 65 ++++++++++++++++++++ 4 files changed, 175 insertions(+), 14 deletions(-) diff --git a/trains/backend_interface/metrics/events.py b/trains/backend_interface/metrics/events.py index d551d84a..e547ead3 100644 --- a/trains/backend_interface/metrics/events.py +++ b/trains/backend_interface/metrics/events.py @@ -179,31 +179,31 @@ class UploadEvent(MetricsEventAdapter): _metric_counters = {} _metric_counters_lock = Lock() - _image_file_history_size = int(config.get('metrics.file_history_size', 5)) + _file_history_size = int(config.get('metrics.file_history_size', 5)) @staticmethod def _replace_slash(part): return part.replace('\\', '/').strip('/').replace('/', '.slash.') def __init__(self, metric, variant, image_data, local_image_path=None, iter=0, upload_uri=None, - image_file_history_size=None, delete_after_upload=False, **kwargs): + file_history_size=None, delete_after_upload=False, **kwargs): # param override_filename: override uploaded file name (notice extension will be added from local path # param override_filename_ext: override uploaded file extension - if image_data is not None and not hasattr(image_data, 'shape'): + if image_data is not None and (not hasattr(image_data, 'shape') and not isinstance(image_data, six.BytesIO)): raise ValueError('Image must have a shape attribute') self._image_data = image_data self._local_image_path = local_image_path self._url = None self._key = None self._count = self._get_metric_count(metric, variant) - if not image_file_history_size: - image_file_history_size = self._image_file_history_size + if not file_history_size: + file_history_size = self._file_history_size self._filename = kwargs.pop('override_filename', None) if not self._filename: - if image_file_history_size < 1: + if file_history_size < 1: self._filename = '%s_%s_%08d' % (metric, variant, self._count) else: - self._filename = '%s_%s_%08d' % (metric, variant, self._count % image_file_history_size) + self._filename = '%s_%s_%08d' % (metric, variant, self._count % file_history_size) # make sure we have to '/' in the filename because it might access other folders, # and we don't want that to occur @@ -253,8 +253,10 @@ class UploadEvent(MetricsEventAdapter): local_file = None # don't provide file in case this event is out of the history window last_count = self._get_metric_count(self.metric, self.variant, next=False) - if abs(self._count - last_count) > self._image_file_history_size: + if abs(self._count - last_count) > self._file_history_size: output = None + elif isinstance(self._image_data, six.BytesIO): + output = self._image_data elif self._image_data is not None: image_data = self._image_data if not isinstance(image_data, np.ndarray): @@ -318,10 +320,26 @@ class UploadEvent(MetricsEventAdapter): class ImageEvent(UploadEvent): def __init__(self, metric, variant, image_data, local_image_path=None, iter=0, upload_uri=None, - image_file_history_size=None, delete_after_upload=False, **kwargs): + file_history_size=None, delete_after_upload=False, **kwargs): super(ImageEvent, self).__init__(metric, variant, image_data=image_data, local_image_path=local_image_path, iter=iter, upload_uri=upload_uri, - image_file_history_size=image_file_history_size, + file_history_size=file_history_size, + delete_after_upload=delete_after_upload, **kwargs) + + def get_api_event(self): + return events.MetricsImageEvent( + url=self._url, + key=self._key, + **self._get_base_dict() + ) + + +class MediaEvent(UploadEvent): + def __init__(self, metric, variant, stream, local_image_path=None, iter=0, upload_uri=None, + file_history_size=None, delete_after_upload=False, **kwargs): + super(MediaEvent, self).__init__(metric, variant, image_data=stream, local_image_path=local_image_path, + iter=iter, upload_uri=upload_uri, + file_history_size=file_history_size, delete_after_upload=delete_after_upload, **kwargs) def get_api_event(self): diff --git a/trains/backend_interface/metrics/reporter.py b/trains/backend_interface/metrics/reporter.py index b1be0bc5..d9143e83 100644 --- a/trains/backend_interface/metrics/reporter.py +++ b/trains/backend_interface/metrics/reporter.py @@ -16,7 +16,7 @@ from ...utilities.plotly_reporter import create_2d_histogram_plot, create_value_ create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict, \ create_image_plot, create_plotly_table from ...utilities.py3_interop import AbstractContextManager -from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload, UploadEvent +from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload, UploadEvent, MediaEvent class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncManagerMixin): @@ -204,11 +204,28 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan ev = ImageEventNoUpload(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter, src=src) self._report(ev) + def report_media(self, title, series, src, iter): + """ + Report a media link. + :param title: Title (AKA metric) + :type title: str + :param series: Series (AKA variant) + :type series: str + :param src: Media source URI. This URI will be used by the webapp and workers when trying to obtain the image + for presentation of processing. Currently only http(s), file and s3 schemes are supported. + :type src: str + :param iter: Iteration number + :type value: int + """ + ev = ImageEventNoUpload(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter, src=src) + self._report(ev) + def report_image_and_upload(self, title, series, iter, path=None, image=None, upload_uri=None, max_image_history=None, delete_after_upload=False): """ Report an image and upload its contents. Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename) describing the task ID, title, series and iteration. + :param title: Title (AKA metric) :type title: str :param series: Series (AKA variant) @@ -228,11 +245,44 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan raise ValueError('Upload configuration is required (use setup_upload())') if len([x for x in (path, image) if x is not None]) != 1: raise ValueError('Expected only one of [filename, image]') - kwargs = dict(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter, image_file_history_size=max_image_history) + kwargs = dict(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter, file_history_size=max_image_history) ev = ImageEvent(image_data=image, upload_uri=upload_uri, local_image_path=path, delete_after_upload=delete_after_upload, **kwargs) self._report(ev) + def report_media_and_upload(self, title, series, iter, path=None, stream=None, upload_uri=None, + file_extension=None, max_history=None, delete_after_upload=False): + """ + Report a media file/stream and upload its contents. + Media is uploaded to a preconfigured bucket + (see setup_upload()) with a key (filename) describing the task ID, title, series and iteration. + + :param title: Title (AKA metric) + :type title: str + :param series: Series (AKA variant) + :type series: str + :param iter: Iteration number + :type iter: int + :param path: A path to an image file. Required unless matrix is provided. + :type path: str + :param stream: File stream + :param file_extension: file extension to use when stream is passed + :param max_history: maximum number of files to store per metric/variant combination + use negative value for unlimited. default is set in global configuration (default=5) + :param delete_after_upload: if True, one the file was uploaded the local copy will be deleted + :type delete_after_upload: boolean + """ + if not self._storage_uri and not upload_uri: + raise ValueError('Upload configuration is required (use setup_upload())') + if len([x for x in (path, stream) if x is not None]) != 1: + raise ValueError('Expected only one of [filename, stream]') + kwargs = dict(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter, + file_history_size=max_history) + ev = MediaEvent(stream=stream, upload_uri=upload_uri, local_image_path=path, + override_filename_ext=file_extension, + delete_after_upload=delete_after_upload, **kwargs) + self._report(ev) + def report_histogram(self, title, series, histogram, iter, labels=None, xlabels=None, xtitle=None, ytitle=None, comment=None): """ @@ -542,7 +592,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan raise ValueError('Upload configuration is required (use setup_upload())') if len([x for x in (path, matrix) if x is not None]) != 1: raise ValueError('Expected only one of [filename, matrix]') - kwargs = dict(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter, image_file_history_size=max_image_history) + kwargs = dict(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter, file_history_size=max_image_history) ev = UploadEvent(image_data=matrix, upload_uri=upload_uri, local_image_path=path, delete_after_upload=delete_after_upload, **kwargs) _, url = ev.get_target_full_upload_uri(upload_uri or self._storage_uri, self._metrics.storage_key_prefix) diff --git a/trains/binding/frameworks/tensorflow_bind.py b/trains/binding/frameworks/tensorflow_bind.py index fbc7b159..fc15b263 100644 --- a/trains/binding/frameworks/tensorflow_bind.py +++ b/trains/binding/frameworks/tensorflow_bind.py @@ -5,6 +5,7 @@ import threading from collections import defaultdict from functools import partial from io import BytesIO +from mimetypes import guess_extension from typing import Any import numpy as np @@ -454,6 +455,31 @@ class EventTrainsWriter(object): except Exception: pass + def _add_audio(self, tag, step, values): + # only report images every specific interval + if step % self.image_report_freq != 0: + return None + + audio_str = values['encodedAudioString'] + audio_data = base64.b64decode(audio_str) + if audio_data is None: + return + + title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title', + auto_reduce_num_split=True) + step = self._fix_step_counter(title, series, step) + + stream = BytesIO(audio_data) + file_extension = guess_extension(values['contentType']) or '.{}'.format(values['contentType'].split('/')[-1]) + self._logger.report_media( + title=title, + series=series, + iteration=step, + stream=stream, + file_extension=file_extension, + max_history=self.max_keep_images, + ) + def _fix_step_counter(self, title, series, step): key = (title, series) if key not in EventTrainsWriter._title_series_wraparound_counter: @@ -473,7 +499,7 @@ class EventTrainsWriter(object): def add_event(self, event, step=None, walltime=None, **kwargs): supported_metrics = { - 'simpleValue', 'image', 'histo', 'tensor' + 'simpleValue', 'image', 'histo', 'tensor', 'audio' } def get_data(value_dict, metric_search_order): @@ -531,6 +557,8 @@ class EventTrainsWriter(object): self._add_histogram(tag=tag, step=step, histo_data=values) elif metric == 'image': self._add_image(tag=tag, step=step, img_data=values) + elif metric == 'audio': + self._add_audio(tag, step, values) elif metric == 'tensor' and values.get('dtype') == 'DT_STRING': # text, just print to console text = base64.b64decode('\n'.join(values['stringVal'])).decode('utf-8') diff --git a/trains/logger.py b/trains/logger.py index 9c0286d7..aa09529c 100644 --- a/trains/logger.py +++ b/trains/logger.py @@ -477,6 +477,71 @@ class Logger(object): delete_after_upload=delete_after_upload, ) + def report_media(self, title, series, iteration, local_path=None, stream=None, + file_extension=None, max_history=None, delete_after_upload=False, url=None): + """ + Report an image and upload its contents. + + Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename) + describing the task ID, title, series and iteration. + + .. note:: + :paramref:`~.Logger.report_image.local_path`, :paramref:`~.Logger.report_image.url`, :paramref:`~.Logger.report_image.image` and :paramref:`~.Logger.report_image.matrix` + are mutually exclusive, and at least one must be provided. + + :param str title: Title (AKA metric) + :param str series: Series (AKA variant) + :param int iteration: Iteration number + :param str local_path: A path to an image file. + :param stream: BytesIO stream to upload (must provide file extension if used) + :param str url: A URL to the location of a pre-uploaded image. + :param file_extension: file extension to use when stream is passed + :param int max_history: maximum number of media files to store per metric/variant combination + use negative value for unlimited. default is set in global configuration (default=5) + :param bool delete_after_upload: if True, one the file was uploaded the local copy will be deleted + """ + mutually_exclusive( + UsageError, _check_none=True, + local_path=local_path or None, url=url or None, stream=stream, + ) + if stream is not None and not file_extension: + raise ValueError("No file extension provided for stream media upload") + + # if task was not started, we have to start it + self._start_task_if_needed() + + self._touch_title_series(title, series) + + if url: + self._task.reporter.report_media( + title=title, + series=series, + src=url, + iter=iteration, + ) + + else: + upload_uri = self.get_default_upload_destination() + if not upload_uri: + upload_uri = Path(get_cache_dir()) / 'debug_images' + upload_uri.mkdir(parents=True, exist_ok=True) + # Verify that we can upload to this destination + upload_uri = str(upload_uri) + storage = StorageHelper.get(upload_uri) + upload_uri = storage.verify_upload(folder_uri=upload_uri) + + self._task.reporter.report_media_and_upload( + title=title, + series=series, + path=local_path, + stream=stream, + iter=iteration, + upload_uri=upload_uri, + max_history=max_history, + delete_after_upload=delete_after_upload, + file_extension=file_extension, + ) + def set_default_upload_destination(self, uri): """ Set the uri to upload all the debug images to.