mirror of
https://github.com/clearml/clearml
synced 2025-03-09 21:40:51 +00:00
Fix support for plotly 4.0 and matplotlib compatibility
This commit is contained in:
parent
a896f5b465
commit
ad5a44e906
@ -2,6 +2,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from tempfile import mkstemp
|
||||
|
||||
import six
|
||||
@ -22,7 +23,11 @@ class PatchedMatplotlib:
|
||||
_global_image_counter = -1
|
||||
_current_task = None
|
||||
_support_image_plot = False
|
||||
_matplotlylib = None
|
||||
_plotly_renderer = None
|
||||
_lock_renderer = threading.RLock()
|
||||
_recursion_guard = {}
|
||||
_matplot_major_version = 2
|
||||
|
||||
class _PatchWarnings(object):
|
||||
def __init__(self):
|
||||
@ -45,8 +50,8 @@ class PatchedMatplotlib:
|
||||
try:
|
||||
# we support matplotlib version 2.0.0 and above
|
||||
import matplotlib
|
||||
matplot_major_version = int(matplotlib.__version__.split('.')[0])
|
||||
if matplot_major_version < 2:
|
||||
PatchedMatplotlib._matplot_major_version = int(matplotlib.__version__.split('.')[0])
|
||||
if PatchedMatplotlib._matplot_major_version < 2:
|
||||
LoggerRoot.get_base_logger().warning(
|
||||
'matplotlib binding supports version 2.0 and above, found version {}'.format(
|
||||
matplotlib.__version__))
|
||||
@ -75,6 +80,15 @@ class PatchedMatplotlib:
|
||||
# patch plotly so we know it failed us.
|
||||
from plotly.matplotlylib import renderer
|
||||
renderer.warnings = PatchedMatplotlib._PatchWarnings()
|
||||
|
||||
# ignore deprecation warnings from plotly to matplotlib
|
||||
try:
|
||||
import warnings
|
||||
warnings.filterwarnings(action='ignore', category=matplotlib.MatplotlibDeprecationWarning,
|
||||
module='plotly')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@ -101,6 +115,12 @@ class PatchedMatplotlib:
|
||||
PatchedMatplotlib._current_task = task
|
||||
from ..backend_api import Session
|
||||
PatchedMatplotlib._support_image_plot = Session.api_version > '2.1'
|
||||
try:
|
||||
from plotly import optional_imports
|
||||
PatchedMatplotlib._matplotlylib = optional_imports.get_module('plotly.matplotlylib')
|
||||
PatchedMatplotlib._plotly_renderer = PatchedMatplotlib._matplotlylib.PlotlyRenderer()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def patched_imshow(*args, **kw):
|
||||
@ -160,8 +180,8 @@ class PatchedMatplotlib:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
from plotly import optional_imports
|
||||
from matplotlib import _pylab_helpers
|
||||
from plotly.io import templates
|
||||
if specific_fig is None:
|
||||
# store the figure object we just created (if it is not already there)
|
||||
stored_figure = stored_figure or _pylab_helpers.Gcf.get_active()
|
||||
@ -187,46 +207,55 @@ class PatchedMatplotlib:
|
||||
fig_dpi = None
|
||||
else:
|
||||
image_format = 'svg'
|
||||
# protect with lock, so we support multiple threads using the same renderer
|
||||
PatchedMatplotlib._lock_renderer.acquire()
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
def our_mpl_to_plotly(fig):
|
||||
matplotlylib = optional_imports.get_module('plotly.matplotlylib')
|
||||
if matplotlylib:
|
||||
renderer = matplotlylib.PlotlyRenderer()
|
||||
matplotlylib.Exporter(renderer, close_mpl=False).run(fig)
|
||||
x_ticks = list(renderer.current_mpl_ax.get_xticklabels())
|
||||
if x_ticks:
|
||||
if not PatchedMatplotlib._matplotlylib or not PatchedMatplotlib._plotly_renderer:
|
||||
return None
|
||||
PatchedMatplotlib._matplotlylib.Exporter(PatchedMatplotlib._plotly_renderer,
|
||||
close_mpl=False).run(fig)
|
||||
x_ticks = list(PatchedMatplotlib._plotly_renderer.current_mpl_ax.get_xticklabels())
|
||||
if x_ticks:
|
||||
try:
|
||||
# check if all values can be cast to float
|
||||
values = [float(t.get_text().replace('−', '-')) for t in x_ticks]
|
||||
except:
|
||||
try:
|
||||
# check if all values can be cast to float
|
||||
values = [float(t.get_text().replace('−', '-')) for t in x_ticks]
|
||||
PatchedMatplotlib._plotly_renderer.plotly_fig['layout']['xaxis1'].update({
|
||||
'ticktext': [t.get_text() for t in x_ticks],
|
||||
'tickvals': [t.get_position()[0] for t in x_ticks],
|
||||
})
|
||||
except:
|
||||
try:
|
||||
renderer.plotly_fig['layout']['xaxis1'].update({
|
||||
'ticktext': [t.get_text() for t in x_ticks],
|
||||
'tickvals': [t.get_position()[0] for t in x_ticks],
|
||||
})
|
||||
except:
|
||||
pass
|
||||
y_ticks = list(renderer.current_mpl_ax.get_yticklabels())
|
||||
if y_ticks:
|
||||
pass
|
||||
y_ticks = list(PatchedMatplotlib._plotly_renderer.current_mpl_ax.get_yticklabels())
|
||||
if y_ticks:
|
||||
try:
|
||||
# check if all values can be cast to float
|
||||
values = [float(t.get_text().replace('−', '-')) for t in y_ticks]
|
||||
except:
|
||||
try:
|
||||
# check if all values can be cast to float
|
||||
values = [float(t.get_text().replace('−', '-')) for t in y_ticks]
|
||||
PatchedMatplotlib._plotly_renderer.plotly_fig['layout']['yaxis1'].update({
|
||||
'ticktext': [t.get_text() for t in y_ticks],
|
||||
'tickvals': [t.get_position()[1] for t in y_ticks],
|
||||
})
|
||||
except:
|
||||
try:
|
||||
renderer.plotly_fig['layout']['yaxis1'].update({
|
||||
'ticktext': [t.get_text() for t in y_ticks],
|
||||
'tickvals': [t.get_position()[1] for t in y_ticks],
|
||||
})
|
||||
except:
|
||||
pass
|
||||
return renderer.plotly_fig
|
||||
pass
|
||||
return deepcopy(PatchedMatplotlib._plotly_renderer.plotly_fig)
|
||||
|
||||
plotly_fig = our_mpl_to_plotly(mpl_fig)
|
||||
try:
|
||||
if 'none' in templates:
|
||||
plotly_fig._layout_obj.template = templates['none']
|
||||
except:
|
||||
pass
|
||||
except Exception as ex:
|
||||
# this was an image, change format to png
|
||||
image_format = 'jpeg' if 'selfie' in str(ex) else 'png'
|
||||
fig_dpi = 300
|
||||
finally:
|
||||
PatchedMatplotlib._lock_renderer.release()
|
||||
|
||||
# plotly could not serialize the plot, we should convert to image
|
||||
if not plotly_fig:
|
||||
@ -236,13 +265,23 @@ class PatchedMatplotlib:
|
||||
# first try SVG if we fail then fallback to png
|
||||
buffer_ = BytesIO()
|
||||
a_plt = specific_fig if specific_fig is not None else plt
|
||||
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
|
||||
if PatchedMatplotlib._matplot_major_version < 3:
|
||||
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0,
|
||||
frameon=False)
|
||||
else:
|
||||
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0,
|
||||
facecolor=None)
|
||||
buffer_.seek(0)
|
||||
except Exception:
|
||||
image_format = 'png'
|
||||
buffer_ = BytesIO()
|
||||
a_plt = specific_fig if specific_fig is not None else plt
|
||||
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
|
||||
if PatchedMatplotlib._matplot_major_version < 3:
|
||||
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0,
|
||||
frameon=False)
|
||||
else:
|
||||
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0,
|
||||
facecolor=None)
|
||||
buffer_.seek(0)
|
||||
fd, image = mkstemp(suffix='.'+image_format)
|
||||
os.write(fd, buffer_.read())
|
||||
|
Loading…
Reference in New Issue
Block a user