Add media (audio) support for both Logger and Tensorboard bind

This commit is contained in:
allegroai 2020-04-09 13:14:14 +03:00
parent 7ac7e088a1
commit 648779380c
4 changed files with 175 additions and 14 deletions

View File

@ -179,31 +179,31 @@ class UploadEvent(MetricsEventAdapter):
_metric_counters = {}
_metric_counters_lock = Lock()
_image_file_history_size = int(config.get('metrics.file_history_size', 5))
_file_history_size = int(config.get('metrics.file_history_size', 5))
@staticmethod
def _replace_slash(part):
return part.replace('\\', '/').strip('/').replace('/', '.slash.')
def __init__(self, metric, variant, image_data, local_image_path=None, iter=0, upload_uri=None,
image_file_history_size=None, delete_after_upload=False, **kwargs):
file_history_size=None, delete_after_upload=False, **kwargs):
# param override_filename: override uploaded file name (notice extension will be added from local path
# param override_filename_ext: override uploaded file extension
if image_data is not None and not hasattr(image_data, 'shape'):
if image_data is not None and (not hasattr(image_data, 'shape') and not isinstance(image_data, six.BytesIO)):
raise ValueError('Image must have a shape attribute')
self._image_data = image_data
self._local_image_path = local_image_path
self._url = None
self._key = None
self._count = self._get_metric_count(metric, variant)
if not image_file_history_size:
image_file_history_size = self._image_file_history_size
if not file_history_size:
file_history_size = self._file_history_size
self._filename = kwargs.pop('override_filename', None)
if not self._filename:
if image_file_history_size < 1:
if file_history_size < 1:
self._filename = '%s_%s_%08d' % (metric, variant, self._count)
else:
self._filename = '%s_%s_%08d' % (metric, variant, self._count % image_file_history_size)
self._filename = '%s_%s_%08d' % (metric, variant, self._count % file_history_size)
# make sure we have to '/' in the filename because it might access other folders,
# and we don't want that to occur
@ -253,8 +253,10 @@ class UploadEvent(MetricsEventAdapter):
local_file = None
# don't provide file in case this event is out of the history window
last_count = self._get_metric_count(self.metric, self.variant, next=False)
if abs(self._count - last_count) > self._image_file_history_size:
if abs(self._count - last_count) > self._file_history_size:
output = None
elif isinstance(self._image_data, six.BytesIO):
output = self._image_data
elif self._image_data is not None:
image_data = self._image_data
if not isinstance(image_data, np.ndarray):
@ -318,10 +320,26 @@ class UploadEvent(MetricsEventAdapter):
class ImageEvent(UploadEvent):
def __init__(self, metric, variant, image_data, local_image_path=None, iter=0, upload_uri=None,
image_file_history_size=None, delete_after_upload=False, **kwargs):
file_history_size=None, delete_after_upload=False, **kwargs):
super(ImageEvent, self).__init__(metric, variant, image_data=image_data, local_image_path=local_image_path,
iter=iter, upload_uri=upload_uri,
image_file_history_size=image_file_history_size,
file_history_size=file_history_size,
delete_after_upload=delete_after_upload, **kwargs)
def get_api_event(self):
return events.MetricsImageEvent(
url=self._url,
key=self._key,
**self._get_base_dict()
)
class MediaEvent(UploadEvent):
def __init__(self, metric, variant, stream, local_image_path=None, iter=0, upload_uri=None,
file_history_size=None, delete_after_upload=False, **kwargs):
super(MediaEvent, self).__init__(metric, variant, image_data=stream, local_image_path=local_image_path,
iter=iter, upload_uri=upload_uri,
file_history_size=file_history_size,
delete_after_upload=delete_after_upload, **kwargs)
def get_api_event(self):

View File

@ -16,7 +16,7 @@ from ...utilities.plotly_reporter import create_2d_histogram_plot, create_value_
create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict, \
create_image_plot, create_plotly_table
from ...utilities.py3_interop import AbstractContextManager
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload, UploadEvent
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload, UploadEvent, MediaEvent
class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncManagerMixin):
@ -204,11 +204,28 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
ev = ImageEventNoUpload(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter, src=src)
self._report(ev)
def report_media(self, title, series, src, iter):
"""
Report a media link.
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param src: Media source URI. This URI will be used by the webapp and workers when trying to obtain the image
for presentation of processing. Currently only http(s), file and s3 schemes are supported.
:type src: str
:param iter: Iteration number
:type value: int
"""
ev = ImageEventNoUpload(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter, src=src)
self._report(ev)
def report_image_and_upload(self, title, series, iter, path=None, image=None, upload_uri=None,
max_image_history=None, delete_after_upload=False):
"""
Report an image and upload its contents. Image is uploaded to a preconfigured bucket (see setup_upload()) with
a key (filename) describing the task ID, title, series and iteration.
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
@ -228,11 +245,44 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
raise ValueError('Upload configuration is required (use setup_upload())')
if len([x for x in (path, image) if x is not None]) != 1:
raise ValueError('Expected only one of [filename, image]')
kwargs = dict(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter, image_file_history_size=max_image_history)
kwargs = dict(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter, file_history_size=max_image_history)
ev = ImageEvent(image_data=image, upload_uri=upload_uri, local_image_path=path,
delete_after_upload=delete_after_upload, **kwargs)
self._report(ev)
def report_media_and_upload(self, title, series, iter, path=None, stream=None, upload_uri=None,
file_extension=None, max_history=None, delete_after_upload=False):
"""
Report a media file/stream and upload its contents.
Media is uploaded to a preconfigured bucket
(see setup_upload()) with a key (filename) describing the task ID, title, series and iteration.
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param iter: Iteration number
:type iter: int
:param path: A path to an image file. Required unless matrix is provided.
:type path: str
:param stream: File stream
:param file_extension: file extension to use when stream is passed
:param max_history: maximum number of files to store per metric/variant combination
use negative value for unlimited. default is set in global configuration (default=5)
:param delete_after_upload: if True, one the file was uploaded the local copy will be deleted
:type delete_after_upload: boolean
"""
if not self._storage_uri and not upload_uri:
raise ValueError('Upload configuration is required (use setup_upload())')
if len([x for x in (path, stream) if x is not None]) != 1:
raise ValueError('Expected only one of [filename, stream]')
kwargs = dict(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter,
file_history_size=max_history)
ev = MediaEvent(stream=stream, upload_uri=upload_uri, local_image_path=path,
override_filename_ext=file_extension,
delete_after_upload=delete_after_upload, **kwargs)
self._report(ev)
def report_histogram(self, title, series, histogram, iter, labels=None, xlabels=None,
xtitle=None, ytitle=None, comment=None):
"""
@ -542,7 +592,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
raise ValueError('Upload configuration is required (use setup_upload())')
if len([x for x in (path, matrix) if x is not None]) != 1:
raise ValueError('Expected only one of [filename, matrix]')
kwargs = dict(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter, image_file_history_size=max_image_history)
kwargs = dict(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter, file_history_size=max_image_history)
ev = UploadEvent(image_data=matrix, upload_uri=upload_uri, local_image_path=path,
delete_after_upload=delete_after_upload, **kwargs)
_, url = ev.get_target_full_upload_uri(upload_uri or self._storage_uri, self._metrics.storage_key_prefix)

