Add support for savefig in matplotlib binding

This commit is contained in:
allegroai 2020-01-13 12:06:47 +02:00
parent 66b251a62b
commit 1c6be01e38

View File

@ -17,6 +17,7 @@ from .import_bind import PostImportHookPatching
class PatchedMatplotlib:
_patched_original_plot = None
_patched_original_figure = None
_patched_original_savefig = None
__patched_original_imshow = None
__patched_original_draw_all = None
__patched_draw_all_recursion_guard = False
@ -73,14 +74,17 @@ class PatchedMatplotlib:
PatchedMatplotlib._patched_original_plot = staticmethod(plt.show)
PatchedMatplotlib._patched_original_imshow = staticmethod(plt.imshow)
PatchedMatplotlib._patched_original_figure = staticmethod(figure.Figure.show)
PatchedMatplotlib._patched_original_savefig = staticmethod(plt.savefig)
else:
PatchedMatplotlib._patched_original_plot = plt.show
PatchedMatplotlib._patched_original_imshow = plt.imshow
PatchedMatplotlib._patched_original_figure = figure.Figure.show
PatchedMatplotlib._patched_original_savefig = plt.savefig
plt.show = PatchedMatplotlib.patched_show
figure.Figure.show = PatchedMatplotlib.patched_figure_show
sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow
sys.modules['matplotlib'].pyplot.savefig = PatchedMatplotlib.patched_savefig
# patch plotly so we know it failed us.
from plotly.matplotlylib import renderer
renderer.warnings = PatchedMatplotlib._PatchWarnings()
@ -156,6 +160,23 @@ class PatchedMatplotlib:
pass
return ret
@staticmethod
def patched_savefig(*args, **kw):
ret = PatchedMatplotlib._patched_original_savefig(*args, **kw)
tid = threading._get_ident() if six.PY2 else threading.get_ident()
if not PatchedMatplotlib._recursion_guard.get(tid):
PatchedMatplotlib._recursion_guard[tid] = True
# noinspection PyBroadException
try:
figures = PatchedMatplotlib._get_output_figures(None, all_figures=True)
for figure in figures:
if figure.canvas.figure:
PatchedMatplotlib._report_figure(stored_figure=figure)
except Exception:
pass
PatchedMatplotlib._recursion_guard[tid] = False
return ret
@staticmethod
def patched_figure_show(self, *args, **kw):
tid = threading._get_ident() if six.PY2 else threading.get_ident()