diff --git a/trains/backend_interface/metrics/events.py b/trains/backend_interface/metrics/events.py
index 896540f8..4af86cbb 100644
--- a/trains/backend_interface/metrics/events.py
+++ b/trains/backend_interface/metrics/events.py
@@ -172,11 +172,12 @@ class ImageEvent(MetricsEventAdapter):
_metric_counters_lock = Lock()
_image_file_history_size = int(config.get('metrics.file_history_size', 5))
- def __init__(self, metric, variant, image_data, iter=0, upload_uri=None,
+ def __init__(self, metric, variant, image_data, local_image_path=None, iter=0, upload_uri=None,
image_file_history_size=None, **kwargs):
- if not hasattr(image_data, 'shape'):
+ if image_data is not None and not hasattr(image_data, 'shape'):
raise ValueError('Image must have a shape attribute')
self._image_data = image_data
+ self._local_image_path = local_image_path
self._url = None
self._key = None
self._count = self._get_metric_count(metric, variant)
@@ -187,6 +188,12 @@ class ImageEvent(MetricsEventAdapter):
else:
self._filename = '%s_%s_%08d' % (metric, variant, self._count % image_file_history_size)
self._upload_uri = upload_uri
+
+ # get upload uri upfront
+ image_format = self._format.lower() if self._image_data is not None else \
+ pathlib2.Path(self._local_image_path).suffix
+ self._upload_filename = str(pathlib2.Path(self._filename).with_suffix(image_format))
+
super(ImageEvent, self).__init__(metric, variant, iter=iter, **kwargs)
@classmethod
@@ -221,7 +228,7 @@ class ImageEvent(MetricsEventAdapter):
last_count = self._get_metric_count(self.metric, self.variant, next=False)
if abs(self._count - last_count) > self._image_file_history_size:
output = None
- else:
+ elif self._image_data is not None:
image_data = self._image_data
if not isinstance(image_data, np.ndarray):
# try conversion, if it fails we'll leave it to the user.
@@ -245,14 +252,24 @@ class ImageEvent(MetricsEventAdapter):
output = six.BytesIO(img_bytes.tostring())
output.seek(0)
-
- filename = str(pathlib2.Path(self._filename).with_suffix(self._format.lower()))
+ else:
+ with open(self._local_image_path, 'rb') as f:
+ output = six.BytesIO(f.read())
+ output.seek(0)
return self.FileEntry(
event=self,
- name=filename,
+ name=self._upload_filename,
stream=output,
url_prop='url',
key_prop='key',
upload_uri=self._upload_uri
)
+
+ def get_target_full_upload_uri(self, storage_uri, storage_key_prefix):
+ e_storage_uri = self._upload_uri or storage_uri
+ # if we have an entry (with or without a stream), we'll generate the URL and store it in the event
+ filename = self._upload_filename
+ key = '/'.join(x for x in (storage_key_prefix, self.metric, self.variant, filename.strip('/')) if x)
+ url = '/'.join(x.strip('/') for x in (e_storage_uri, key))
+ return key, url
diff --git a/trains/backend_interface/metrics/interface.py b/trains/backend_interface/metrics/interface.py
index 13a9b91b..f5b00ef9 100644
--- a/trains/backend_interface/metrics/interface.py
+++ b/trains/backend_interface/metrics/interface.py
@@ -116,12 +116,7 @@ class Metrics(InterfaceBase):
entry = ev.get_file_entry()
kwargs = {}
if entry:
- e_storage_uri = entry.upload_uri or storage_uri
- self._file_related_event_time = now
- # if we have an entry (with or without a stream), we'll generate the URL and store it in the event
- filename = entry.name
- key = '/'.join(x for x in (self._storage_key_prefix, ev.metric, ev.variant, filename.strip('/')) if x)
- url = '/'.join(x.strip('/') for x in (e_storage_uri, key))
+ key, url = ev.get_target_full_upload_uri(storage_uri, self.storage_key_prefix)
kwargs[entry.key_prop] = key
kwargs[entry.url_prop] = url
if not entry.stream:
diff --git a/trains/backend_interface/metrics/reporter.py b/trains/backend_interface/metrics/reporter.py
index 3716ecbf..e29074a6 100644
--- a/trains/backend_interface/metrics/reporter.py
+++ b/trains/backend_interface/metrics/reporter.py
@@ -8,7 +8,8 @@ from ..base import InterfaceBase
from ..setupuploadmixin import SetupUploadMixin
from ...utilities.async_manager import AsyncManagerMixin
from ...utilities.plotly_reporter import create_2d_histogram_plot, create_value_matrix, create_3d_surface, \
- create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict
+ create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict, \
+ create_image_plot
from ...utilities.py3_interop import AbstractContextManager
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload
@@ -187,9 +188,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
raise ValueError('Expected only one of [filename, matrix]')
kwargs = dict(metric=self._normalize_name(title),
variant=self._normalize_name(series), iter=iter, image_file_history_size=max_image_history)
- if matrix is None:
- matrix = cv2.imread(path)
- ev = ImageEvent(image_data=matrix, upload_uri=upload_uri, **kwargs)
+ ev = ImageEvent(image_data=matrix, upload_uri=upload_uri, local_image_path=path, **kwargs)
self._report(ev)
def report_histogram(self, title, series, histogram, iter, labels=None, xlabels=None, comment=None):
@@ -445,6 +444,49 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
iter=iter,
)
+ def report_image_plot_and_upload(self, title, series, iter, path=None, matrix=None,
+ upload_uri=None, max_image_history=None):
+ """
+ Report an image as plot and upload its contents.
+ Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename)
+ describing the task ID, title, series and iteration.
+ Then a plotly object is created and registered, this plotly objects points to the uploaded image
+ :param title: Title (AKA metric)
+ :type title: str
+ :param series: Series (AKA variant)
+ :type series: str
+ :param iter: Iteration number
+ :type value: int
+ :param path: A path to an image file. Required unless matrix is provided.
+ :type path: str
+ :param matrix: A 3D numpy.ndarray object containing image data (BGR). Required unless filename is provided.
+ :type matrix: str
+ :param max_image_history: maximum number of image to store per metric/variant combination
+ use negative value for unlimited. default is set in global configuration (default=5)
+ """
+ if not upload_uri and not self._storage_uri:
+ raise ValueError('Upload configuration is required (use setup_upload())')
+ if len([x for x in (path, matrix) if x is not None]) != 1:
+ raise ValueError('Expected only one of [filename, matrix]')
+ kwargs = dict(metric=self._normalize_name(title),
+ variant=self._normalize_name(series), iter=iter, image_file_history_size=max_image_history)
+ ev = ImageEvent(image_data=matrix, upload_uri=upload_uri, local_image_path=path, **kwargs)
+ _, url = ev.get_target_full_upload_uri(upload_uri or self._storage_uri, self._metrics.storage_key_prefix)
+ self._report(ev)
+ plotly_dict = create_image_plot(
+ image_src=url,
+ title=title + '/' + series,
+ width=matrix.shape[1] if matrix is not None else 640,
+ height=matrix.shape[0] if matrix is not None else 480,
+ )
+
+ return self.report_plot(
+ title=self._normalize_name(title),
+ series=self._normalize_name(series),
+ plot=plotly_dict,
+ iter=iter,
+ )
+
@classmethod
def _normalize_name(cls, name):
if not name:
diff --git a/trains/binding/matplotlib_bind.py b/trains/binding/matplotlib_bind.py
index 626fd3a8..b8db7f13 100644
--- a/trains/binding/matplotlib_bind.py
+++ b/trains/binding/matplotlib_bind.py
@@ -1,7 +1,7 @@
+import os
import sys
+from tempfile import mkstemp
-import cv2
-import numpy as np
import six
from six import BytesIO
@@ -129,6 +129,7 @@ class PatchedMatplotlib:
# convert to plotly
image = None
plotly_fig = None
+ image_format = 'svg'
if not force_save_as_image:
# noinspection PyBroadException
try:
@@ -140,17 +141,28 @@ class PatchedMatplotlib:
return renderer.plotly_fig
plotly_fig = our_mpl_to_plotly(mpl_fig)
- except Exception:
- pass
+ except Exception as ex:
+ # this was an image, change format to jpeg
+ if 'selfie' in str(ex):
+ image_format = 'jpeg'
# plotly could not serialize the plot, we should convert to image
if not plotly_fig:
plotly_fig = None
- buffer_ = BytesIO()
- plt.savefig(buffer_, format="png", bbox_inches='tight', pad_inches=0)
- buffer_.seek(0)
- buffer = buffer_.getbuffer() if not six.PY2 else buffer_.getvalue()
- image = cv2.imdecode(np.frombuffer(buffer, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
+ # noinspection PyBroadException
+ try:
+ # first try SVG if we fail then fallback to png
+ buffer_ = BytesIO()
+ plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0)
+ buffer_.seek(0)
+ except Exception:
+ image_format = 'png'
+ buffer_ = BytesIO()
+ plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0)
+ buffer_.seek(0)
+ fd, image = mkstemp(suffix='.'+image_format)
+ os.write(fd, buffer_.read())
+ os.close(fd)
# check if we need to restore the active object
if set_active and not _pylab_helpers.Gcf.get_active():
@@ -185,7 +197,14 @@ class PatchedMatplotlib:
PatchedMatplotlib._global_image_counter += 1
logger = PatchedMatplotlib._current_task.get_logger()
title = plot_title or 'untitled %d' % PatchedMatplotlib._global_image_counter
- logger.report_image_and_upload(title=title, series='plot image', matrix=image,
+ # this is actually a failed plot, we should put it under plots:
+ # currently disabled
+ # if image_format == 'svg':
+ # logger.report_image_plot_and_upload(title=title, series='plot image', path=image,
+ # iteration=PatchedMatplotlib._global_image_counter
+ # if plot_title else 0)
+ # else:
+ logger.report_image_and_upload(title=title, series='plot image', path=image,
iteration=PatchedMatplotlib._global_image_counter
if plot_title else 0)
except Exception:
diff --git a/trains/logger.py b/trains/logger.py
index d4d300c4..8d208512 100644
--- a/trains/logger.py
+++ b/trains/logger.py
@@ -533,6 +533,49 @@ class Logger(object):
max_image_history=max_image_history,
)
+ def report_image_plot_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None):
+ """
+ Report an image, upload its contents, and present in plots section using plotly
+
+ Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename)
+ describing the task ID, title, series and iteration.
+
+ :param title: Title (AKA metric)
+ :type title: str
+ :param series: Series (AKA variant)
+ :type series: str
+ :param iteration: Iteration number
+ :type iteration: int
+ :param path: A path to an image file. Required unless matrix is provided.
+ :type path: str
+ :param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided.
+ :type matrix: str
+ :param max_image_history: maximum number of image to store per metric/variant combination \
+ use negative value for unlimited. default is set in global configuration (default=5)
+ :type max_image_history: int
+ """
+
+ # if task was not started, we have to start it
+ self._start_task_if_needed()
+ upload_uri = self._default_upload_destination or self._task._get_default_report_storage_uri()
+ if not upload_uri:
+ upload_uri = Path(get_cache_dir()) / 'debug_images'
+ upload_uri.mkdir(parents=True, exist_ok=True)
+ # Verify that we can upload to this destination
+ upload_uri = str(upload_uri)
+ storage = StorageHelper.get(upload_uri)
+ upload_uri = storage.verify_upload(folder_uri=upload_uri)
+
+ self._task.reporter.report_image_plot_and_upload(
+ title=title,
+ series=series,
+ path=path,
+ matrix=matrix,
+ iter=iteration,
+ upload_uri=upload_uri,
+ max_image_history=max_image_history,
+ )
+
def set_default_upload_destination(self, uri):
"""
Set the uri to upload all the debug images to.
diff --git a/trains/utilities/plotly_reporter.py b/trains/utilities/plotly_reporter.py
index ef454504..99a4f7b4 100644
--- a/trains/utilities/plotly_reporter.py
+++ b/trains/utilities/plotly_reporter.py
@@ -245,6 +245,36 @@ def create_3d_surface(np_value_matrix, title="3D Surface", xlabels=None, ylabels
return conf_matrix_plot
+def create_image_plot(image_src, title, width=640, height=480, series=None, comment=None):
+ image_plot = {
+ "data": [],
+ "layout": {
+ "xaxis": {"visible": False, "range": [0, width]},
+ "yaxis": {"visible": False, "range": [0, height]},
+ "width": width,
+ "height": height,
+ "margin": {'l': 0, 'r': 0, 't': 0, 'b': 0},
+ "images": [{
+ "sizex": width,
+ "sizey": height,
+ "xref": "x",
+ "yref": "y",
+ "opacity": 1.0,
+ "x": 0,
+ "y": int(height / 2),
+ "yanchor": "middle",
+ "sizing": "contain",
+ "layer": "below",
+ "source": image_src
+ }],
+ "showlegend": False,
+ "title": title if not comment else (title + '
' + comment + ''),
+ "name": series,
+ }
+ }
+ return image_plot
+
+
def _get_z_colorbar_data(z_data=None, values=None, colors=None):
if values is None:
values = [0, 1. / 10, 2. / 10, 6. / 10, 9. / 10]