diff --git a/trains/binding/matplotlib_bind.py b/trains/binding/matplotlib_bind.py index ca9e0e28..ad1dc21f 100644 --- a/trains/binding/matplotlib_bind.py +++ b/trains/binding/matplotlib_bind.py @@ -18,6 +18,7 @@ class PatchedMatplotlib: _patched_original_plot = None _patched_original_figure = None _patched_original_savefig = None + _patched_original_savefig_plt_pylab = None __patched_original_imshow = None __patched_original_draw_all = None __patched_draw_all_recursion_guard = False @@ -69,22 +70,26 @@ class PatchedMatplotlib: matplotlib.pyplot.switch_backend('agg') import matplotlib.pyplot as plt import matplotlib.figure as figure + import matplotlib.pylab as plt_pylab from matplotlib import _pylab_helpers if six.PY2: PatchedMatplotlib._patched_original_plot = staticmethod(plt.show) PatchedMatplotlib._patched_original_imshow = staticmethod(plt.imshow) PatchedMatplotlib._patched_original_figure = staticmethod(figure.Figure.show) PatchedMatplotlib._patched_original_savefig = staticmethod(plt.savefig) + PatchedMatplotlib._patched_original_savefig_plt_pylab = staticmethod(plt_pylab.savefig) else: PatchedMatplotlib._patched_original_plot = plt.show PatchedMatplotlib._patched_original_imshow = plt.imshow PatchedMatplotlib._patched_original_figure = figure.Figure.show PatchedMatplotlib._patched_original_savefig = plt.savefig + PatchedMatplotlib._patched_original_savefig_plt_pylab = plt_pylab.savefig plt.show = PatchedMatplotlib.patched_show figure.Figure.show = PatchedMatplotlib.patched_figure_show sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow sys.modules['matplotlib'].pyplot.savefig = PatchedMatplotlib.patched_savefig + sys.modules['matplotlib'].pylab.savefig = PatchedMatplotlib.patched_savefig_pylab # patch plotly so we know it failed us. from plotly.matplotlylib import renderer renderer.warnings = PatchedMatplotlib._PatchWarnings() @@ -143,6 +148,7 @@ class PatchedMatplotlib: elif not running_remotely() and 'matplotlib.pyplot' not in sys.modules: PatchedMatplotlib._current_task = task PostImportHookPatching.add_on_import('matplotlib.pyplot', PatchedMatplotlib.patch_matplotlib) + PostImportHookPatching.add_on_import('matplotlib.pylab', PatchedMatplotlib.patch_matplotlib) elif PatchedMatplotlib.patch_matplotlib(): PatchedMatplotlib._current_task = task @@ -163,6 +169,17 @@ class PatchedMatplotlib: @staticmethod def patched_savefig(*args, **kw): ret = PatchedMatplotlib._patched_original_savefig(*args, **kw) + PatchedMatplotlib._report_savefig_figure() + return ret + + @staticmethod + def patched_savefig_pylab(*args, **kw): + ret = PatchedMatplotlib._patched_original_savefig_plt_pylab(*args, **kw) + PatchedMatplotlib._report_savefig_figure() + return ret + + @staticmethod + def _report_savefig_figure(): tid = threading._get_ident() if six.PY2 else threading.get_ident() if not PatchedMatplotlib._recursion_guard.get(tid): PatchedMatplotlib._recursion_guard[tid] = True @@ -175,7 +192,6 @@ class PatchedMatplotlib: except Exception: pass PatchedMatplotlib._recursion_guard[tid] = False - return ret @staticmethod def patched_figure_show(self, *args, **kw):