diff --git a/trains/binding/matplotlib_bind.py b/trains/binding/matplotlib_bind.py index 1a4a42d1..715c5be7 100644 --- a/trains/binding/matplotlib_bind.py +++ b/trains/binding/matplotlib_bind.py @@ -30,9 +30,11 @@ class PatchedMatplotlib: _support_image_plot = False _matplotlylib = None _plotly_renderer = None + _patched_mpltools_get_spine_visible = False _lock_renderer = threading.RLock() _recursion_guard = {} _matplot_major_version = 2 + _matplot_minor_version = 0 _logger_started_reporting = False _matplotlib_reported_titles = set() @@ -60,7 +62,9 @@ class PatchedMatplotlib: try: # we support matplotlib version 2.0.0 and above import matplotlib - PatchedMatplotlib._matplot_major_version = int(matplotlib.__version__.split('.')[0]) + version_split = matplotlib.__version__.split('.') + PatchedMatplotlib._matplot_major_version = int(version_split[0]) + PatchedMatplotlib._matplot_minor_version = int(version_split[1]) if PatchedMatplotlib._matplot_major_version < 2: LoggerRoot.get_base_logger().warning( 'matplotlib binding supports version 2.0 and above, found version {}'.format( @@ -289,6 +293,13 @@ class PatchedMatplotlib: def our_mpl_to_plotly(fig): if not PatchedMatplotlib._matplotlylib or not PatchedMatplotlib._plotly_renderer: return None + if not PatchedMatplotlib._patched_mpltools_get_spine_visible and \ + PatchedMatplotlib._matplot_major_version and \ + PatchedMatplotlib._matplot_major_version >= 3 and \ + PatchedMatplotlib._matplot_minor_version >= 3: + from plotly.matplotlylib import mpltools + mpltools.get_spine_visible = lambda *_, **__: True + PatchedMatplotlib._patched_mpltools_get_spine_visible = True plotly_renderer = PatchedMatplotlib._matplotlylib.PlotlyRenderer() PatchedMatplotlib._matplotlylib.Exporter(plotly_renderer, close_mpl=False).run(fig)