Add Logger.report_matplotlib_figure() with examples

This commit is contained in:
allegroai 2020-10-15 23:20:17 +03:00
parent 7ce5bc0313
commit df395b67ba
5 changed files with 253 additions and 75 deletions

View File

@ -8,30 +8,39 @@ from trains import Task
task = Task.init(project_name='examples', task_name='Matplotlib example') task = Task.init(project_name='examples', task_name='Matplotlib example')
# create plot # Create a plot
N = 50 N = 50
x = np.random.rand(N) x = np.random.rand(N)
y = np.random.rand(N) y = np.random.rand(N)
colors = np.random.rand(N) colors = np.random.rand(N)
area = (30 * np.random.rand(N))**2 # 0 to 15 point radii area = (30 * np.random.rand(N))**2 # 0 to 15 point radii
plt.scatter(x, y, s=area, c=colors, alpha=0.5) plt.scatter(x, y, s=area, c=colors, alpha=0.5)
# Plot will be reported automatically
plt.show() plt.show()
# create another plot - with a name # Alternatively, in order to report the plot with a more meaningful title/series and iteration number
area = (40 * np.random.rand(N))**2
plt.scatter(x, y, s=area, c=colors, alpha=0.5)
task.logger.report_matplotlib_figure(title="My Plot Title", series="My Plot Series", iteration=10, figure=plt)
# Create another plot - with a name
x = np.linspace(0, 10, 30) x = np.linspace(0, 10, 30)
y = np.sin(x) y = np.sin(x)
plt.plot(x, y, 'o', color='black') plt.plot(x, y, 'o', color='black')
# Plot will be reported automatically
plt.show() plt.show()
# create image plot # Create image plot
m = np.eye(256, 256, dtype=np.uint8) m = np.eye(256, 256, dtype=np.uint8)
plt.imshow(m) plt.imshow(m)
# Plot will be reported automatically
plt.show() plt.show()
# create image plot - with a name # Create image plot - with a name
m = np.eye(256, 256, dtype=np.uint8) m = np.eye(256, 256, dtype=np.uint8)
plt.imshow(m) plt.imshow(m)
plt.title('Image Title') plt.title('Image Title')
# Plot will be reported automatically
plt.show() plt.show()
sns.set(style="darkgrid") sns.set(style="darkgrid")
@ -41,6 +50,7 @@ fmri = sns.load_dataset("fmri")
sns.lineplot(x="timepoint", y="signal", sns.lineplot(x="timepoint", y="signal",
hue="region", style="event", hue="region", style="event",
data=fmri) data=fmri)
# Plot will be reported automatically
plt.show() plt.show()
print('This is a Matplotlib & Seaborn example') print('This is a Matplotlib & Seaborn example')

View File

@ -0,0 +1,55 @@
# TRAINS - Example of Matplotlib and Seaborn integration and reporting
#
import numpy as np
import matplotlib.pyplot as plt
from trains import Task
# Create a new task, disable automatic matplotlib connect
task = Task.init(
project_name='examples',
task_name='Manual Matplotlib example',
auto_connect_frameworks={'matplotlib': False}
)
# Create plot and explicitly report as figure
N = 50
x = np.random.rand(N)
y = np.random.rand(N)
colors = np.random.rand(N)
area = (30 * np.random.rand(N))**2 # 0 to 15 point radii
plt.scatter(x, y, s=area, c=colors, alpha=0.5)
task.logger.report_matplotlib_figure(
title="Manual Reporting",
series="Just a plot",
iteration=0,
figure=plt,
)
# Show the plot
plt.show()
# Create plot and explicitly report as an image
plt.scatter(x, y, s=area, c=colors, alpha=0.5)
task.logger.report_matplotlib_figure(
title="Manual Reporting",
series="Plot as an image",
iteration=0,
figure=plt,
report_image=True,
)
# Create an image plot and explicitly report (as an image)
m = np.eye(256, 256, dtype=np.uint8)
plt.imshow(m)
task.logger.report_matplotlib_figure(
title="Manual Reporting",
series="Image plot",
iteration=0,
figure=plt,
report_image=True, # Note this is required for image plots
)
# Show the plot
plt.show()