View File

@ -5,6 +5,7 @@ import threading
from collections import defaultdict
from functools import partial
from io import BytesIO
from mimetypes import guess_extension
from typing import Any
import numpy as np
@ -454,6 +455,31 @@ class EventTrainsWriter(object):
except Exception:
pass
def _add_audio(self, tag, step, values):
# only report images every specific interval
if step % self.image_report_freq != 0:
return None
audio_str = values['encodedAudioString']
audio_data = base64.b64decode(audio_str)
if audio_data is None:
return
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title',
auto_reduce_num_split=True)
step = self._fix_step_counter(title, series, step)
stream = BytesIO(audio_data)
file_extension = guess_extension(values['contentType']) or '.{}'.format(values['contentType'].split('/')[-1])
self._logger.report_media(
title=title,
series=series,
iteration=step,
stream=stream,
file_extension=file_extension,
max_history=self.max_keep_images,
)
def _fix_step_counter(self, title, series, step):
key = (title, series)
if key not in EventTrainsWriter._title_series_wraparound_counter:
@ -473,7 +499,7 @@ class EventTrainsWriter(object):
def add_event(self, event, step=None, walltime=None, **kwargs):
supported_metrics = {
'simpleValue', 'image', 'histo', 'tensor'
'simpleValue', 'image', 'histo', 'tensor', 'audio'
}
def get_data(value_dict, metric_search_order):
@ -531,6 +557,8 @@ class EventTrainsWriter(object):
self._add_histogram(tag=tag, step=step, histo_data=values)
elif metric == 'image':
self._add_image(tag=tag, step=step, img_data=values)
elif metric == 'audio':
self._add_audio(tag, step, values)
elif metric == 'tensor' and values.get('dtype') == 'DT_STRING':
# text, just print to console
text = base64.b64decode('\n'.join(values['stringVal'])).decode('utf-8')

View File

@ -477,6 +477,71 @@ class Logger(object):
delete_after_upload=delete_after_upload,
)
def report_media(self, title, series, iteration, local_path=None, stream=None,
file_extension=None, max_history=None, delete_after_upload=False, url=None):
"""
Report an image and upload its contents.
Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename)
describing the task ID, title, series and iteration.
.. note::
:paramref:`~.Logger.report_image.local_path`, :paramref:`~.Logger.report_image.url`, :paramref:`~.Logger.report_image.image` and :paramref:`~.Logger.report_image.matrix`
are mutually exclusive, and at least one must be provided.
:param str title: Title (AKA metric)
:param str series: Series (AKA variant)
:param int iteration: Iteration number
:param str local_path: A path to an image file.
:param stream: BytesIO stream to upload (must provide file extension if used)
:param str url: A URL to the location of a pre-uploaded image.
:param file_extension: file extension to use when stream is passed
:param int max_history: maximum number of media files to store per metric/variant combination
use negative value for unlimited. default is set in global configuration (default=5)
:param bool delete_after_upload: if True, one the file was uploaded the local copy will be deleted
"""
mutually_exclusive(
UsageError, _check_none=True,
local_path=local_path or None, url=url or None, stream=stream,
)
if stream is not None and not file_extension:
raise ValueError("No file extension provided for stream media upload")
# if task was not started, we have to start it
self._start_task_if_needed()
self._touch_title_series(title, series)
if url:
self._task.reporter.report_media(
title=title,
series=series,
src=url,
iter=iteration,
)
else:
upload_uri = self.get_default_upload_destination()
if not upload_uri:
upload_uri = Path(get_cache_dir()) / 'debug_images'
upload_uri.mkdir(parents=True, exist_ok=True)
# Verify that we can upload to this destination
upload_uri = str(upload_uri)
storage = StorageHelper.get(upload_uri)
upload_uri = storage.verify_upload(folder_uri=upload_uri)
self._task.reporter.report_media_and_upload(
title=title,
series=series,
path=local_path,
stream=stream,
iter=iteration,
upload_uri=upload_uri,
max_history=max_history,
delete_after_upload=delete_after_upload,
file_extension=file_extension,
)
def set_default_upload_destination(self, uri):
"""
Set the uri to upload all the debug images to.