From df395b67ba42143f7dc8b7a17a252a3878ebb309 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 15 Oct 2020 23:20:17 +0300 Subject: [PATCH] Add Logger.report_matplotlib_figure() with examples --- .../matplotlib/matplotlib_example.py | 18 +- .../reporting/manual_matplotlib_reporting.py | 55 ++++++ trains/backend_interface/metrics/reporter.py | 14 +- trains/binding/matplotlib_bind.py | 181 +++++++++++------- trains/logger.py | 60 +++++- 5 files changed, 253 insertions(+), 75 deletions(-) create mode 100644 examples/reporting/manual_matplotlib_reporting.py diff --git a/examples/frameworks/matplotlib/matplotlib_example.py b/examples/frameworks/matplotlib/matplotlib_example.py index 2398b250..2d4be825 100644 --- a/examples/frameworks/matplotlib/matplotlib_example.py +++ b/examples/frameworks/matplotlib/matplotlib_example.py @@ -8,30 +8,39 @@ from trains import Task task = Task.init(project_name='examples', task_name='Matplotlib example') -# create plot +# Create a plot N = 50 x = np.random.rand(N) y = np.random.rand(N) colors = np.random.rand(N) area = (30 * np.random.rand(N))**2 # 0 to 15 point radii plt.scatter(x, y, s=area, c=colors, alpha=0.5) +# Plot will be reported automatically plt.show() -# create another plot - with a name +# Alternatively, in order to report the plot with a more meaningful title/series and iteration number +area = (40 * np.random.rand(N))**2 +plt.scatter(x, y, s=area, c=colors, alpha=0.5) +task.logger.report_matplotlib_figure(title="My Plot Title", series="My Plot Series", iteration=10, figure=plt) + +# Create another plot - with a name x = np.linspace(0, 10, 30) y = np.sin(x) plt.plot(x, y, 'o', color='black') +# Plot will be reported automatically plt.show() -# create image plot +# Create image plot m = np.eye(256, 256, dtype=np.uint8) plt.imshow(m) +# Plot will be reported automatically plt.show() -# create image plot - with a name +# Create image plot - with a name m = np.eye(256, 256, dtype=np.uint8) plt.imshow(m) plt.title('Image Title') +# Plot will be reported automatically plt.show() sns.set(style="darkgrid") @@ -41,6 +50,7 @@ fmri = sns.load_dataset("fmri") sns.lineplot(x="timepoint", y="signal", hue="region", style="event", data=fmri) +# Plot will be reported automatically plt.show() print('This is a Matplotlib & Seaborn example') diff --git a/examples/reporting/manual_matplotlib_reporting.py b/examples/reporting/manual_matplotlib_reporting.py new file mode 100644 index 00000000..19b932e3 --- /dev/null +++ b/examples/reporting/manual_matplotlib_reporting.py @@ -0,0 +1,55 @@ +# TRAINS - Example of Matplotlib and Seaborn integration and reporting +# +import numpy as np +import matplotlib.pyplot as plt +from trains import Task + +# Create a new task, disable automatic matplotlib connect +task = Task.init( + project_name='examples', + task_name='Manual Matplotlib example', + auto_connect_frameworks={'matplotlib': False} +) + +# Create plot and explicitly report as figure +N = 50 +x = np.random.rand(N) +y = np.random.rand(N) +colors = np.random.rand(N) +area = (30 * np.random.rand(N))**2 # 0 to 15 point radii +plt.scatter(x, y, s=area, c=colors, alpha=0.5) +task.logger.report_matplotlib_figure( + title="Manual Reporting", + series="Just a plot", + iteration=0, + figure=plt, +) + +# Show the plot +plt.show() + +# Create plot and explicitly report as an image +plt.scatter(x, y, s=area, c=colors, alpha=0.5) +task.logger.report_matplotlib_figure( + title="Manual Reporting", + series="Plot as an image", + iteration=0, + figure=plt, + report_image=True, +) + + +# Create an image plot and explicitly report (as an image) +m = np.eye(256, 256, dtype=np.uint8) +plt.imshow(m) +task.logger.report_matplotlib_figure( + title="Manual Reporting", + series="Image plot", + iteration=0, + figure=plt, + report_image=True, # Note this is required for image plots +) + +# Show the plot +plt.show() + diff --git a/trains/backend_interface/metrics/reporter.py b/trains/backend_interface/metrics/reporter.py index 9b3d26ce..fc0c6b40 100644 --- a/trains/backend_interface/metrics/reporter.py +++ b/trains/backend_interface/metrics/reporter.py @@ -167,6 +167,18 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan iter=iter) self._report(ev) + def report_matplotlib(self, title, series, figure, iter, force_save_as_image=False, logger=None): + from trains.binding.matplotlib_bind import PatchedMatplotlib + PatchedMatplotlib.report_figure( + title=title, + series=series, + figure=figure, + iter=iter, + force_save_as_image=force_save_as_image, + reporter=self, + logger=logger, + ) + def report_plot(self, title, series, plot, iter, round_digits=None): """ Report a Plotly chart @@ -500,7 +512,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan :param iter: Iteration number :type iter: int :param labels: label (text) per point in the scatter (in the same order) - :type labels: str + :type labels: list(str) :param mode: (type str) 'lines'/'markers'/'lines+markers' :param color: list of RGBA colors [(217, 217, 217, 0.14),] :param marker_size: marker size in px diff --git a/trains/binding/matplotlib_bind.py b/trains/binding/matplotlib_bind.py index 3b2a9896..ec821bdb 100644 --- a/trains/binding/matplotlib_bind.py +++ b/trains/binding/matplotlib_bind.py @@ -33,7 +33,7 @@ class PatchedMatplotlib: _patched_mpltools_get_spine_visible = False _lock_renderer = threading.RLock() _recursion_guard = {} - _matplot_major_version = 2 + _matplot_major_version = 0 _matplot_minor_version = 0 _logger_started_reporting = False _matplotlib_reported_titles = set() @@ -62,9 +62,7 @@ class PatchedMatplotlib: try: # we support matplotlib version 2.0.0 and above import matplotlib - version_split = matplotlib.__version__.split('.') - PatchedMatplotlib._matplot_major_version = int(version_split[0]) - PatchedMatplotlib._matplot_minor_version = int(version_split[1]) + PatchedMatplotlib._update_matplotlib_version() if PatchedMatplotlib._matplot_major_version < 2: LoggerRoot.get_base_logger().warning( 'matplotlib binding supports version 2.0 and above, found version {}'.format( @@ -137,6 +135,36 @@ class PatchedMatplotlib: # update api version from ..backend_api import Session PatchedMatplotlib._support_image_plot = Session.check_min_api_version('2.2') + # load plotly + PatchedMatplotlib._update_plotly_renderers() + + return True + + @staticmethod + def _update_matplotlib_version(): + if PatchedMatplotlib._matplot_major_version: + return + + # we support matplotlib version 2.0.0 and above + try: + import matplotlib + version_split = matplotlib.__version__.split('.') + PatchedMatplotlib._matplot_major_version = int(version_split[0]) + PatchedMatplotlib._matplot_minor_version = int(version_split[1]) + + if running_remotely(): + # disable GUI backend - make headless + matplotlib.rcParams['backend'] = 'agg' + import matplotlib.pyplot + matplotlib.pyplot.switch_backend('agg') + + except Exception: + pass + + @staticmethod + def _update_plotly_renderers(): + if PatchedMatplotlib._matplotlylib and PatchedMatplotlib._plotly_renderer: + return True # create plotly renderer try: @@ -144,7 +172,7 @@ class PatchedMatplotlib: PatchedMatplotlib._matplotlylib = optional_imports.get_module('plotly.matplotlylib') PatchedMatplotlib._plotly_renderer = PatchedMatplotlib._matplotlylib.PlotlyRenderer() except Exception: - pass + return False return True @@ -251,10 +279,47 @@ class PatchedMatplotlib: return ret @staticmethod - def _report_figure(force_save_as_image=False, stored_figure=None, set_active=True, specific_fig=None): - if not PatchedMatplotlib._current_task: + def report_figure(title, series, figure, iter, force_save_as_image=False, reporter=None, logger=None): + PatchedMatplotlib._report_figure( + force_save_as_image=force_save_as_image, + specific_fig=figure.gcf() if hasattr(figure, 'gcf') else figure, + title=title, + series=series, + iter=iter, + reporter=reporter, + logger=logger + ) + + @staticmethod + def _report_figure( + force_save_as_image=False, + stored_figure=None, + set_active=True, + specific_fig=None, + title=None, + series=None, + iter=None, + reporter=None, + logger=None, + ): + # get the main task + if not PatchedMatplotlib._current_task and not reporter and not logger: return + # check if this is explicit reporting + is_explicit = reporter and logger + + # noinspection PyProtectedMember + reporter = reporter or PatchedMatplotlib._current_task._reporter + if not reporter: + return + logger = logger or PatchedMatplotlib._current_task.get_logger() + if not logger: + return + + # make sure we have matplotlib ready + PatchedMatplotlib._update_matplotlib_version() + # noinspection PyBroadException try: import matplotlib.pyplot as plt @@ -276,6 +341,13 @@ class PatchedMatplotlib: else: mpl_fig = specific_fig + if is_explicit: + # marked displayed explicitly + mpl_fig._trains_explicit = True + elif getattr(mpl_fig, '_trains_explicit', False): + # if auto bind (i.e. plt.show) and plot already displayed explicitly, do nothing. + return + # convert to plotly image = None plotly_fig = None @@ -284,6 +356,8 @@ class PatchedMatplotlib: if force_save_as_image: # if this is an image, store as is. fig_dpi = None + if isinstance(force_save_as_image, str): + image_format = force_save_as_image else: image_format = 'svg' # protect with lock, so we support multiple threads using the same renderer @@ -291,7 +365,7 @@ class PatchedMatplotlib: # noinspection PyBroadException try: def our_mpl_to_plotly(fig): - if not PatchedMatplotlib._matplotlylib or not PatchedMatplotlib._plotly_renderer: + if not PatchedMatplotlib._update_plotly_renderers(): return None if not PatchedMatplotlib._patched_mpltools_get_spine_visible and \ PatchedMatplotlib._matplot_major_version and \ @@ -383,76 +457,53 @@ class PatchedMatplotlib: if set_active and not _pylab_helpers.Gcf.get_active(): _pylab_helpers.Gcf.set_active(stored_figure) - # get the main task - # noinspection PyProtectedMember - reporter = PatchedMatplotlib._current_task._reporter - if reporter is not None: + last_iteration = iter if iter is not None else PatchedMatplotlib._get_last_iteration() + + if not title: if mpl_fig.texts: plot_title = mpl_fig.texts[0].get_text() else: gca = mpl_fig.gca() plot_title = gca.title.get_text() if gca.title else None - # remove borders and size, we should let the web take care of that - if plotly_fig: - last_iteration = PatchedMatplotlib._get_last_iteration() - if plot_title: - title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration) - else: - PatchedMatplotlib._global_plot_counter += 1 - title = 'untitled %02d' % PatchedMatplotlib._global_plot_counter - - plotly_fig.layout.margin = {} - plotly_fig.layout.autosize = True - plotly_fig.layout.height = None - plotly_fig.layout.width = None - # send the plot event - plotly_dict = plotly_fig.to_plotly_json() - if not plotly_dict.get('layout'): - plotly_dict['layout'] = {} - plotly_dict['layout']['title'] = title - - PatchedMatplotlib._matplotlib_reported_titles.add(title) - reporter.report_plot(title=title, series='plot', plot=plotly_dict, iter=last_iteration) + if plot_title: + title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration) else: - logger = PatchedMatplotlib._current_task.get_logger() + PatchedMatplotlib._global_plot_counter += 1 + mod_ = 1 if plotly_fig else PatchedMatplotlib._global_image_counter_limit + title = 'untitled %02d' % (PatchedMatplotlib._global_plot_counter % mod_) - # this is actually a failed plot, we should put it under plots: - # currently disabled - if force_save_as_image or not PatchedMatplotlib._support_image_plot: - last_iteration = PatchedMatplotlib._get_last_iteration() - # send the plot as image - if plot_title: - title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration) - else: - PatchedMatplotlib._global_image_counter += 1 - title = 'untitled %02d' % (PatchedMatplotlib._global_image_counter % - PatchedMatplotlib._global_image_counter_limit) + # remove borders and size, we should let the web take care of that + if plotly_fig: + plotly_fig.layout.margin = {} + plotly_fig.layout.autosize = True + plotly_fig.layout.height = None + plotly_fig.layout.width = None + # send the plot event + plotly_dict = plotly_fig.to_plotly_json() + if not plotly_dict.get('layout'): + plotly_dict['layout'] = {} + plotly_dict['layout']['title'] = series or title - PatchedMatplotlib._matplotlib_reported_titles.add(title) - logger.report_image(title=title, series='plot image', local_path=image, - delete_after_upload=True, iteration=last_iteration) - else: - # send the plot as plotly with embedded image - last_iteration = PatchedMatplotlib._get_last_iteration() - if plot_title: - title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration) - else: - PatchedMatplotlib._global_plot_counter += 1 - title = 'untitled %02d' % (PatchedMatplotlib._global_plot_counter % - PatchedMatplotlib._global_image_counter_limit) - - PatchedMatplotlib._matplotlib_reported_titles.add(title) - # noinspection PyProtectedMember - logger._report_image_plot_and_upload( - title=title, series='plot image', path=image, - delete_after_upload=True, iteration=last_iteration) + PatchedMatplotlib._matplotlib_reported_titles.add(title) + reporter.report_plot(title=title, series=series or 'plot', plot=plotly_dict, iter=last_iteration) + else: + # this is actually a failed plot, we should put it under plots: + # currently disabled + if force_save_as_image or not PatchedMatplotlib._support_image_plot: + PatchedMatplotlib._matplotlib_reported_titles.add(title) + logger.report_image(title=title, series=series or 'plot image', local_path=image, + delete_after_upload=True, iteration=last_iteration) + else: + PatchedMatplotlib._matplotlib_reported_titles.add(title) + # noinspection PyProtectedMember + logger._report_image_plot_and_upload( + title=title, series=series or 'plot image', path=image, + delete_after_upload=True, iteration=last_iteration) except Exception: # plotly failed pass - return - @staticmethod def _enforce_unique_title_per_iteration(title, last_iteration): # type: (str, int) -> str diff --git a/trains/logger.py b/trains/logger.py index 7567320b..29467d75 100644 --- a/trains/logger.py +++ b/trains/logger.py @@ -1,7 +1,7 @@ import logging import math import warnings -from typing import Any, Sequence, Union, List, Optional, Tuple, Dict +from typing import Any, Sequence, Union, List, Optional, Tuple, Dict, TYPE_CHECKING import numpy as np import six @@ -28,6 +28,12 @@ from .utilities.plotly_reporter import SeriesInfo warnings.filterwarnings('always', category=DeprecationWarning, module=__name__) +if TYPE_CHECKING: + from matplotlib.figure import Figure as MatplotlibFigure + from matplotlib import pyplot + from plotly.graph_objects import Figure + + class Logger(object): """ The ``Logger`` class is the Trains console log and metric statistics interface, and contains methods for explicit @@ -142,6 +148,7 @@ class Logger(object): # if task was not started, we have to start it self._start_task_if_needed() self._touch_title_series(title, series) + # noinspection PyProtectedMember return self._task._reporter.report_scalar(title=title, series=series, value=float(value), iter=iteration) def report_vector( @@ -239,6 +246,7 @@ class Logger(object): # if task was not started, we have to start it self._start_task_if_needed() self._touch_title_series(title, series) + # noinspection PyProtectedMember return self._task._reporter.report_histogram( title=title, series=series, @@ -321,7 +329,7 @@ class Logger(object): replace("NaN", np.nan, math.nan) replace("Inf", np.inf, math.inf) replace("-Inf", -np.inf, np.NINF, -math.inf) - + # noinspection PyProtectedMember return self._task._reporter.report_table( title=title, series=series, @@ -375,6 +383,7 @@ class Logger(object): # if task was not started, we have to start it self._start_task_if_needed() self._touch_title_series(title, series[0].name if series else '') + # noinspection PyProtectedMember return self._task._reporter.report_line_plot( title=title, series=series, @@ -453,6 +462,7 @@ class Logger(object): # if task was not started, we have to start it self._start_task_if_needed() self._touch_title_series(title, series) + # noinspection PyProtectedMember return self._task._reporter.report_2d_scatter( title=title, series=series, @@ -547,6 +557,7 @@ class Logger(object): # if task was not started, we have to start it self._start_task_if_needed() self._touch_title_series(title, series) + # noinspection PyProtectedMember return self._task._reporter.report_3d_scatter( title=title, series=series, @@ -607,6 +618,7 @@ class Logger(object): # if task was not started, we have to start it self._start_task_if_needed() self._touch_title_series(title, series) + # noinspection PyProtectedMember return self._task._reporter.report_value_matrix( title=title, series=series, @@ -707,6 +719,7 @@ class Logger(object): # if task was not started, we have to start it self._start_task_if_needed() self._touch_title_series(title, series) + # noinspection PyProtectedMember return self._task._reporter.report_value_surface( title=title, series=series, @@ -797,6 +810,7 @@ class Logger(object): self._touch_title_series(title, series) if url: + # noinspection PyProtectedMember self._task._reporter.report_image( title=title, series=series, @@ -816,7 +830,7 @@ class Logger(object): if isinstance(image, Image.Image): image = np.array(image) - + # noinspection PyProtectedMember self._task._reporter.report_image_and_upload( title=title, series=series, @@ -882,6 +896,7 @@ class Logger(object): self._touch_title_series(title, series) if url: + # noinspection PyProtectedMember self._task._reporter.report_media( title=title, series=series, @@ -898,7 +913,7 @@ class Logger(object): upload_uri = str(upload_uri) storage = StorageHelper.get(upload_uri) upload_uri = storage.verify_upload(folder_uri=upload_uri) - + # noinspection PyProtectedMember self._task._reporter.report_media_and_upload( title=title, series=series, @@ -939,6 +954,7 @@ class Logger(object): plot['layout']['title'] = series except Exception: pass + # noinspection PyProtectedMember self._task._reporter.report_plot( title=title, series=series, @@ -946,6 +962,39 @@ class Logger(object): iter=iteration, ) + def report_matplotlib_figure( + self, + title, # type: str + series, # type: str + iteration, # type: int + figure, # type: Union[MatplotlibFigure, pyplot] + report_image=False, # type: bool + ): + """ + Report a ``matplotlib`` figure / plot directly + + ``matplotlib.figure.Figure`` / ``matplotlib.pyplot`` + + :param str title: The title (metric) of the plot. + :param str series: The series name (variant) of the reported plot. + :param int iteration: The iteration number. + :param MatplotlibFigure figure: A ``matplotlib`` Figure object + :param report_image: Default False. If True the plot will be uploaded as a debug sample (png image), + and will appear under the debug samples tab (instead of the Plots tab). + """ + # if task was not started, we have to start it + self._start_task_if_needed() + + # noinspection PyProtectedMember + self._task._reporter.report_matplotlib( + title=title, + series=series, + figure=figure, + iter=iteration, + logger=self, + force_save_as_image='png' if report_image else False, + ) + def set_default_upload_destination(self, uri): # type: (str) -> None """ @@ -1185,6 +1234,7 @@ class Logger(object): storage = StorageHelper.get(upload_uri) upload_uri = storage.verify_upload(folder_uri=upload_uri) + # noinspection PyProtectedMember self._task._reporter.report_image_plot_and_upload( title=title, series=series, @@ -1236,7 +1286,7 @@ class Logger(object): upload_uri = str(upload_uri) storage = StorageHelper.get(upload_uri) upload_uri = storage.verify_upload(folder_uri=upload_uri) - + # noinspection PyProtectedMember self._task._reporter.report_image_and_upload( title=title, series=series,