View File

@ -167,6 +167,18 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
iter=iter) iter=iter)
self._report(ev) self._report(ev)
def report_matplotlib(self, title, series, figure, iter, force_save_as_image=False, logger=None):
from trains.binding.matplotlib_bind import PatchedMatplotlib
PatchedMatplotlib.report_figure(
title=title,
series=series,
figure=figure,
iter=iter,
force_save_as_image=force_save_as_image,
reporter=self,
logger=logger,
)
def report_plot(self, title, series, plot, iter, round_digits=None): def report_plot(self, title, series, plot, iter, round_digits=None):
""" """
Report a Plotly chart Report a Plotly chart
@ -500,7 +512,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param iter: Iteration number :param iter: Iteration number
:type iter: int :type iter: int
:param labels: label (text) per point in the scatter (in the same order) :param labels: label (text) per point in the scatter (in the same order)
:type labels: str :type labels: list(str)
:param mode: (type str) 'lines'/'markers'/'lines+markers' :param mode: (type str) 'lines'/'markers'/'lines+markers'
:param color: list of RGBA colors [(217, 217, 217, 0.14),] :param color: list of RGBA colors [(217, 217, 217, 0.14),]
:param marker_size: marker size in px :param marker_size: marker size in px

View File

@ -33,7 +33,7 @@ class PatchedMatplotlib:
_patched_mpltools_get_spine_visible = False _patched_mpltools_get_spine_visible = False
_lock_renderer = threading.RLock() _lock_renderer = threading.RLock()
_recursion_guard = {} _recursion_guard = {}
_matplot_major_version = 2 _matplot_major_version = 0
_matplot_minor_version = 0 _matplot_minor_version = 0
_logger_started_reporting = False _logger_started_reporting = False
_matplotlib_reported_titles = set() _matplotlib_reported_titles = set()
@ -62,9 +62,7 @@ class PatchedMatplotlib:
try: try:
# we support matplotlib version 2.0.0 and above # we support matplotlib version 2.0.0 and above
import matplotlib import matplotlib
version_split = matplotlib.__version__.split('.') PatchedMatplotlib._update_matplotlib_version()
PatchedMatplotlib._matplot_major_version = int(version_split[0])
PatchedMatplotlib._matplot_minor_version = int(version_split[1])
if PatchedMatplotlib._matplot_major_version < 2: if PatchedMatplotlib._matplot_major_version < 2:
LoggerRoot.get_base_logger().warning( LoggerRoot.get_base_logger().warning(
'matplotlib binding supports version 2.0 and above, found version {}'.format( 'matplotlib binding supports version 2.0 and above, found version {}'.format(
@ -137,6 +135,36 @@ class PatchedMatplotlib:
# update api version # update api version
from ..backend_api import Session from ..backend_api import Session
PatchedMatplotlib._support_image_plot = Session.check_min_api_version('2.2') PatchedMatplotlib._support_image_plot = Session.check_min_api_version('2.2')
# load plotly
PatchedMatplotlib._update_plotly_renderers()
return True
@staticmethod
def _update_matplotlib_version():
if PatchedMatplotlib._matplot_major_version:
return
# we support matplotlib version 2.0.0 and above
try:
import matplotlib
version_split = matplotlib.__version__.split('.')
PatchedMatplotlib._matplot_major_version = int(version_split[0])
PatchedMatplotlib._matplot_minor_version = int(version_split[1])
if running_remotely():
# disable GUI backend - make headless
matplotlib.rcParams['backend'] = 'agg'
import matplotlib.pyplot
matplotlib.pyplot.switch_backend('agg')
except Exception:
pass
@staticmethod
def _update_plotly_renderers():
if PatchedMatplotlib._matplotlylib and PatchedMatplotlib._plotly_renderer:
return True
# create plotly renderer # create plotly renderer
try: try:
@ -144,7 +172,7 @@ class PatchedMatplotlib:
PatchedMatplotlib._matplotlylib = optional_imports.get_module('plotly.matplotlylib') PatchedMatplotlib._matplotlylib = optional_imports.get_module('plotly.matplotlylib')
PatchedMatplotlib._plotly_renderer = PatchedMatplotlib._matplotlylib.PlotlyRenderer() PatchedMatplotlib._plotly_renderer = PatchedMatplotlib._matplotlylib.PlotlyRenderer()
except Exception: except Exception:
pass return False
return True return True
@ -251,10 +279,47 @@ class PatchedMatplotlib:
return ret return ret
@staticmethod @staticmethod
def _report_figure(force_save_as_image=False, stored_figure=None, set_active=True, specific_fig=None): def report_figure(title, series, figure, iter, force_save_as_image=False, reporter=None, logger=None):
if not PatchedMatplotlib._current_task: PatchedMatplotlib._report_figure(
force_save_as_image=force_save_as_image,
specific_fig=figure.gcf() if hasattr(figure, 'gcf') else figure,
title=title,
series=series,
iter=iter,
reporter=reporter,
logger=logger
)
@staticmethod
def _report_figure(
force_save_as_image=False,
stored_figure=None,
set_active=True,
specific_fig=None,
title=None,
series=None,
iter=None,
reporter=None,
logger=None,
):
# get the main task
if not PatchedMatplotlib._current_task and not reporter and not logger:
return return
# check if this is explicit reporting
is_explicit = reporter and logger
# noinspection PyProtectedMember
reporter = reporter or PatchedMatplotlib._current_task._reporter
if not reporter:
return
logger = logger or PatchedMatplotlib._current_task.get_logger()
if not logger:
return
# make sure we have matplotlib ready
PatchedMatplotlib._update_matplotlib_version()
# noinspection PyBroadException # noinspection PyBroadException
try: try:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -276,6 +341,13 @@ class PatchedMatplotlib:
else: else:
mpl_fig = specific_fig mpl_fig = specific_fig
if is_explicit:
# marked displayed explicitly
mpl_fig._trains_explicit = True
elif getattr(mpl_fig, '_trains_explicit', False):
# if auto bind (i.e. plt.show) and plot already displayed explicitly, do nothing.
return
# convert to plotly # convert to plotly
image = None image = None
plotly_fig = None plotly_fig = None
@ -284,6 +356,8 @@ class PatchedMatplotlib:
if force_save_as_image: if force_save_as_image:
# if this is an image, store as is. # if this is an image, store as is.
fig_dpi = None fig_dpi = None
if isinstance(force_save_as_image, str):
image_format = force_save_as_image
else: else:
image_format = 'svg' image_format = 'svg'
# protect with lock, so we support multiple threads using the same renderer # protect with lock, so we support multiple threads using the same renderer
@ -291,7 +365,7 @@ class PatchedMatplotlib:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
def our_mpl_to_plotly(fig): def our_mpl_to_plotly(fig):
if not PatchedMatplotlib._matplotlylib or not PatchedMatplotlib._plotly_renderer: if not PatchedMatplotlib._update_plotly_renderers():
return None return None
if not PatchedMatplotlib._patched_mpltools_get_spine_visible and \ if not PatchedMatplotlib._patched_mpltools_get_spine_visible and \
PatchedMatplotlib._matplot_major_version and \ PatchedMatplotlib._matplot_major_version and \
@ -383,25 +457,24 @@ class PatchedMatplotlib:
if set_active and not _pylab_helpers.Gcf.get_active(): if set_active and not _pylab_helpers.Gcf.get_active():
_pylab_helpers.Gcf.set_active(stored_figure) _pylab_helpers.Gcf.set_active(stored_figure)
# get the main task last_iteration = iter if iter is not None else PatchedMatplotlib._get_last_iteration()
# noinspection PyProtectedMember
reporter = PatchedMatplotlib._current_task._reporter if not title:
if reporter is not None:
if mpl_fig.texts: if mpl_fig.texts:
plot_title = mpl_fig.texts[0].get_text() plot_title = mpl_fig.texts[0].get_text()
else: else:
gca = mpl_fig.gca() gca = mpl_fig.gca()
plot_title = gca.title.get_text() if gca.title else None plot_title = gca.title.get_text() if gca.title else None
# remove borders and size, we should let the web take care of that
if plotly_fig:
last_iteration = PatchedMatplotlib._get_last_iteration()
if plot_title: if plot_title:
title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration) title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
else: else:
PatchedMatplotlib._global_plot_counter += 1 PatchedMatplotlib._global_plot_counter += 1
title = 'untitled %02d' % PatchedMatplotlib._global_plot_counter mod_ = 1 if plotly_fig else PatchedMatplotlib._global_image_counter_limit
title = 'untitled %02d' % (PatchedMatplotlib._global_plot_counter % mod_)
# remove borders and size, we should let the web take care of that
if plotly_fig:
plotly_fig.layout.margin = {} plotly_fig.layout.margin = {}
plotly_fig.layout.autosize = True plotly_fig.layout.autosize = True
plotly_fig.layout.height = None plotly_fig.layout.height = None
@ -410,49 +483,27 @@ class PatchedMatplotlib:
plotly_dict = plotly_fig.to_plotly_json() plotly_dict = plotly_fig.to_plotly_json()
if not plotly_dict.get('layout'): if not plotly_dict.get('layout'):
plotly_dict['layout'] = {} plotly_dict['layout'] = {}
plotly_dict['layout']['title'] = title plotly_dict['layout']['title'] = series or title
PatchedMatplotlib._matplotlib_reported_titles.add(title) PatchedMatplotlib._matplotlib_reported_titles.add(title)
reporter.report_plot(title=title, series='plot', plot=plotly_dict, iter=last_iteration) reporter.report_plot(title=title, series=series or 'plot', plot=plotly_dict, iter=last_iteration)
else: else:
logger = PatchedMatplotlib._current_task.get_logger()
# this is actually a failed plot, we should put it under plots: # this is actually a failed plot, we should put it under plots:
# currently disabled # currently disabled
if force_save_as_image or not PatchedMatplotlib._support_image_plot: if force_save_as_image or not PatchedMatplotlib._support_image_plot:
last_iteration = PatchedMatplotlib._get_last_iteration()
# send the plot as image
if plot_title:
title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
else:
PatchedMatplotlib._global_image_counter += 1
title = 'untitled %02d' % (PatchedMatplotlib._global_image_counter %
PatchedMatplotlib._global_image_counter_limit)
PatchedMatplotlib._matplotlib_reported_titles.add(title) PatchedMatplotlib._matplotlib_reported_titles.add(title)
logger.report_image(title=title, series='plot image', local_path=image, logger.report_image(title=title, series=series or 'plot image', local_path=image,
delete_after_upload=True, iteration=last_iteration) delete_after_upload=True, iteration=last_iteration)
else: else:
# send the plot as plotly with embedded image
last_iteration = PatchedMatplotlib._get_last_iteration()
if plot_title:
title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
else:
PatchedMatplotlib._global_plot_counter += 1
title = 'untitled %02d' % (PatchedMatplotlib._global_plot_counter %
PatchedMatplotlib._global_image_counter_limit)
PatchedMatplotlib._matplotlib_reported_titles.add(title) PatchedMatplotlib._matplotlib_reported_titles.add(title)
# noinspection PyProtectedMember # noinspection PyProtectedMember
logger._report_image_plot_and_upload( logger._report_image_plot_and_upload(
title=title, series='plot image', path=image, title=title, series=series or 'plot image', path=image,
delete_after_upload=True, iteration=last_iteration) delete_after_upload=True, iteration=last_iteration)
except Exception: except Exception:
# plotly failed # plotly failed
pass pass
return
@staticmethod @staticmethod
def _enforce_unique_title_per_iteration(title, last_iteration): def _enforce_unique_title_per_iteration(title, last_iteration):
# type: (str, int) -> str # type: (str, int) -> str

View File

@ -1,7 +1,7 @@
import logging import logging
import math import math
import warnings import warnings
from typing import Any, Sequence, Union, List, Optional, Tuple, Dict from typing import Any, Sequence, Union, List, Optional, Tuple, Dict, TYPE_CHECKING
import numpy as np import numpy as np
import six import six
@ -28,6 +28,12 @@ from .utilities.plotly_reporter import SeriesInfo
warnings.filterwarnings('always', category=DeprecationWarning, module=__name__) warnings.filterwarnings('always', category=DeprecationWarning, module=__name__)
if TYPE_CHECKING:
from matplotlib.figure import Figure as MatplotlibFigure
from matplotlib import pyplot
from plotly.graph_objects import Figure
class Logger(object): class Logger(object):
""" """
The ``Logger`` class is the Trains console log and metric statistics interface, and contains methods for explicit The ``Logger`` class is the Trains console log and metric statistics interface, and contains methods for explicit
@ -142,6 +148,7 @@ class Logger(object):
# if task was not started, we have to start it # if task was not started, we have to start it
self._start_task_if_needed() self._start_task_if_needed()
self._touch_title_series(title, series) self._touch_title_series(title, series)
# noinspection PyProtectedMember
return self._task._reporter.report_scalar(title=title, series=series, value=float(value), iter=iteration) return self._task._reporter.report_scalar(title=title, series=series, value=float(value), iter=iteration)
def report_vector( def report_vector(
@ -239,6 +246,7 @@ class Logger(object):
# if task was not started, we have to start it # if task was not started, we have to start it
self._start_task_if_needed() self._start_task_if_needed()
self._touch_title_series(title, series) self._touch_title_series(title, series)
# noinspection PyProtectedMember
return self._task._reporter.report_histogram( return self._task._reporter.report_histogram(
title=title, title=title,
series=series, series=series,
@ -321,7 +329,7 @@ class Logger(object):
replace("NaN", np.nan, math.nan) replace("NaN", np.nan, math.nan)
replace("Inf", np.inf, math.inf) replace("Inf", np.inf, math.inf)
replace("-Inf", -np.inf, np.NINF, -math.inf) replace("-Inf", -np.inf, np.NINF, -math.inf)
# noinspection PyProtectedMember
return self._task._reporter.report_table( return self._task._reporter.report_table(
title=title, title=title,
series=series, series=series,
@ -375,6 +383,7 @@ class Logger(object):
# if task was not started, we have to start it # if task was not started, we have to start it
self._start_task_if_needed() self._start_task_if_needed()
self._touch_title_series(title, series[0].name if series else '') self._touch_title_series(title, series[0].name if series else '')
# noinspection PyProtectedMember
return self._task._reporter.report_line_plot( return self._task._reporter.report_line_plot(
title=title, title=title,
series=series, series=series,
@ -453,6 +462,7 @@ class Logger(object):
# if task was not started, we have to start it # if task was not started, we have to start it
self._start_task_if_needed() self._start_task_if_needed()
self._touch_title_series(title, series) self._touch_title_series(title, series)
# noinspection PyProtectedMember
return self._task._reporter.report_2d_scatter( return self._task._reporter.report_2d_scatter(
title=title, title=title,
series=series, series=series,
@ -547,6 +557,7 @@ class Logger(object):
# if task was not started, we have to start it # if task was not started, we have to start it
self._start_task_if_needed() self._start_task_if_needed()
self._touch_title_series(title, series) self._touch_title_series(title, series)
# noinspection PyProtectedMember
return self._task._reporter.report_3d_scatter( return self._task._reporter.report_3d_scatter(
title=title, title=title,
series=series, series=series,
@ -607,6 +618,7 @@ class Logger(object):
# if task was not started, we have to start it # if task was not started, we have to start it
self._start_task_if_needed() self._start_task_if_needed()
self._touch_title_series(title, series) self._touch_title_series(title, series)
# noinspection PyProtectedMember
return self._task._reporter.report_value_matrix( return self._task._reporter.report_value_matrix(
title=title, title=title,
series=series, series=series,
@ -707,6 +719,7 @@ class Logger(object):
# if task was not started, we have to start it # if task was not started, we have to start it
self._start_task_if_needed() self._start_task_if_needed()
self._touch_title_series(title, series) self._touch_title_series(title, series)
# noinspection PyProtectedMember
return self._task._reporter.report_value_surface( return self._task._reporter.report_value_surface(
title=title, title=title,
series=series, series=series,
@ -797,6 +810,7 @@ class Logger(object):
self._touch_title_series(title, series) self._touch_title_series(title, series)
if url: if url:
# noinspection PyProtectedMember
self._task._reporter.report_image( self._task._reporter.report_image(
title=title, title=title,
series=series, series=series,
@ -816,7 +830,7 @@ class Logger(object):
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
image = np.array(image) image = np.array(image)
# noinspection PyProtectedMember
self._task._reporter.report_image_and_upload( self._task._reporter.report_image_and_upload(
title=title, title=title,
series=series, series=series,
@ -882,6 +896,7 @@ class Logger(object):
self._touch_title_series(title, series) self._touch_title_series(title, series)
if url: if url:
# noinspection PyProtectedMember
self._task._reporter.report_media( self._task._reporter.report_media(
title=title, title=title,
series=series, series=series,
@ -898,7 +913,7 @@ class Logger(object):
upload_uri = str(upload_uri) upload_uri = str(upload_uri)
storage = StorageHelper.get(upload_uri) storage = StorageHelper.get(upload_uri)
upload_uri = storage.verify_upload(folder_uri=upload_uri) upload_uri = storage.verify_upload(folder_uri=upload_uri)
# noinspection PyProtectedMember
self._task._reporter.report_media_and_upload( self._task._reporter.report_media_and_upload(
title=title, title=title,
series=series, series=series,
@ -939,6 +954,7 @@ class Logger(object):
plot['layout']['title'] = series plot['layout']['title'] = series
except Exception: except Exception:
pass pass
# noinspection PyProtectedMember
self._task._reporter.report_plot( self._task._reporter.report_plot(
title=title, title=title,
series=series, series=series,
@ -946,6 +962,39 @@ class Logger(object):
iter=iteration, iter=iteration,
) )
def report_matplotlib_figure(
self,
title, # type: str
series, # type: str
iteration, # type: int
figure, # type: Union[MatplotlibFigure, pyplot]
report_image=False, # type: bool
):
"""
Report a ``matplotlib`` figure / plot directly
``matplotlib.figure.Figure`` / ``matplotlib.pyplot``
:param str title: The title (metric) of the plot.
:param str series: The series name (variant) of the reported plot.
:param int iteration: The iteration number.
:param MatplotlibFigure figure: A ``matplotlib`` Figure object
:param report_image: Default False. If True the plot will be uploaded as a debug sample (png image),
and will appear under the debug samples tab (instead of the Plots tab).
"""
# if task was not started, we have to start it
self._start_task_if_needed()
# noinspection PyProtectedMember
self._task._reporter.report_matplotlib(
title=title,
series=series,
figure=figure,
iter=iteration,
logger=self,
force_save_as_image='png' if report_image else False,
)
def set_default_upload_destination(self, uri): def set_default_upload_destination(self, uri):
# type: (str) -> None # type: (str) -> None
""" """
@ -1185,6 +1234,7 @@ class Logger(object):
storage = StorageHelper.get(upload_uri) storage = StorageHelper.get(upload_uri)
upload_uri = storage.verify_upload(folder_uri=upload_uri) upload_uri = storage.verify_upload(folder_uri=upload_uri)
# noinspection PyProtectedMember
self._task._reporter.report_image_plot_and_upload( self._task._reporter.report_image_plot_and_upload(
title=title, title=title,
series=series, series=series,
@ -1236,7 +1286,7 @@ class Logger(object):
upload_uri = str(upload_uri) upload_uri = str(upload_uri)
storage = StorageHelper.get(upload_uri) storage = StorageHelper.get(upload_uri)
upload_uri = storage.verify_upload(folder_uri=upload_uri) upload_uri = storage.verify_upload(folder_uri=upload_uri)
# noinspection PyProtectedMember
self._task._reporter.report_image_and_upload( self._task._reporter.report_image_and_upload(
title=title, title=title,
series=series, series=series,