diff --git a/trains/backend_interface/metrics/events.py b/trains/backend_interface/metrics/events.py index 896540f8..4af86cbb 100644 --- a/trains/backend_interface/metrics/events.py +++ b/trains/backend_interface/metrics/events.py @@ -172,11 +172,12 @@ class ImageEvent(MetricsEventAdapter): _metric_counters_lock = Lock() _image_file_history_size = int(config.get('metrics.file_history_size', 5)) - def __init__(self, metric, variant, image_data, iter=0, upload_uri=None, + def __init__(self, metric, variant, image_data, local_image_path=None, iter=0, upload_uri=None, image_file_history_size=None, **kwargs): - if not hasattr(image_data, 'shape'): + if image_data is not None and not hasattr(image_data, 'shape'): raise ValueError('Image must have a shape attribute') self._image_data = image_data + self._local_image_path = local_image_path self._url = None self._key = None self._count = self._get_metric_count(metric, variant) @@ -187,6 +188,12 @@ class ImageEvent(MetricsEventAdapter): else: self._filename = '%s_%s_%08d' % (metric, variant, self._count % image_file_history_size) self._upload_uri = upload_uri + + # get upload uri upfront + image_format = self._format.lower() if self._image_data is not None else \ + pathlib2.Path(self._local_image_path).suffix + self._upload_filename = str(pathlib2.Path(self._filename).with_suffix(image_format)) + super(ImageEvent, self).__init__(metric, variant, iter=iter, **kwargs) @classmethod @@ -221,7 +228,7 @@ class ImageEvent(MetricsEventAdapter): last_count = self._get_metric_count(self.metric, self.variant, next=False) if abs(self._count - last_count) > self._image_file_history_size: output = None - else: + elif self._image_data is not None: image_data = self._image_data if not isinstance(image_data, np.ndarray): # try conversion, if it fails we'll leave it to the user. @@ -245,14 +252,24 @@ class ImageEvent(MetricsEventAdapter): output = six.BytesIO(img_bytes.tostring()) output.seek(0) - - filename = str(pathlib2.Path(self._filename).with_suffix(self._format.lower())) + else: + with open(self._local_image_path, 'rb') as f: + output = six.BytesIO(f.read()) + output.seek(0) return self.FileEntry( event=self, - name=filename, + name=self._upload_filename, stream=output, url_prop='url', key_prop='key', upload_uri=self._upload_uri ) + + def get_target_full_upload_uri(self, storage_uri, storage_key_prefix): + e_storage_uri = self._upload_uri or storage_uri + # if we have an entry (with or without a stream), we'll generate the URL and store it in the event + filename = self._upload_filename + key = '/'.join(x for x in (storage_key_prefix, self.metric, self.variant, filename.strip('/')) if x) + url = '/'.join(x.strip('/') for x in (e_storage_uri, key)) + return key, url diff --git a/trains/backend_interface/metrics/interface.py b/trains/backend_interface/metrics/interface.py index 13a9b91b..f5b00ef9 100644 --- a/trains/backend_interface/metrics/interface.py +++ b/trains/backend_interface/metrics/interface.py @@ -116,12 +116,7 @@ class Metrics(InterfaceBase): entry = ev.get_file_entry() kwargs = {} if entry: - e_storage_uri = entry.upload_uri or storage_uri - self._file_related_event_time = now - # if we have an entry (with or without a stream), we'll generate the URL and store it in the event - filename = entry.name - key = '/'.join(x for x in (self._storage_key_prefix, ev.metric, ev.variant, filename.strip('/')) if x) - url = '/'.join(x.strip('/') for x in (e_storage_uri, key)) + key, url = ev.get_target_full_upload_uri(storage_uri, self.storage_key_prefix) kwargs[entry.key_prop] = key kwargs[entry.url_prop] = url if not entry.stream: diff --git a/trains/backend_interface/metrics/reporter.py b/trains/backend_interface/metrics/reporter.py index 3716ecbf..e29074a6 100644 --- a/trains/backend_interface/metrics/reporter.py +++ b/trains/backend_interface/metrics/reporter.py @@ -8,7 +8,8 @@ from ..base import InterfaceBase from ..setupuploadmixin import SetupUploadMixin from ...utilities.async_manager import AsyncManagerMixin from ...utilities.plotly_reporter import create_2d_histogram_plot, create_value_matrix, create_3d_surface, \ - create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict + create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict, \ + create_image_plot from ...utilities.py3_interop import AbstractContextManager from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload @@ -187,9 +188,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan raise ValueError('Expected only one of [filename, matrix]') kwargs = dict(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter, image_file_history_size=max_image_history) - if matrix is None: - matrix = cv2.imread(path) - ev = ImageEvent(image_data=matrix, upload_uri=upload_uri, **kwargs) + ev = ImageEvent(image_data=matrix, upload_uri=upload_uri, local_image_path=path, **kwargs) self._report(ev) def report_histogram(self, title, series, histogram, iter, labels=None, xlabels=None, comment=None): @@ -445,6 +444,49 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan iter=iter, ) + def report_image_plot_and_upload(self, title, series, iter, path=None, matrix=None, + upload_uri=None, max_image_history=None): + """ + Report an image as plot and upload its contents. + Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename) + describing the task ID, title, series and iteration. + Then a plotly object is created and registered, this plotly objects points to the uploaded image + :param title: Title (AKA metric) + :type title: str + :param series: Series (AKA variant) + :type series: str + :param iter: Iteration number + :type value: int + :param path: A path to an image file. Required unless matrix is provided. + :type path: str + :param matrix: A 3D numpy.ndarray object containing image data (BGR). Required unless filename is provided. + :type matrix: str + :param max_image_history: maximum number of image to store per metric/variant combination + use negative value for unlimited. default is set in global configuration (default=5) + """ + if not upload_uri and not self._storage_uri: + raise ValueError('Upload configuration is required (use setup_upload())') + if len([x for x in (path, matrix) if x is not None]) != 1: + raise ValueError('Expected only one of [filename, matrix]') + kwargs = dict(metric=self._normalize_name(title), + variant=self._normalize_name(series), iter=iter, image_file_history_size=max_image_history) + ev = ImageEvent(image_data=matrix, upload_uri=upload_uri, local_image_path=path, **kwargs) + _, url = ev.get_target_full_upload_uri(upload_uri or self._storage_uri, self._metrics.storage_key_prefix) + self._report(ev) + plotly_dict = create_image_plot( + image_src=url, + title=title + '/' + series, + width=matrix.shape[1] if matrix is not None else 640, + height=matrix.shape[0] if matrix is not None else 480, + ) + + return self.report_plot( + title=self._normalize_name(title), + series=self._normalize_name(series), + plot=plotly_dict, + iter=iter, + ) + @classmethod def _normalize_name(cls, name): if not name: diff --git a/trains/binding/matplotlib_bind.py b/trains/binding/matplotlib_bind.py index 626fd3a8..b8db7f13 100644 --- a/trains/binding/matplotlib_bind.py +++ b/trains/binding/matplotlib_bind.py @@ -1,7 +1,7 @@ +import os import sys +from tempfile import mkstemp -import cv2 -import numpy as np import six from six import BytesIO @@ -129,6 +129,7 @@ class PatchedMatplotlib: # convert to plotly image = None plotly_fig = None + image_format = 'svg' if not force_save_as_image: # noinspection PyBroadException try: @@ -140,17 +141,28 @@ class PatchedMatplotlib: return renderer.plotly_fig plotly_fig = our_mpl_to_plotly(mpl_fig) - except Exception: - pass + except Exception as ex: + # this was an image, change format to jpeg + if 'selfie' in str(ex): + image_format = 'jpeg' # plotly could not serialize the plot, we should convert to image if not plotly_fig: plotly_fig = None - buffer_ = BytesIO() - plt.savefig(buffer_, format="png", bbox_inches='tight', pad_inches=0) - buffer_.seek(0) - buffer = buffer_.getbuffer() if not six.PY2 else buffer_.getvalue() - image = cv2.imdecode(np.frombuffer(buffer, dtype=np.uint8), cv2.IMREAD_UNCHANGED) + # noinspection PyBroadException + try: + # first try SVG if we fail then fallback to png + buffer_ = BytesIO() + plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0) + buffer_.seek(0) + except Exception: + image_format = 'png' + buffer_ = BytesIO() + plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0) + buffer_.seek(0) + fd, image = mkstemp(suffix='.'+image_format) + os.write(fd, buffer_.read()) + os.close(fd) # check if we need to restore the active object if set_active and not _pylab_helpers.Gcf.get_active(): @@ -185,7 +197,14 @@ class PatchedMatplotlib: PatchedMatplotlib._global_image_counter += 1 logger = PatchedMatplotlib._current_task.get_logger() title = plot_title or 'untitled %d' % PatchedMatplotlib._global_image_counter - logger.report_image_and_upload(title=title, series='plot image', matrix=image, + # this is actually a failed plot, we should put it under plots: + # currently disabled + # if image_format == 'svg': + # logger.report_image_plot_and_upload(title=title, series='plot image', path=image, + # iteration=PatchedMatplotlib._global_image_counter + # if plot_title else 0) + # else: + logger.report_image_and_upload(title=title, series='plot image', path=image, iteration=PatchedMatplotlib._global_image_counter if plot_title else 0) except Exception: diff --git a/trains/logger.py b/trains/logger.py index d4d300c4..8d208512 100644 --- a/trains/logger.py +++ b/trains/logger.py @@ -533,6 +533,49 @@ class Logger(object): max_image_history=max_image_history, ) + def report_image_plot_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None): + """ + Report an image, upload its contents, and present in plots section using plotly + + Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename) + describing the task ID, title, series and iteration. + + :param title: Title (AKA metric) + :type title: str + :param series: Series (AKA variant) + :type series: str + :param iteration: Iteration number + :type iteration: int + :param path: A path to an image file. Required unless matrix is provided. + :type path: str + :param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided. + :type matrix: str + :param max_image_history: maximum number of image to store per metric/variant combination \ + use negative value for unlimited. default is set in global configuration (default=5) + :type max_image_history: int + """ + + # if task was not started, we have to start it + self._start_task_if_needed() + upload_uri = self._default_upload_destination or self._task._get_default_report_storage_uri() + if not upload_uri: + upload_uri = Path(get_cache_dir()) / 'debug_images' + upload_uri.mkdir(parents=True, exist_ok=True) + # Verify that we can upload to this destination + upload_uri = str(upload_uri) + storage = StorageHelper.get(upload_uri) + upload_uri = storage.verify_upload(folder_uri=upload_uri) + + self._task.reporter.report_image_plot_and_upload( + title=title, + series=series, + path=path, + matrix=matrix, + iter=iteration, + upload_uri=upload_uri, + max_image_history=max_image_history, + ) + def set_default_upload_destination(self, uri): """ Set the uri to upload all the debug images to. diff --git a/trains/utilities/plotly_reporter.py b/trains/utilities/plotly_reporter.py index ef454504..99a4f7b4 100644 --- a/trains/utilities/plotly_reporter.py +++ b/trains/utilities/plotly_reporter.py @@ -245,6 +245,36 @@ def create_3d_surface(np_value_matrix, title="3D Surface", xlabels=None, ylabels return conf_matrix_plot +def create_image_plot(image_src, title, width=640, height=480, series=None, comment=None): + image_plot = { + "data": [], + "layout": { + "xaxis": {"visible": False, "range": [0, width]}, + "yaxis": {"visible": False, "range": [0, height]}, + "width": width, + "height": height, + "margin": {'l': 0, 'r': 0, 't': 0, 'b': 0}, + "images": [{ + "sizex": width, + "sizey": height, + "xref": "x", + "yref": "y", + "opacity": 1.0, + "x": 0, + "y": int(height / 2), + "yanchor": "middle", + "sizing": "contain", + "layer": "below", + "source": image_src + }], + "showlegend": False, + "title": title if not comment else (title + '
' + comment + ''), + "name": series, + } + } + return image_plot + + def _get_z_colorbar_data(z_data=None, values=None, colors=None): if values is None: values = [0, 1. / 10, 2. / 10, 6. / 10, 9. / 10]