1
0
mirror of https://github.com/clearml/clearml synced 2025-04-08 06:34:37 +00:00

Add seaborn support and SVG support for matplotlib

This commit is contained in:
allegroai 2019-07-13 23:53:19 +03:00
parent c9221e3fbb
commit cac4ac12b8
6 changed files with 172 additions and 26 deletions

View File

@ -172,11 +172,12 @@ class ImageEvent(MetricsEventAdapter):
_metric_counters_lock = Lock()
_image_file_history_size = int(config.get('metrics.file_history_size', 5))
def __init__(self, metric, variant, image_data, iter=0, upload_uri=None,
def __init__(self, metric, variant, image_data, local_image_path=None, iter=0, upload_uri=None,
image_file_history_size=None, **kwargs):
if not hasattr(image_data, 'shape'):
if image_data is not None and not hasattr(image_data, 'shape'):
raise ValueError('Image must have a shape attribute')
self._image_data = image_data
self._local_image_path = local_image_path
self._url = None
self._key = None
self._count = self._get_metric_count(metric, variant)
@ -187,6 +188,12 @@ class ImageEvent(MetricsEventAdapter):
else:
self._filename = '%s_%s_%08d' % (metric, variant, self._count % image_file_history_size)
self._upload_uri = upload_uri
# get upload uri upfront
image_format = self._format.lower() if self._image_data is not None else \
pathlib2.Path(self._local_image_path).suffix
self._upload_filename = str(pathlib2.Path(self._filename).with_suffix(image_format))
super(ImageEvent, self).__init__(metric, variant, iter=iter, **kwargs)
@classmethod
@ -221,7 +228,7 @@ class ImageEvent(MetricsEventAdapter):
last_count = self._get_metric_count(self.metric, self.variant, next=False)
if abs(self._count - last_count) > self._image_file_history_size:
output = None
else:
elif self._image_data is not None:
image_data = self._image_data
if not isinstance(image_data, np.ndarray):
# try conversion, if it fails we'll leave it to the user.
@ -245,14 +252,24 @@ class ImageEvent(MetricsEventAdapter):
output = six.BytesIO(img_bytes.tostring())
output.seek(0)
filename = str(pathlib2.Path(self._filename).with_suffix(self._format.lower()))
else:
with open(self._local_image_path, 'rb') as f:
output = six.BytesIO(f.read())
output.seek(0)
return self.FileEntry(
event=self,
name=filename,
name=self._upload_filename,
stream=output,
url_prop='url',
key_prop='key',
upload_uri=self._upload_uri
)
def get_target_full_upload_uri(self, storage_uri, storage_key_prefix):
e_storage_uri = self._upload_uri or storage_uri
# if we have an entry (with or without a stream), we'll generate the URL and store it in the event
filename = self._upload_filename
key = '/'.join(x for x in (storage_key_prefix, self.metric, self.variant, filename.strip('/')) if x)
url = '/'.join(x.strip('/') for x in (e_storage_uri, key))
return key, url

View File

@ -116,12 +116,7 @@ class Metrics(InterfaceBase):
entry = ev.get_file_entry()
kwargs = {}
if entry:
e_storage_uri = entry.upload_uri or storage_uri
self._file_related_event_time = now
# if we have an entry (with or without a stream), we'll generate the URL and store it in the event
filename = entry.name
key = '/'.join(x for x in (self._storage_key_prefix, ev.metric, ev.variant, filename.strip('/')) if x)
url = '/'.join(x.strip('/') for x in (e_storage_uri, key))
key, url = ev.get_target_full_upload_uri(storage_uri, self.storage_key_prefix)
kwargs[entry.key_prop] = key
kwargs[entry.url_prop] = url
if not entry.stream:

View File

@ -8,7 +8,8 @@ from ..base import InterfaceBase
from ..setupuploadmixin import SetupUploadMixin
from ...utilities.async_manager import AsyncManagerMixin
from ...utilities.plotly_reporter import create_2d_histogram_plot, create_value_matrix, create_3d_surface, \
create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict
create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict, \
create_image_plot
from ...utilities.py3_interop import AbstractContextManager
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload
@ -187,9 +188,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
raise ValueError('Expected only one of [filename, matrix]')
kwargs = dict(metric=self._normalize_name(title),
variant=self._normalize_name(series), iter=iter, image_file_history_size=max_image_history)
if matrix is None:
matrix = cv2.imread(path)
ev = ImageEvent(image_data=matrix, upload_uri=upload_uri, **kwargs)
ev = ImageEvent(image_data=matrix, upload_uri=upload_uri, local_image_path=path, **kwargs)
self._report(ev)
def report_histogram(self, title, series, histogram, iter, labels=None, xlabels=None, comment=None):
@ -445,6 +444,49 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
iter=iter,
)
def report_image_plot_and_upload(self, title, series, iter, path=None, matrix=None,
upload_uri=None, max_image_history=None):
"""
Report an image as plot and upload its contents.
Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename)
describing the task ID, title, series and iteration.
Then a plotly object is created and registered, this plotly objects points to the uploaded image
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param iter: Iteration number
:type value: int
:param path: A path to an image file. Required unless matrix is provided.
:type path: str
:param matrix: A 3D numpy.ndarray object containing image data (BGR). Required unless filename is provided.
:type matrix: str
:param max_image_history: maximum number of image to store per metric/variant combination
use negative value for unlimited. default is set in global configuration (default=5)
"""
if not upload_uri and not self._storage_uri:
raise ValueError('Upload configuration is required (use setup_upload())')
if len([x for x in (path, matrix) if x is not None]) != 1:
raise ValueError('Expected only one of [filename, matrix]')
kwargs = dict(metric=self._normalize_name(title),
variant=self._normalize_name(series), iter=iter, image_file_history_size=max_image_history)
ev = ImageEvent(image_data=matrix, upload_uri=upload_uri, local_image_path=path, **kwargs)
_, url = ev.get_target_full_upload_uri(upload_uri or self._storage_uri, self._metrics.storage_key_prefix)
self._report(ev)
plotly_dict = create_image_plot(
image_src=url,
title=title + '/' + series,
width=matrix.shape[1] if matrix is not None else 640,
height=matrix.shape[0] if matrix is not None else 480,
)
return self.report_plot(
title=self._normalize_name(title),
series=self._normalize_name(series),
plot=plotly_dict,
iter=iter,
)
@classmethod
def _normalize_name(cls, name):
if not name:

