From 67df9774e1986f6ab9b940ca5cbc6a088c3b07e1 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 28 Jul 2019 21:06:47 +0300 Subject: [PATCH] Fix Matplotlib figure.show support Improve Matplotlib support, plt.show will show under plots tabs, and plt.imshow under debug images --- trains/binding/matplotlib_bind.py | 94 ++++++++++++++++++++++--------- 1 file changed, 67 insertions(+), 27 deletions(-) diff --git a/trains/binding/matplotlib_bind.py b/trains/binding/matplotlib_bind.py index b8db7f13..11364a48 100644 --- a/trains/binding/matplotlib_bind.py +++ b/trains/binding/matplotlib_bind.py @@ -11,12 +11,14 @@ from ..config import running_remotely class PatchedMatplotlib: _patched_original_plot = None + _patched_original_figure = None __patched_original_imshow = None __patched_original_draw_all = None __patched_draw_all_recursion_guard = False _global_plot_counter = -1 _global_image_counter = -1 _current_task = None + _support_image_plot = False class _PatchWarnings(object): def __init__(self): @@ -51,15 +53,19 @@ class PatchedMatplotlib: import matplotlib.pyplot matplotlib.pyplot.switch_backend('agg') import matplotlib.pyplot as plt + import matplotlib.figure as figure from matplotlib import _pylab_helpers if six.PY2: PatchedMatplotlib._patched_original_plot = staticmethod(plt.show) PatchedMatplotlib._patched_original_imshow = staticmethod(plt.imshow) + PatchedMatplotlib._patched_original_figure = staticmethod(figure.Figure.show) else: PatchedMatplotlib._patched_original_plot = plt.show PatchedMatplotlib._patched_original_imshow = plt.imshow + PatchedMatplotlib._patched_original_figure = figure.Figure.show plt.show = PatchedMatplotlib.patched_show - # sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow + figure.Figure.show = PatchedMatplotlib.patched_figure_show + sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow # patch plotly so we know it failed us. from plotly.matplotlylib import renderer renderer.warnings = PatchedMatplotlib._PatchWarnings() @@ -87,11 +93,26 @@ class PatchedMatplotlib: def update_current_task(task): if PatchedMatplotlib.patch_matplotlib(): PatchedMatplotlib._current_task = task + from ..backend_api import Session + PatchedMatplotlib._support_image_plot = Session.api_version > '2.1' @staticmethod def patched_imshow(*args, **kw): ret = PatchedMatplotlib._patched_original_imshow(*args, **kw) - PatchedMatplotlib._report_figure(force_save_as_image=True) + try: + from matplotlib import _pylab_helpers + # store on the plot that this is an imshow plot + stored_figure = _pylab_helpers.Gcf.get_active() + if stored_figure: + stored_figure._trains_is_imshow = True + except Exception: + pass + return ret + + @staticmethod + def patched_figure_show(self, *args, **kw): + PatchedMatplotlib._report_figure(set_active=False, specific_fig=self) + ret = PatchedMatplotlib._patched_original_figure(self, *args, **kw) return ret @staticmethod @@ -110,7 +131,7 @@ class PatchedMatplotlib: return ret @staticmethod - def _report_figure(force_save_as_image=False, stored_figure=None, set_active=True): + def _report_figure(force_save_as_image=False, stored_figure=None, set_active=True, specific_fig=None): if not PatchedMatplotlib._current_task: return @@ -119,18 +140,27 @@ class PatchedMatplotlib: import matplotlib.pyplot as plt from plotly import optional_imports from matplotlib import _pylab_helpers - # store the figure object we just created (if it is not already there) - stored_figure = stored_figure or _pylab_helpers.Gcf.get_active() - if not stored_figure: - # nothing for us to do - return - # get current figure - mpl_fig = stored_figure.canvas.figure # plt.gcf() + if specific_fig is None: + # store the figure object we just created (if it is not already there) + stored_figure = stored_figure or _pylab_helpers.Gcf.get_active() + if not stored_figure: + # nothing for us to do + return + # check if this is an imshow + if hasattr(stored_figure, '_trains_is_imshow') and stored_figure._trains_is_imshow: + force_save_as_image = True + # flag will be cleared when calling clf() (object will be replaced) + # get current figure + mpl_fig = stored_figure.canvas.figure # plt.gcf() + else: + mpl_fig = specific_fig + # convert to plotly image = None plotly_fig = None - image_format = 'svg' + image_format = 'jpeg' if not force_save_as_image: + image_format = 'svg' # noinspection PyBroadException try: def our_mpl_to_plotly(fig): @@ -142,9 +172,8 @@ class PatchedMatplotlib: plotly_fig = our_mpl_to_plotly(mpl_fig) except Exception as ex: - # this was an image, change format to jpeg - if 'selfie' in str(ex): - image_format = 'jpeg' + # this was an image, change format to png + image_format = 'jpeg' if 'selfie' in str(ex) else 'png' # plotly could not serialize the plot, we should convert to image if not plotly_fig: @@ -153,12 +182,14 @@ class PatchedMatplotlib: try: # first try SVG if we fail then fallback to png buffer_ = BytesIO() - plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0) + a_plt = specific_fig if specific_fig is not None else plt + a_plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False) buffer_.seek(0) except Exception: image_format = 'png' buffer_ = BytesIO() - plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0) + a_plt = specific_fig if specific_fig is not None else plt + a_plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False) buffer_.seek(0) fd, image = mkstemp(suffix='.'+image_format) os.write(fd, buffer_.read()) @@ -193,20 +224,29 @@ class PatchedMatplotlib: reporter.report_plot(title=title, series='plot', plot=plotly_dict, iter=PatchedMatplotlib._global_plot_counter if plot_title else 0) else: - # send the plot as image - PatchedMatplotlib._global_image_counter += 1 logger = PatchedMatplotlib._current_task.get_logger() - title = plot_title or 'untitled %d' % PatchedMatplotlib._global_image_counter + # this is actually a failed plot, we should put it under plots: # currently disabled - # if image_format == 'svg': - # logger.report_image_plot_and_upload(title=title, series='plot image', path=image, - # iteration=PatchedMatplotlib._global_image_counter - # if plot_title else 0) - # else: - logger.report_image_and_upload(title=title, series='plot image', path=image, - iteration=PatchedMatplotlib._global_image_counter - if plot_title else 0) + if force_save_as_image or not PatchedMatplotlib._support_image_plot: + # send the plot as image + PatchedMatplotlib._global_image_counter += 1 + title = plot_title or 'untitled %d' % PatchedMatplotlib._global_image_counter + + logger.report_image_and_upload(title=title, series='plot image', path=image, + delete_after_upload=True, + iteration=PatchedMatplotlib._global_image_counter + if plot_title else 0) + else: + # send the plot as plotly with embedded image + PatchedMatplotlib._global_plot_counter += 1 + title = plot_title or 'untitled %d' % PatchedMatplotlib._global_plot_counter + + logger.report_image_plot_and_upload(title=title, series='plot image', path=image, + delete_after_upload=True, + iteration=PatchedMatplotlib._global_plot_counter + if plot_title else 0) + except Exception: # plotly failed pass