mirror of
https://github.com/clearml/clearml
synced 2025-02-01 01:26:49 +00:00
Fix matplotlib savefig patching
This commit is contained in:
parent
c5dd762d9b
commit
f0a27127bf
@ -18,7 +18,6 @@ class PatchedMatplotlib:
|
||||
_patched_original_plot = None
|
||||
_patched_original_figure = None
|
||||
_patched_original_savefig = None
|
||||
_patched_original_savefig_plt_pylab = None
|
||||
__patched_original_imshow = None
|
||||
__patched_original_draw_all = None
|
||||
__patched_draw_all_recursion_guard = False
|
||||
@ -51,6 +50,9 @@ class PatchedMatplotlib:
|
||||
# only once
|
||||
if PatchedMatplotlib._patched_original_plot is not None:
|
||||
return True
|
||||
# make sure we only patch once
|
||||
PatchedMatplotlib._patched_original_plot = False
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# we support matplotlib version 2.0.0 and above
|
||||
@ -70,26 +72,29 @@ class PatchedMatplotlib:
|
||||
matplotlib.pyplot.switch_backend('agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.figure as figure
|
||||
import matplotlib.pylab as plt_pylab
|
||||
from matplotlib import _pylab_helpers
|
||||
if six.PY2:
|
||||
PatchedMatplotlib._patched_original_plot = staticmethod(plt.show)
|
||||
PatchedMatplotlib._patched_original_imshow = staticmethod(plt.imshow)
|
||||
PatchedMatplotlib._patched_original_figure = staticmethod(figure.Figure.show)
|
||||
PatchedMatplotlib._patched_original_savefig = staticmethod(plt.savefig)
|
||||
PatchedMatplotlib._patched_original_savefig_plt_pylab = staticmethod(plt_pylab.savefig)
|
||||
PatchedMatplotlib._patched_original_savefig = staticmethod(figure.Figure.savefig)
|
||||
else:
|
||||
PatchedMatplotlib._patched_original_plot = plt.show
|
||||
PatchedMatplotlib._patched_original_imshow = plt.imshow
|
||||
PatchedMatplotlib._patched_original_figure = figure.Figure.show
|
||||
PatchedMatplotlib._patched_original_savefig = plt.savefig
|
||||
PatchedMatplotlib._patched_original_savefig_plt_pylab = plt_pylab.savefig
|
||||
PatchedMatplotlib._patched_original_savefig = figure.Figure.savefig
|
||||
|
||||
try:
|
||||
import matplotlib.pylab as pltlab
|
||||
if plt.show == pltlab.show:
|
||||
pltlab.show = PatchedMatplotlib.patched_show
|
||||
if plt.imshow == pltlab.imshow:
|
||||
pltlab.imshow = PatchedMatplotlib.patched_imshow
|
||||
except:
|
||||
pass
|
||||
plt.show = PatchedMatplotlib.patched_show
|
||||
figure.Figure.show = PatchedMatplotlib.patched_figure_show
|
||||
sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow
|
||||
sys.modules['matplotlib'].pyplot.savefig = PatchedMatplotlib.patched_savefig
|
||||
sys.modules['matplotlib'].pylab.savefig = PatchedMatplotlib.patched_savefig_pylab
|
||||
sys.modules['matplotlib'].figure.Figure.savefig = PatchedMatplotlib.patched_savefig
|
||||
# patch plotly so we know it failed us.
|
||||
from plotly.matplotlylib import renderer
|
||||
renderer.warnings = PatchedMatplotlib._PatchWarnings()
|
||||
@ -145,7 +150,8 @@ class PatchedMatplotlib:
|
||||
if PatchedMatplotlib._patched_original_plot is not None:
|
||||
PatchedMatplotlib._current_task = task
|
||||
# if matplotlib is not loaded yet, get a callback hook
|
||||
elif not running_remotely() and 'matplotlib.pyplot' not in sys.modules:
|
||||
elif not running_remotely() and \
|
||||
('matplotlib.pyplot' not in sys.modules and 'matplotlib.pylab' not in sys.modules):
|
||||
PatchedMatplotlib._current_task = task
|
||||
PostImportHookPatching.add_on_import('matplotlib.pyplot', PatchedMatplotlib.patch_matplotlib)
|
||||
PostImportHookPatching.add_on_import('matplotlib.pylab', PatchedMatplotlib.patch_matplotlib)
|
||||
@ -167,31 +173,19 @@ class PatchedMatplotlib:
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def patched_savefig(*args, **kw):
|
||||
ret = PatchedMatplotlib._patched_original_savefig(*args, **kw)
|
||||
PatchedMatplotlib._report_savefig_figure()
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def patched_savefig_pylab(*args, **kw):
|
||||
ret = PatchedMatplotlib._patched_original_savefig_plt_pylab(*args, **kw)
|
||||
PatchedMatplotlib._report_savefig_figure()
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _report_savefig_figure():
|
||||
def patched_savefig(self, *args, **kw):
|
||||
ret = PatchedMatplotlib._patched_original_savefig(self, *args, **kw)
|
||||
tid = threading._get_ident() if six.PY2 else threading.get_ident()
|
||||
if not PatchedMatplotlib._recursion_guard.get(tid):
|
||||
PatchedMatplotlib._recursion_guard[tid] = True
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
figures = PatchedMatplotlib._get_output_figures(None, all_figures=True)
|
||||
for figure in figures:
|
||||
if figure.canvas.figure:
|
||||
PatchedMatplotlib._report_figure(stored_figure=figure)
|
||||
PatchedMatplotlib._report_figure(specific_fig=self, set_active=False)
|
||||
except Exception:
|
||||
pass
|
||||
PatchedMatplotlib._recursion_guard[tid] = False
|
||||
PatchedMatplotlib._recursion_guard[tid] = False
|
||||
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def patched_figure_show(self, *args, **kw):
|
||||
|
Loading…
Reference in New Issue
Block a user