From ad5a44e906fc6fa5bf9fd0879023b157053fead6 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 19 Aug 2019 21:19:44 +0300 Subject: [PATCH] Fix support for plotly 4.0 and matplotlib compatibility --- trains/binding/matplotlib_bind.py | 103 ++++++++++++++++++++---------- 1 file changed, 71 insertions(+), 32 deletions(-) diff --git a/trains/binding/matplotlib_bind.py b/trains/binding/matplotlib_bind.py index e39a4a04..46c2c9dc 100644 --- a/trains/binding/matplotlib_bind.py +++ b/trains/binding/matplotlib_bind.py @@ -2,6 +2,7 @@ import os import sys +from copy import deepcopy from tempfile import mkstemp import six @@ -22,7 +23,11 @@ class PatchedMatplotlib: _global_image_counter = -1 _current_task = None _support_image_plot = False + _matplotlylib = None + _plotly_renderer = None + _lock_renderer = threading.RLock() _recursion_guard = {} + _matplot_major_version = 2 class _PatchWarnings(object): def __init__(self): @@ -45,8 +50,8 @@ class PatchedMatplotlib: try: # we support matplotlib version 2.0.0 and above import matplotlib - matplot_major_version = int(matplotlib.__version__.split('.')[0]) - if matplot_major_version < 2: + PatchedMatplotlib._matplot_major_version = int(matplotlib.__version__.split('.')[0]) + if PatchedMatplotlib._matplot_major_version < 2: LoggerRoot.get_base_logger().warning( 'matplotlib binding supports version 2.0 and above, found version {}'.format( matplotlib.__version__)) @@ -75,6 +80,15 @@ class PatchedMatplotlib: # patch plotly so we know it failed us. from plotly.matplotlylib import renderer renderer.warnings = PatchedMatplotlib._PatchWarnings() + + # ignore deprecation warnings from plotly to matplotlib + try: + import warnings + warnings.filterwarnings(action='ignore', category=matplotlib.MatplotlibDeprecationWarning, + module='plotly') + except Exception: + pass + except Exception: return False @@ -101,6 +115,12 @@ class PatchedMatplotlib: PatchedMatplotlib._current_task = task from ..backend_api import Session PatchedMatplotlib._support_image_plot = Session.api_version > '2.1' + try: + from plotly import optional_imports + PatchedMatplotlib._matplotlylib = optional_imports.get_module('plotly.matplotlylib') + PatchedMatplotlib._plotly_renderer = PatchedMatplotlib._matplotlylib.PlotlyRenderer() + except Exception: + pass @staticmethod def patched_imshow(*args, **kw): @@ -160,8 +180,8 @@ class PatchedMatplotlib: # noinspection PyBroadException try: import matplotlib.pyplot as plt - from plotly import optional_imports from matplotlib import _pylab_helpers + from plotly.io import templates if specific_fig is None: # store the figure object we just created (if it is not already there) stored_figure = stored_figure or _pylab_helpers.Gcf.get_active() @@ -187,46 +207,55 @@ class PatchedMatplotlib: fig_dpi = None else: image_format = 'svg' + # protect with lock, so we support multiple threads using the same renderer + PatchedMatplotlib._lock_renderer.acquire() # noinspection PyBroadException try: def our_mpl_to_plotly(fig): - matplotlylib = optional_imports.get_module('plotly.matplotlylib') - if matplotlylib: - renderer = matplotlylib.PlotlyRenderer() - matplotlylib.Exporter(renderer, close_mpl=False).run(fig) - x_ticks = list(renderer.current_mpl_ax.get_xticklabels()) - if x_ticks: + if not PatchedMatplotlib._matplotlylib or not PatchedMatplotlib._plotly_renderer: + return None + PatchedMatplotlib._matplotlylib.Exporter(PatchedMatplotlib._plotly_renderer, + close_mpl=False).run(fig) + x_ticks = list(PatchedMatplotlib._plotly_renderer.current_mpl_ax.get_xticklabels()) + if x_ticks: + try: + # check if all values can be cast to float + values = [float(t.get_text().replace('−', '-')) for t in x_ticks] + except: try: - # check if all values can be cast to float - values = [float(t.get_text().replace('−', '-')) for t in x_ticks] + PatchedMatplotlib._plotly_renderer.plotly_fig['layout']['xaxis1'].update({ + 'ticktext': [t.get_text() for t in x_ticks], + 'tickvals': [t.get_position()[0] for t in x_ticks], + }) except: - try: - renderer.plotly_fig['layout']['xaxis1'].update({ - 'ticktext': [t.get_text() for t in x_ticks], - 'tickvals': [t.get_position()[0] for t in x_ticks], - }) - except: - pass - y_ticks = list(renderer.current_mpl_ax.get_yticklabels()) - if y_ticks: + pass + y_ticks = list(PatchedMatplotlib._plotly_renderer.current_mpl_ax.get_yticklabels()) + if y_ticks: + try: + # check if all values can be cast to float + values = [float(t.get_text().replace('−', '-')) for t in y_ticks] + except: try: - # check if all values can be cast to float - values = [float(t.get_text().replace('−', '-')) for t in y_ticks] + PatchedMatplotlib._plotly_renderer.plotly_fig['layout']['yaxis1'].update({ + 'ticktext': [t.get_text() for t in y_ticks], + 'tickvals': [t.get_position()[1] for t in y_ticks], + }) except: - try: - renderer.plotly_fig['layout']['yaxis1'].update({ - 'ticktext': [t.get_text() for t in y_ticks], - 'tickvals': [t.get_position()[1] for t in y_ticks], - }) - except: - pass - return renderer.plotly_fig + pass + return deepcopy(PatchedMatplotlib._plotly_renderer.plotly_fig) plotly_fig = our_mpl_to_plotly(mpl_fig) + try: + if 'none' in templates: + plotly_fig._layout_obj.template = templates['none'] + except: + pass except Exception as ex: # this was an image, change format to png image_format = 'jpeg' if 'selfie' in str(ex) else 'png' fig_dpi = 300 + finally: + PatchedMatplotlib._lock_renderer.release() # plotly could not serialize the plot, we should convert to image if not plotly_fig: @@ -236,13 +265,23 @@ class PatchedMatplotlib: # first try SVG if we fail then fallback to png buffer_ = BytesIO() a_plt = specific_fig if specific_fig is not None else plt - a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False) + if PatchedMatplotlib._matplot_major_version < 3: + a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, + frameon=False) + else: + a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, + facecolor=None) buffer_.seek(0) except Exception: image_format = 'png' buffer_ = BytesIO() a_plt = specific_fig if specific_fig is not None else plt - a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False) + if PatchedMatplotlib._matplot_major_version < 3: + a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, + frameon=False) + else: + a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, + facecolor=None) buffer_.seek(0) fd, image = mkstemp(suffix='.'+image_format) os.write(fd, buffer_.read())