Add support for pylab.savefig in matplotlib binding

This commit is contained in:
allegroai 2020-01-13 17:16:56 +02:00
parent 0ecd734fd1
commit fcaff82980

View File

@ -18,6 +18,7 @@ class PatchedMatplotlib:
_patched_original_plot = None _patched_original_plot = None
_patched_original_figure = None _patched_original_figure = None
_patched_original_savefig = None _patched_original_savefig = None
_patched_original_savefig_plt_pylab = None
__patched_original_imshow = None __patched_original_imshow = None
__patched_original_draw_all = None __patched_original_draw_all = None
__patched_draw_all_recursion_guard = False __patched_draw_all_recursion_guard = False
@ -69,22 +70,26 @@ class PatchedMatplotlib:
matplotlib.pyplot.switch_backend('agg') matplotlib.pyplot.switch_backend('agg')
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.figure as figure import matplotlib.figure as figure
import matplotlib.pylab as plt_pylab
from matplotlib import _pylab_helpers from matplotlib import _pylab_helpers
if six.PY2: if six.PY2:
PatchedMatplotlib._patched_original_plot = staticmethod(plt.show) PatchedMatplotlib._patched_original_plot = staticmethod(plt.show)
PatchedMatplotlib._patched_original_imshow = staticmethod(plt.imshow) PatchedMatplotlib._patched_original_imshow = staticmethod(plt.imshow)
PatchedMatplotlib._patched_original_figure = staticmethod(figure.Figure.show) PatchedMatplotlib._patched_original_figure = staticmethod(figure.Figure.show)
PatchedMatplotlib._patched_original_savefig = staticmethod(plt.savefig) PatchedMatplotlib._patched_original_savefig = staticmethod(plt.savefig)
PatchedMatplotlib._patched_original_savefig_plt_pylab = staticmethod(plt_pylab.savefig)
else: else:
PatchedMatplotlib._patched_original_plot = plt.show PatchedMatplotlib._patched_original_plot = plt.show
PatchedMatplotlib._patched_original_imshow = plt.imshow PatchedMatplotlib._patched_original_imshow = plt.imshow
PatchedMatplotlib._patched_original_figure = figure.Figure.show PatchedMatplotlib._patched_original_figure = figure.Figure.show
PatchedMatplotlib._patched_original_savefig = plt.savefig PatchedMatplotlib._patched_original_savefig = plt.savefig
PatchedMatplotlib._patched_original_savefig_plt_pylab = plt_pylab.savefig
plt.show = PatchedMatplotlib.patched_show plt.show = PatchedMatplotlib.patched_show
figure.Figure.show = PatchedMatplotlib.patched_figure_show figure.Figure.show = PatchedMatplotlib.patched_figure_show
sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow
sys.modules['matplotlib'].pyplot.savefig = PatchedMatplotlib.patched_savefig sys.modules['matplotlib'].pyplot.savefig = PatchedMatplotlib.patched_savefig
sys.modules['matplotlib'].pylab.savefig = PatchedMatplotlib.patched_savefig_pylab
# patch plotly so we know it failed us. # patch plotly so we know it failed us.
from plotly.matplotlylib import renderer from plotly.matplotlylib import renderer
renderer.warnings = PatchedMatplotlib._PatchWarnings() renderer.warnings = PatchedMatplotlib._PatchWarnings()
@ -143,6 +148,7 @@ class PatchedMatplotlib:
elif not running_remotely() and 'matplotlib.pyplot' not in sys.modules: elif not running_remotely() and 'matplotlib.pyplot' not in sys.modules:
PatchedMatplotlib._current_task = task PatchedMatplotlib._current_task = task
PostImportHookPatching.add_on_import('matplotlib.pyplot', PatchedMatplotlib.patch_matplotlib) PostImportHookPatching.add_on_import('matplotlib.pyplot', PatchedMatplotlib.patch_matplotlib)
PostImportHookPatching.add_on_import('matplotlib.pylab', PatchedMatplotlib.patch_matplotlib)
elif PatchedMatplotlib.patch_matplotlib(): elif PatchedMatplotlib.patch_matplotlib():
PatchedMatplotlib._current_task = task PatchedMatplotlib._current_task = task
@ -163,6 +169,17 @@ class PatchedMatplotlib:
@staticmethod @staticmethod
def patched_savefig(*args, **kw): def patched_savefig(*args, **kw):
ret = PatchedMatplotlib._patched_original_savefig(*args, **kw) ret = PatchedMatplotlib._patched_original_savefig(*args, **kw)
PatchedMatplotlib._report_savefig_figure()
return ret
@staticmethod
def patched_savefig_pylab(*args, **kw):
ret = PatchedMatplotlib._patched_original_savefig_plt_pylab(*args, **kw)
PatchedMatplotlib._report_savefig_figure()
return ret
@staticmethod
def _report_savefig_figure():
tid = threading._get_ident() if six.PY2 else threading.get_ident() tid = threading._get_ident() if six.PY2 else threading.get_ident()
if not PatchedMatplotlib._recursion_guard.get(tid): if not PatchedMatplotlib._recursion_guard.get(tid):
PatchedMatplotlib._recursion_guard[tid] = True PatchedMatplotlib._recursion_guard[tid] = True
@ -175,7 +192,6 @@ class PatchedMatplotlib:
except Exception: except Exception:
pass pass
PatchedMatplotlib._recursion_guard[tid] = False PatchedMatplotlib._recursion_guard[tid] = False
return ret
@staticmethod @staticmethod
def patched_figure_show(self, *args, **kw): def patched_figure_show(self, *args, **kw):