Fix matplotlib with Agg backend (or in remote execution)

This commit is contained in:
allegroai 2021-03-03 15:05:22 +02:00
parent ecc539ffb6
commit 912f6f5ba2

View File

@ -195,7 +195,7 @@ class PatchedMatplotlib:
# store on the plot that this is an imshow plot # store on the plot that this is an imshow plot
stored_figure = _pylab_helpers.Gcf.get_active() stored_figure = _pylab_helpers.Gcf.get_active()
if stored_figure: if stored_figure:
stored_figure._trains_is_imshow = 1 if not hasattr(stored_figure, '_trains_is_imshow') \ stored_figure._trains_is_imshow = 1 if getattr(stored_figure, '_trains_is_imshow', None) is None \
else stored_figure._trains_is_imshow + 1 else stored_figure._trains_is_imshow + 1
except Exception: except Exception:
pass pass
@ -253,7 +253,8 @@ class PatchedMatplotlib:
figures = PatchedMatplotlib._get_output_figures(None, all_figures=True) figures = PatchedMatplotlib._get_output_figures(None, all_figures=True)
for figure in figures: for figure in figures:
# if this is a stale figure (just updated) we should send it, the rest will not be stale # if this is a stale figure (just updated) we should send it, the rest will not be stale
if figure.canvas.figure.stale or (hasattr(figure, '_trains_is_imshow') and figure._trains_is_imshow): if figure.canvas.figure.stale or (
getattr(figure, '_trains_is_imshow', None) is not None and figure._trains_is_imshow):
PatchedMatplotlib._report_figure(stored_figure=figure) PatchedMatplotlib._report_figure(stored_figure=figure)
except Exception: except Exception:
pass pass
@ -263,6 +264,14 @@ class PatchedMatplotlib:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if sys.modules['matplotlib'].rcParams['backend'] == 'agg': if sys.modules['matplotlib'].rcParams['backend'] == 'agg':
# noinspection PyBroadException
try:
from matplotlib import _pylab_helpers
stored_figure = _pylab_helpers.Gcf.get_active()
stored_figure._trains_is_imshow = None
stored_figure.canvas.figure._trains_explicit = False
except Exception:
pass
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
plt.clf() plt.clf()
except Exception: except Exception:
@ -324,7 +333,7 @@ class PatchedMatplotlib:
# nothing for us to do # nothing for us to do
return return
# check if this is an imshow # check if this is an imshow
if hasattr(stored_figure, '_trains_is_imshow'): if getattr(stored_figure, '_trains_is_imshow', None) is not None:
# flag will be cleared when calling clf() (object will be replaced) # flag will be cleared when calling clf() (object will be replaced)
stored_figure._trains_is_imshow = max(0, stored_figure._trains_is_imshow - 1) stored_figure._trains_is_imshow = max(0, stored_figure._trains_is_imshow - 1)
force_save_as_image = True force_save_as_image = True