From 1c6be01e381e81802aec393fc869c52bb2f6158a Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 13 Jan 2020 12:06:47 +0200 Subject: [PATCH] Add support for savefig in matplotlib binding --- trains/binding/matplotlib_bind.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/trains/binding/matplotlib_bind.py b/trains/binding/matplotlib_bind.py index 742a4dfd..ca9e0e28 100644 --- a/trains/binding/matplotlib_bind.py +++ b/trains/binding/matplotlib_bind.py @@ -17,6 +17,7 @@ from .import_bind import PostImportHookPatching class PatchedMatplotlib: _patched_original_plot = None _patched_original_figure = None + _patched_original_savefig = None __patched_original_imshow = None __patched_original_draw_all = None __patched_draw_all_recursion_guard = False @@ -73,14 +74,17 @@ class PatchedMatplotlib: PatchedMatplotlib._patched_original_plot = staticmethod(plt.show) PatchedMatplotlib._patched_original_imshow = staticmethod(plt.imshow) PatchedMatplotlib._patched_original_figure = staticmethod(figure.Figure.show) + PatchedMatplotlib._patched_original_savefig = staticmethod(plt.savefig) else: PatchedMatplotlib._patched_original_plot = plt.show PatchedMatplotlib._patched_original_imshow = plt.imshow PatchedMatplotlib._patched_original_figure = figure.Figure.show + PatchedMatplotlib._patched_original_savefig = plt.savefig plt.show = PatchedMatplotlib.patched_show figure.Figure.show = PatchedMatplotlib.patched_figure_show sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow + sys.modules['matplotlib'].pyplot.savefig = PatchedMatplotlib.patched_savefig # patch plotly so we know it failed us. from plotly.matplotlylib import renderer renderer.warnings = PatchedMatplotlib._PatchWarnings() @@ -156,6 +160,23 @@ class PatchedMatplotlib: pass return ret + @staticmethod + def patched_savefig(*args, **kw): + ret = PatchedMatplotlib._patched_original_savefig(*args, **kw) + tid = threading._get_ident() if six.PY2 else threading.get_ident() + if not PatchedMatplotlib._recursion_guard.get(tid): + PatchedMatplotlib._recursion_guard[tid] = True + # noinspection PyBroadException + try: + figures = PatchedMatplotlib._get_output_figures(None, all_figures=True) + for figure in figures: + if figure.canvas.figure: + PatchedMatplotlib._report_figure(stored_figure=figure) + except Exception: + pass + PatchedMatplotlib._recursion_guard[tid] = False + return ret + @staticmethod def patched_figure_show(self, *args, **kw): tid = threading._get_ident() if six.PY2 else threading.get_ident()