Fix support for plotly 4.0 and matplotlib compatibility

This commit is contained in:
allegroai 2019-08-19 21:19:44 +03:00
parent a896f5b465
commit ad5a44e906

View File

@ -2,6 +2,7 @@
import os
import sys
from copy import deepcopy
from tempfile import mkstemp
import six
@ -22,7 +23,11 @@ class PatchedMatplotlib:
_global_image_counter = -1
_current_task = None
_support_image_plot = False
_matplotlylib = None
_plotly_renderer = None
_lock_renderer = threading.RLock()
_recursion_guard = {}
_matplot_major_version = 2
class _PatchWarnings(object):
def __init__(self):
@ -45,8 +50,8 @@ class PatchedMatplotlib:
try:
# we support matplotlib version 2.0.0 and above
import matplotlib
matplot_major_version = int(matplotlib.__version__.split('.')[0])
if matplot_major_version < 2:
PatchedMatplotlib._matplot_major_version = int(matplotlib.__version__.split('.')[0])
if PatchedMatplotlib._matplot_major_version < 2:
LoggerRoot.get_base_logger().warning(
'matplotlib binding supports version 2.0 and above, found version {}'.format(
matplotlib.__version__))
@ -75,6 +80,15 @@ class PatchedMatplotlib:
# patch plotly so we know it failed us.
from plotly.matplotlylib import renderer
renderer.warnings = PatchedMatplotlib._PatchWarnings()
# ignore deprecation warnings from plotly to matplotlib
try:
import warnings
warnings.filterwarnings(action='ignore', category=matplotlib.MatplotlibDeprecationWarning,
module='plotly')
except Exception:
pass
except Exception:
return False
@ -101,6 +115,12 @@ class PatchedMatplotlib:
PatchedMatplotlib._current_task = task
from ..backend_api import Session
PatchedMatplotlib._support_image_plot = Session.api_version > '2.1'
try:
from plotly import optional_imports
PatchedMatplotlib._matplotlylib = optional_imports.get_module('plotly.matplotlylib')
PatchedMatplotlib._plotly_renderer = PatchedMatplotlib._matplotlylib.PlotlyRenderer()
except Exception:
pass
@staticmethod
def patched_imshow(*args, **kw):
@ -160,8 +180,8 @@ class PatchedMatplotlib:
# noinspection PyBroadException
try:
import matplotlib.pyplot as plt
from plotly import optional_imports
from matplotlib import _pylab_helpers
from plotly.io import templates
if specific_fig is None:
# store the figure object we just created (if it is not already there)
stored_figure = stored_figure or _pylab_helpers.Gcf.get_active()
@ -187,46 +207,55 @@ class PatchedMatplotlib:
fig_dpi = None
else:
image_format = 'svg'
# protect with lock, so we support multiple threads using the same renderer
PatchedMatplotlib._lock_renderer.acquire()
# noinspection PyBroadException
try:
def our_mpl_to_plotly(fig):
matplotlylib = optional_imports.get_module('plotly.matplotlylib')
if matplotlylib:
renderer = matplotlylib.PlotlyRenderer()
matplotlylib.Exporter(renderer, close_mpl=False).run(fig)
x_ticks = list(renderer.current_mpl_ax.get_xticklabels())
if x_ticks:
if not PatchedMatplotlib._matplotlylib or not PatchedMatplotlib._plotly_renderer:
return None
PatchedMatplotlib._matplotlylib.Exporter(PatchedMatplotlib._plotly_renderer,
close_mpl=False).run(fig)
x_ticks = list(PatchedMatplotlib._plotly_renderer.current_mpl_ax.get_xticklabels())
if x_ticks:
try:
# check if all values can be cast to float
values = [float(t.get_text().replace('', '-')) for t in x_ticks]
except:
try:
# check if all values can be cast to float
values = [float(t.get_text().replace('', '-')) for t in x_ticks]
PatchedMatplotlib._plotly_renderer.plotly_fig['layout']['xaxis1'].update({
'ticktext': [t.get_text() for t in x_ticks],
'tickvals': [t.get_position()[0] for t in x_ticks],
})
except:
try:
renderer.plotly_fig['layout']['xaxis1'].update({
'ticktext': [t.get_text() for t in x_ticks],
'tickvals': [t.get_position()[0] for t in x_ticks],
})
except:
pass
y_ticks = list(renderer.current_mpl_ax.get_yticklabels())
if y_ticks:
pass
y_ticks = list(PatchedMatplotlib._plotly_renderer.current_mpl_ax.get_yticklabels())
if y_ticks:
try:
# check if all values can be cast to float
values = [float(t.get_text().replace('', '-')) for t in y_ticks]
except:
try:
# check if all values can be cast to float
values = [float(t.get_text().replace('', '-')) for t in y_ticks]
PatchedMatplotlib._plotly_renderer.plotly_fig['layout']['yaxis1'].update({
'ticktext': [t.get_text() for t in y_ticks],
'tickvals': [t.get_position()[1] for t in y_ticks],
})
except:
try:
renderer.plotly_fig['layout']['yaxis1'].update({
'ticktext': [t.get_text() for t in y_ticks],
'tickvals': [t.get_position()[1] for t in y_ticks],
})
except:
pass
return renderer.plotly_fig
pass
return deepcopy(PatchedMatplotlib._plotly_renderer.plotly_fig)
plotly_fig = our_mpl_to_plotly(mpl_fig)
try:
if 'none' in templates:
plotly_fig._layout_obj.template = templates['none']
except:
pass
except Exception as ex:
# this was an image, change format to png
image_format = 'jpeg' if 'selfie' in str(ex) else 'png'
fig_dpi = 300
finally:
PatchedMatplotlib._lock_renderer.release()
# plotly could not serialize the plot, we should convert to image
if not plotly_fig:
@ -236,13 +265,23 @@ class PatchedMatplotlib:
# first try SVG if we fail then fallback to png
buffer_ = BytesIO()
a_plt = specific_fig if specific_fig is not None else plt
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
if PatchedMatplotlib._matplot_major_version < 3:
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0,
frameon=False)
else:
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0,
facecolor=None)
buffer_.seek(0)
except Exception:
image_format = 'png'
buffer_ = BytesIO()
a_plt = specific_fig if specific_fig is not None else plt
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
if PatchedMatplotlib._matplot_major_version < 3:
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0,
frameon=False)
else:
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0,
facecolor=None)
buffer_.seek(0)
fd, image = mkstemp(suffix='.'+image_format)
os.write(fd, buffer_.read())