mirror of
https://github.com/clearml/clearml
synced 2025-04-08 06:34:37 +00:00
Add seaborn support and SVG support for matplotlib
This commit is contained in:
parent
c9221e3fbb
commit
cac4ac12b8
trains
@ -172,11 +172,12 @@ class ImageEvent(MetricsEventAdapter):
|
||||
_metric_counters_lock = Lock()
|
||||
_image_file_history_size = int(config.get('metrics.file_history_size', 5))
|
||||
|
||||
def __init__(self, metric, variant, image_data, iter=0, upload_uri=None,
|
||||
def __init__(self, metric, variant, image_data, local_image_path=None, iter=0, upload_uri=None,
|
||||
image_file_history_size=None, **kwargs):
|
||||
if not hasattr(image_data, 'shape'):
|
||||
if image_data is not None and not hasattr(image_data, 'shape'):
|
||||
raise ValueError('Image must have a shape attribute')
|
||||
self._image_data = image_data
|
||||
self._local_image_path = local_image_path
|
||||
self._url = None
|
||||
self._key = None
|
||||
self._count = self._get_metric_count(metric, variant)
|
||||
@ -187,6 +188,12 @@ class ImageEvent(MetricsEventAdapter):
|
||||
else:
|
||||
self._filename = '%s_%s_%08d' % (metric, variant, self._count % image_file_history_size)
|
||||
self._upload_uri = upload_uri
|
||||
|
||||
# get upload uri upfront
|
||||
image_format = self._format.lower() if self._image_data is not None else \
|
||||
pathlib2.Path(self._local_image_path).suffix
|
||||
self._upload_filename = str(pathlib2.Path(self._filename).with_suffix(image_format))
|
||||
|
||||
super(ImageEvent, self).__init__(metric, variant, iter=iter, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@ -221,7 +228,7 @@ class ImageEvent(MetricsEventAdapter):
|
||||
last_count = self._get_metric_count(self.metric, self.variant, next=False)
|
||||
if abs(self._count - last_count) > self._image_file_history_size:
|
||||
output = None
|
||||
else:
|
||||
elif self._image_data is not None:
|
||||
image_data = self._image_data
|
||||
if not isinstance(image_data, np.ndarray):
|
||||
# try conversion, if it fails we'll leave it to the user.
|
||||
@ -245,14 +252,24 @@ class ImageEvent(MetricsEventAdapter):
|
||||
|
||||
output = six.BytesIO(img_bytes.tostring())
|
||||
output.seek(0)
|
||||
|
||||
filename = str(pathlib2.Path(self._filename).with_suffix(self._format.lower()))
|
||||
else:
|
||||
with open(self._local_image_path, 'rb') as f:
|
||||
output = six.BytesIO(f.read())
|
||||
output.seek(0)
|
||||
|
||||
return self.FileEntry(
|
||||
event=self,
|
||||
name=filename,
|
||||
name=self._upload_filename,
|
||||
stream=output,
|
||||
url_prop='url',
|
||||
key_prop='key',
|
||||
upload_uri=self._upload_uri
|
||||
)
|
||||
|
||||
def get_target_full_upload_uri(self, storage_uri, storage_key_prefix):
|
||||
e_storage_uri = self._upload_uri or storage_uri
|
||||
# if we have an entry (with or without a stream), we'll generate the URL and store it in the event
|
||||
filename = self._upload_filename
|
||||
key = '/'.join(x for x in (storage_key_prefix, self.metric, self.variant, filename.strip('/')) if x)
|
||||
url = '/'.join(x.strip('/') for x in (e_storage_uri, key))
|
||||
return key, url
|
||||
|
@ -116,12 +116,7 @@ class Metrics(InterfaceBase):
|
||||
entry = ev.get_file_entry()
|
||||
kwargs = {}
|
||||
if entry:
|
||||
e_storage_uri = entry.upload_uri or storage_uri
|
||||
self._file_related_event_time = now
|
||||
# if we have an entry (with or without a stream), we'll generate the URL and store it in the event
|
||||
filename = entry.name
|
||||
key = '/'.join(x for x in (self._storage_key_prefix, ev.metric, ev.variant, filename.strip('/')) if x)
|
||||
url = '/'.join(x.strip('/') for x in (e_storage_uri, key))
|
||||
key, url = ev.get_target_full_upload_uri(storage_uri, self.storage_key_prefix)
|
||||
kwargs[entry.key_prop] = key
|
||||
kwargs[entry.url_prop] = url
|
||||
if not entry.stream:
|
||||
|
@ -8,7 +8,8 @@ from ..base import InterfaceBase
|
||||
from ..setupuploadmixin import SetupUploadMixin
|
||||
from ...utilities.async_manager import AsyncManagerMixin
|
||||
from ...utilities.plotly_reporter import create_2d_histogram_plot, create_value_matrix, create_3d_surface, \
|
||||
create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict
|
||||
create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict, \
|
||||
create_image_plot
|
||||
from ...utilities.py3_interop import AbstractContextManager
|
||||
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload
|
||||
|
||||
@ -187,9 +188,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
raise ValueError('Expected only one of [filename, matrix]')
|
||||
kwargs = dict(metric=self._normalize_name(title),
|
||||
variant=self._normalize_name(series), iter=iter, image_file_history_size=max_image_history)
|
||||
if matrix is None:
|
||||
matrix = cv2.imread(path)
|
||||
ev = ImageEvent(image_data=matrix, upload_uri=upload_uri, **kwargs)
|
||||
ev = ImageEvent(image_data=matrix, upload_uri=upload_uri, local_image_path=path, **kwargs)
|
||||
self._report(ev)
|
||||
|
||||
def report_histogram(self, title, series, histogram, iter, labels=None, xlabels=None, comment=None):
|
||||
@ -445,6 +444,49 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
iter=iter,
|
||||
)
|
||||
|
||||
def report_image_plot_and_upload(self, title, series, iter, path=None, matrix=None,
|
||||
upload_uri=None, max_image_history=None):
|
||||
"""
|
||||
Report an image as plot and upload its contents.
|
||||
Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename)
|
||||
describing the task ID, title, series and iteration.
|
||||
Then a plotly object is created and registered, this plotly objects points to the uploaded image
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param iter: Iteration number
|
||||
:type value: int
|
||||
:param path: A path to an image file. Required unless matrix is provided.
|
||||
:type path: str
|
||||
:param matrix: A 3D numpy.ndarray object containing image data (BGR). Required unless filename is provided.
|
||||
:type matrix: str
|
||||
:param max_image_history: maximum number of image to store per metric/variant combination
|
||||
use negative value for unlimited. default is set in global configuration (default=5)
|
||||
"""
|
||||
if not upload_uri and not self._storage_uri:
|
||||
raise ValueError('Upload configuration is required (use setup_upload())')
|
||||
if len([x for x in (path, matrix) if x is not None]) != 1:
|
||||
raise ValueError('Expected only one of [filename, matrix]')
|
||||
kwargs = dict(metric=self._normalize_name(title),
|
||||
variant=self._normalize_name(series), iter=iter, image_file_history_size=max_image_history)
|
||||
ev = ImageEvent(image_data=matrix, upload_uri=upload_uri, local_image_path=path, **kwargs)
|
||||
_, url = ev.get_target_full_upload_uri(upload_uri or self._storage_uri, self._metrics.storage_key_prefix)
|
||||
self._report(ev)
|
||||
plotly_dict = create_image_plot(
|
||||
image_src=url,
|
||||
title=title + '/' + series,
|
||||
width=matrix.shape[1] if matrix is not None else 640,
|
||||
height=matrix.shape[0] if matrix is not None else 480,
|
||||
)
|
||||
|
||||
return self.report_plot(
|
||||
title=self._normalize_name(title),
|
||||
series=self._normalize_name(series),
|
||||
plot=plotly_dict,
|
||||
iter=iter,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _normalize_name(cls, name):
|
||||
if not name:
|
||||
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
from tempfile import mkstemp
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import six
|
||||
from six import BytesIO
|
||||
|
||||
@ -129,6 +129,7 @@ class PatchedMatplotlib:
|
||||
# convert to plotly
|
||||
image = None
|
||||
plotly_fig = None
|
||||
image_format = 'svg'
|
||||
if not force_save_as_image:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -140,17 +141,28 @@ class PatchedMatplotlib:
|
||||
return renderer.plotly_fig
|
||||
|
||||
plotly_fig = our_mpl_to_plotly(mpl_fig)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as ex:
|
||||
# this was an image, change format to jpeg
|
||||
if 'selfie' in str(ex):
|
||||
image_format = 'jpeg'
|
||||
|
||||
# plotly could not serialize the plot, we should convert to image
|
||||
if not plotly_fig:
|
||||
plotly_fig = None
|
||||
buffer_ = BytesIO()
|
||||
plt.savefig(buffer_, format="png", bbox_inches='tight', pad_inches=0)
|
||||
buffer_.seek(0)
|
||||
buffer = buffer_.getbuffer() if not six.PY2 else buffer_.getvalue()
|
||||
image = cv2.imdecode(np.frombuffer(buffer, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# first try SVG if we fail then fallback to png
|
||||
buffer_ = BytesIO()
|
||||
plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0)
|
||||
buffer_.seek(0)
|
||||
except Exception:
|
||||
image_format = 'png'
|
||||
buffer_ = BytesIO()
|
||||
plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0)
|
||||
buffer_.seek(0)
|
||||
fd, image = mkstemp(suffix='.'+image_format)
|
||||
os.write(fd, buffer_.read())
|
||||
os.close(fd)
|
||||
|
||||
# check if we need to restore the active object
|
||||
if set_active and not _pylab_helpers.Gcf.get_active():
|
||||
@ -185,7 +197,14 @@ class PatchedMatplotlib:
|
||||
PatchedMatplotlib._global_image_counter += 1
|
||||
logger = PatchedMatplotlib._current_task.get_logger()
|
||||
title = plot_title or 'untitled %d' % PatchedMatplotlib._global_image_counter
|
||||
logger.report_image_and_upload(title=title, series='plot image', matrix=image,
|
||||
# this is actually a failed plot, we should put it under plots:
|
||||
# currently disabled
|
||||
# if image_format == 'svg':
|
||||
# logger.report_image_plot_and_upload(title=title, series='plot image', path=image,
|
||||
# iteration=PatchedMatplotlib._global_image_counter
|
||||
# if plot_title else 0)
|
||||
# else:
|
||||
logger.report_image_and_upload(title=title, series='plot image', path=image,
|
||||
iteration=PatchedMatplotlib._global_image_counter
|
||||
if plot_title else 0)
|
||||
except Exception:
|
||||
|
@ -533,6 +533,49 @@ class Logger(object):
|
||||
max_image_history=max_image_history,
|
||||
)
|
||||
|
||||
def report_image_plot_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None):
|
||||
"""
|
||||
Report an image, upload its contents, and present in plots section using plotly
|
||||
|
||||
Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename)
|
||||
describing the task ID, title, series and iteration.
|
||||
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param iteration: Iteration number
|
||||
:type iteration: int
|
||||
:param path: A path to an image file. Required unless matrix is provided.
|
||||
:type path: str
|
||||
:param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided.
|
||||
:type matrix: str
|
||||
:param max_image_history: maximum number of image to store per metric/variant combination \
|
||||
use negative value for unlimited. default is set in global configuration (default=5)
|
||||
:type max_image_history: int
|
||||
"""
|
||||
|
||||
# if task was not started, we have to start it
|
||||
self._start_task_if_needed()
|
||||
upload_uri = self._default_upload_destination or self._task._get_default_report_storage_uri()
|
||||
if not upload_uri:
|
||||
upload_uri = Path(get_cache_dir()) / 'debug_images'
|
||||
upload_uri.mkdir(parents=True, exist_ok=True)
|
||||
# Verify that we can upload to this destination
|
||||
upload_uri = str(upload_uri)
|
||||
storage = StorageHelper.get(upload_uri)
|
||||
upload_uri = storage.verify_upload(folder_uri=upload_uri)
|
||||
|
||||
self._task.reporter.report_image_plot_and_upload(
|
||||
title=title,
|
||||
series=series,
|
||||
path=path,
|
||||
matrix=matrix,
|
||||
iter=iteration,
|
||||
upload_uri=upload_uri,
|
||||
max_image_history=max_image_history,
|
||||
)
|
||||
|
||||
def set_default_upload_destination(self, uri):
|
||||
"""
|
||||
Set the uri to upload all the debug images to.
|
||||
|
@ -245,6 +245,36 @@ def create_3d_surface(np_value_matrix, title="3D Surface", xlabels=None, ylabels
|
||||
return conf_matrix_plot
|
||||
|
||||
|
||||
def create_image_plot(image_src, title, width=640, height=480, series=None, comment=None):
|
||||
image_plot = {
|
||||
"data": [],
|
||||
"layout": {
|
||||
"xaxis": {"visible": False, "range": [0, width]},
|
||||
"yaxis": {"visible": False, "range": [0, height]},
|
||||
"width": width,
|
||||
"height": height,
|
||||
"margin": {'l': 0, 'r': 0, 't': 0, 'b': 0},
|
||||
"images": [{
|
||||
"sizex": width,
|
||||
"sizey": height,
|
||||
"xref": "x",
|
||||
"yref": "y",
|
||||
"opacity": 1.0,
|
||||
"x": 0,
|
||||
"y": int(height / 2),
|
||||
"yanchor": "middle",
|
||||
"sizing": "contain",
|
||||
"layer": "below",
|
||||
"source": image_src
|
||||
}],
|
||||
"showlegend": False,
|
||||
"title": title if not comment else (title + '<br><sup>' + comment + '</sup>'),
|
||||
"name": series,
|
||||
}
|
||||
}
|
||||
return image_plot
|
||||
|
||||
|
||||
def _get_z_colorbar_data(z_data=None, values=None, colors=None):
|
||||
if values is None:
|
||||
values = [0, 1. / 10, 2. / 10, 6. / 10, 9. / 10]
|
||||
|
Loading…
Reference in New Issue
Block a user