diff --git a/trains/binding/matplotlib_bind.py b/trains/binding/matplotlib_bind.py index ad1dc21f..5da8d363 100644 --- a/trains/binding/matplotlib_bind.py +++ b/trains/binding/matplotlib_bind.py @@ -18,7 +18,6 @@ class PatchedMatplotlib: _patched_original_plot = None _patched_original_figure = None _patched_original_savefig = None - _patched_original_savefig_plt_pylab = None __patched_original_imshow = None __patched_original_draw_all = None __patched_draw_all_recursion_guard = False @@ -51,6 +50,9 @@ class PatchedMatplotlib: # only once if PatchedMatplotlib._patched_original_plot is not None: return True + # make sure we only patch once + PatchedMatplotlib._patched_original_plot = False + # noinspection PyBroadException try: # we support matplotlib version 2.0.0 and above @@ -70,26 +72,29 @@ class PatchedMatplotlib: matplotlib.pyplot.switch_backend('agg') import matplotlib.pyplot as plt import matplotlib.figure as figure - import matplotlib.pylab as plt_pylab - from matplotlib import _pylab_helpers if six.PY2: PatchedMatplotlib._patched_original_plot = staticmethod(plt.show) PatchedMatplotlib._patched_original_imshow = staticmethod(plt.imshow) PatchedMatplotlib._patched_original_figure = staticmethod(figure.Figure.show) - PatchedMatplotlib._patched_original_savefig = staticmethod(plt.savefig) - PatchedMatplotlib._patched_original_savefig_plt_pylab = staticmethod(plt_pylab.savefig) + PatchedMatplotlib._patched_original_savefig = staticmethod(figure.Figure.savefig) else: PatchedMatplotlib._patched_original_plot = plt.show PatchedMatplotlib._patched_original_imshow = plt.imshow PatchedMatplotlib._patched_original_figure = figure.Figure.show - PatchedMatplotlib._patched_original_savefig = plt.savefig - PatchedMatplotlib._patched_original_savefig_plt_pylab = plt_pylab.savefig + PatchedMatplotlib._patched_original_savefig = figure.Figure.savefig + try: + import matplotlib.pylab as pltlab + if plt.show == pltlab.show: + pltlab.show = PatchedMatplotlib.patched_show + if plt.imshow == pltlab.imshow: + pltlab.imshow = PatchedMatplotlib.patched_imshow + except: + pass plt.show = PatchedMatplotlib.patched_show figure.Figure.show = PatchedMatplotlib.patched_figure_show sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow - sys.modules['matplotlib'].pyplot.savefig = PatchedMatplotlib.patched_savefig - sys.modules['matplotlib'].pylab.savefig = PatchedMatplotlib.patched_savefig_pylab + sys.modules['matplotlib'].figure.Figure.savefig = PatchedMatplotlib.patched_savefig # patch plotly so we know it failed us. from plotly.matplotlylib import renderer renderer.warnings = PatchedMatplotlib._PatchWarnings() @@ -145,7 +150,8 @@ class PatchedMatplotlib: if PatchedMatplotlib._patched_original_plot is not None: PatchedMatplotlib._current_task = task # if matplotlib is not loaded yet, get a callback hook - elif not running_remotely() and 'matplotlib.pyplot' not in sys.modules: + elif not running_remotely() and \ + ('matplotlib.pyplot' not in sys.modules and 'matplotlib.pylab' not in sys.modules): PatchedMatplotlib._current_task = task PostImportHookPatching.add_on_import('matplotlib.pyplot', PatchedMatplotlib.patch_matplotlib) PostImportHookPatching.add_on_import('matplotlib.pylab', PatchedMatplotlib.patch_matplotlib) @@ -167,31 +173,19 @@ class PatchedMatplotlib: return ret @staticmethod - def patched_savefig(*args, **kw): - ret = PatchedMatplotlib._patched_original_savefig(*args, **kw) - PatchedMatplotlib._report_savefig_figure() - return ret - - @staticmethod - def patched_savefig_pylab(*args, **kw): - ret = PatchedMatplotlib._patched_original_savefig_plt_pylab(*args, **kw) - PatchedMatplotlib._report_savefig_figure() - return ret - - @staticmethod - def _report_savefig_figure(): + def patched_savefig(self, *args, **kw): + ret = PatchedMatplotlib._patched_original_savefig(self, *args, **kw) tid = threading._get_ident() if six.PY2 else threading.get_ident() if not PatchedMatplotlib._recursion_guard.get(tid): PatchedMatplotlib._recursion_guard[tid] = True # noinspection PyBroadException try: - figures = PatchedMatplotlib._get_output_figures(None, all_figures=True) - for figure in figures: - if figure.canvas.figure: - PatchedMatplotlib._report_figure(stored_figure=figure) + PatchedMatplotlib._report_figure(specific_fig=self, set_active=False) except Exception: pass - PatchedMatplotlib._recursion_guard[tid] = False + PatchedMatplotlib._recursion_guard[tid] = False + + return ret @staticmethod def patched_figure_show(self, *args, **kw):