View File

@ -1,7 +1,7 @@
import os
import sys
from tempfile import mkstemp
import cv2
import numpy as np
import six
from six import BytesIO
@ -129,6 +129,7 @@ class PatchedMatplotlib:
# convert to plotly
image = None
plotly_fig = None
image_format = 'svg'
if not force_save_as_image:
# noinspection PyBroadException
try:
@ -140,17 +141,28 @@ class PatchedMatplotlib:
return renderer.plotly_fig
plotly_fig = our_mpl_to_plotly(mpl_fig)
except Exception:
pass
except Exception as ex:
# this was an image, change format to jpeg
if 'selfie' in str(ex):
image_format = 'jpeg'
# plotly could not serialize the plot, we should convert to image
if not plotly_fig:
plotly_fig = None
buffer_ = BytesIO()
plt.savefig(buffer_, format="png", bbox_inches='tight', pad_inches=0)
buffer_.seek(0)
buffer = buffer_.getbuffer() if not six.PY2 else buffer_.getvalue()
image = cv2.imdecode(np.frombuffer(buffer, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
# noinspection PyBroadException
try:
# first try SVG if we fail then fallback to png
buffer_ = BytesIO()
plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0)
buffer_.seek(0)
except Exception:
image_format = 'png'
buffer_ = BytesIO()
plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0)
buffer_.seek(0)
fd, image = mkstemp(suffix='.'+image_format)
os.write(fd, buffer_.read())
os.close(fd)
# check if we need to restore the active object
if set_active and not _pylab_helpers.Gcf.get_active():
@ -185,7 +197,14 @@ class PatchedMatplotlib:
PatchedMatplotlib._global_image_counter += 1
logger = PatchedMatplotlib._current_task.get_logger()
title = plot_title or 'untitled %d' % PatchedMatplotlib._global_image_counter
logger.report_image_and_upload(title=title, series='plot image', matrix=image,
# this is actually a failed plot, we should put it under plots:
# currently disabled
# if image_format == 'svg':
# logger.report_image_plot_and_upload(title=title, series='plot image', path=image,
# iteration=PatchedMatplotlib._global_image_counter
# if plot_title else 0)
# else:
logger.report_image_and_upload(title=title, series='plot image', path=image,
iteration=PatchedMatplotlib._global_image_counter
if plot_title else 0)
except Exception:

View File

@ -533,6 +533,49 @@ class Logger(object):
max_image_history=max_image_history,
)
def report_image_plot_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None):
"""
Report an image, upload its contents, and present in plots section using plotly
Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename)
describing the task ID, title, series and iteration.
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param iteration: Iteration number
:type iteration: int
:param path: A path to an image file. Required unless matrix is provided.
:type path: str
:param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided.
:type matrix: str
:param max_image_history: maximum number of image to store per metric/variant combination \
use negative value for unlimited. default is set in global configuration (default=5)
:type max_image_history: int
"""
# if task was not started, we have to start it
self._start_task_if_needed()
upload_uri = self._default_upload_destination or self._task._get_default_report_storage_uri()
if not upload_uri:
upload_uri = Path(get_cache_dir()) / 'debug_images'
upload_uri.mkdir(parents=True, exist_ok=True)
# Verify that we can upload to this destination
upload_uri = str(upload_uri)
storage = StorageHelper.get(upload_uri)
upload_uri = storage.verify_upload(folder_uri=upload_uri)
self._task.reporter.report_image_plot_and_upload(
title=title,
series=series,
path=path,
matrix=matrix,
iter=iteration,
upload_uri=upload_uri,
max_image_history=max_image_history,
)
def set_default_upload_destination(self, uri):
"""
Set the uri to upload all the debug images to.

View File

@ -245,6 +245,36 @@ def create_3d_surface(np_value_matrix, title="3D Surface", xlabels=None, ylabels
return conf_matrix_plot
def create_image_plot(image_src, title, width=640, height=480, series=None, comment=None):
image_plot = {
"data": [],
"layout": {
"xaxis": {"visible": False, "range": [0, width]},
"yaxis": {"visible": False, "range": [0, height]},
"width": width,
"height": height,
"margin": {'l': 0, 'r': 0, 't': 0, 'b': 0},
"images": [{
"sizex": width,
"sizey": height,
"xref": "x",
"yref": "y",
"opacity": 1.0,
"x": 0,
"y": int(height / 2),
"yanchor": "middle",
"sizing": "contain",
"layer": "below",
"source": image_src
}],
"showlegend": False,
"title": title if not comment else (title + '<br><sup>' + comment + '</sup>'),
"name": series,
}
}
return image_plot
def _get_z_colorbar_data(z_data=None, values=None, colors=None):
if values is None:
values = [0, 1. / 10, 2. / 10, 6. / 10, 9. / 10]