Fix Matplotlib figure.show support

Improve Matplotlib support, plt.show will show under plots tabs, and plt.imshow under debug images
This commit is contained in:
allegroai 2019-07-28 21:06:47 +03:00
parent 787da4798b
commit 67df9774e1

View File

@ -11,12 +11,14 @@ from ..config import running_remotely
class PatchedMatplotlib: class PatchedMatplotlib:
_patched_original_plot = None _patched_original_plot = None
_patched_original_figure = None
__patched_original_imshow = None __patched_original_imshow = None
__patched_original_draw_all = None __patched_original_draw_all = None
__patched_draw_all_recursion_guard = False __patched_draw_all_recursion_guard = False
_global_plot_counter = -1 _global_plot_counter = -1
_global_image_counter = -1 _global_image_counter = -1
_current_task = None _current_task = None
_support_image_plot = False
class _PatchWarnings(object): class _PatchWarnings(object):
def __init__(self): def __init__(self):
@ -51,15 +53,19 @@ class PatchedMatplotlib:
import matplotlib.pyplot import matplotlib.pyplot
matplotlib.pyplot.switch_backend('agg') matplotlib.pyplot.switch_backend('agg')
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.figure as figure
from matplotlib import _pylab_helpers from matplotlib import _pylab_helpers
if six.PY2: if six.PY2:
PatchedMatplotlib._patched_original_plot = staticmethod(plt.show) PatchedMatplotlib._patched_original_plot = staticmethod(plt.show)
PatchedMatplotlib._patched_original_imshow = staticmethod(plt.imshow) PatchedMatplotlib._patched_original_imshow = staticmethod(plt.imshow)
PatchedMatplotlib._patched_original_figure = staticmethod(figure.Figure.show)
else: else:
PatchedMatplotlib._patched_original_plot = plt.show PatchedMatplotlib._patched_original_plot = plt.show
PatchedMatplotlib._patched_original_imshow = plt.imshow PatchedMatplotlib._patched_original_imshow = plt.imshow
PatchedMatplotlib._patched_original_figure = figure.Figure.show
plt.show = PatchedMatplotlib.patched_show plt.show = PatchedMatplotlib.patched_show
# sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow figure.Figure.show = PatchedMatplotlib.patched_figure_show
sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow
# patch plotly so we know it failed us. # patch plotly so we know it failed us.
from plotly.matplotlylib import renderer from plotly.matplotlylib import renderer
renderer.warnings = PatchedMatplotlib._PatchWarnings() renderer.warnings = PatchedMatplotlib._PatchWarnings()
@ -87,11 +93,26 @@ class PatchedMatplotlib:
def update_current_task(task): def update_current_task(task):
if PatchedMatplotlib.patch_matplotlib(): if PatchedMatplotlib.patch_matplotlib():
PatchedMatplotlib._current_task = task PatchedMatplotlib._current_task = task
from ..backend_api import Session
PatchedMatplotlib._support_image_plot = Session.api_version > '2.1'
@staticmethod @staticmethod
def patched_imshow(*args, **kw): def patched_imshow(*args, **kw):
ret = PatchedMatplotlib._patched_original_imshow(*args, **kw) ret = PatchedMatplotlib._patched_original_imshow(*args, **kw)
PatchedMatplotlib._report_figure(force_save_as_image=True) try:
from matplotlib import _pylab_helpers
# store on the plot that this is an imshow plot
stored_figure = _pylab_helpers.Gcf.get_active()
if stored_figure:
stored_figure._trains_is_imshow = True
except Exception:
pass
return ret
@staticmethod
def patched_figure_show(self, *args, **kw):
PatchedMatplotlib._report_figure(set_active=False, specific_fig=self)
ret = PatchedMatplotlib._patched_original_figure(self, *args, **kw)
return ret return ret
@staticmethod @staticmethod
@ -110,7 +131,7 @@ class PatchedMatplotlib:
return ret return ret
@staticmethod @staticmethod
def _report_figure(force_save_as_image=False, stored_figure=None, set_active=True): def _report_figure(force_save_as_image=False, stored_figure=None, set_active=True, specific_fig=None):
if not PatchedMatplotlib._current_task: if not PatchedMatplotlib._current_task:
return return
@ -119,18 +140,27 @@ class PatchedMatplotlib:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from plotly import optional_imports from plotly import optional_imports
from matplotlib import _pylab_helpers from matplotlib import _pylab_helpers
# store the figure object we just created (if it is not already there) if specific_fig is None:
stored_figure = stored_figure or _pylab_helpers.Gcf.get_active() # store the figure object we just created (if it is not already there)
if not stored_figure: stored_figure = stored_figure or _pylab_helpers.Gcf.get_active()
# nothing for us to do if not stored_figure:
return # nothing for us to do
# get current figure return
mpl_fig = stored_figure.canvas.figure # plt.gcf() # check if this is an imshow
if hasattr(stored_figure, '_trains_is_imshow') and stored_figure._trains_is_imshow:
force_save_as_image = True
# flag will be cleared when calling clf() (object will be replaced)
# get current figure
mpl_fig = stored_figure.canvas.figure # plt.gcf()
else:
mpl_fig = specific_fig
# convert to plotly # convert to plotly
image = None image = None
plotly_fig = None plotly_fig = None
image_format = 'svg' image_format = 'jpeg'
if not force_save_as_image: if not force_save_as_image:
image_format = 'svg'
# noinspection PyBroadException # noinspection PyBroadException
try: try:
def our_mpl_to_plotly(fig): def our_mpl_to_plotly(fig):
@ -142,9 +172,8 @@ class PatchedMatplotlib:
plotly_fig = our_mpl_to_plotly(mpl_fig) plotly_fig = our_mpl_to_plotly(mpl_fig)
except Exception as ex: except Exception as ex:
# this was an image, change format to jpeg # this was an image, change format to png
if 'selfie' in str(ex): image_format = 'jpeg' if 'selfie' in str(ex) else 'png'
image_format = 'jpeg'
# plotly could not serialize the plot, we should convert to image # plotly could not serialize the plot, we should convert to image
if not plotly_fig: if not plotly_fig:
@ -153,12 +182,14 @@ class PatchedMatplotlib:
try: try:
# first try SVG if we fail then fallback to png # first try SVG if we fail then fallback to png
buffer_ = BytesIO() buffer_ = BytesIO()
plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0) a_plt = specific_fig if specific_fig is not None else plt
a_plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
buffer_.seek(0) buffer_.seek(0)
except Exception: except Exception:
image_format = 'png' image_format = 'png'
buffer_ = BytesIO() buffer_ = BytesIO()
plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0) a_plt = specific_fig if specific_fig is not None else plt
a_plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
buffer_.seek(0) buffer_.seek(0)
fd, image = mkstemp(suffix='.'+image_format) fd, image = mkstemp(suffix='.'+image_format)
os.write(fd, buffer_.read()) os.write(fd, buffer_.read())
@ -193,20 +224,29 @@ class PatchedMatplotlib:
reporter.report_plot(title=title, series='plot', plot=plotly_dict, reporter.report_plot(title=title, series='plot', plot=plotly_dict,
iter=PatchedMatplotlib._global_plot_counter if plot_title else 0) iter=PatchedMatplotlib._global_plot_counter if plot_title else 0)
else: else:
# send the plot as image
PatchedMatplotlib._global_image_counter += 1
logger = PatchedMatplotlib._current_task.get_logger() logger = PatchedMatplotlib._current_task.get_logger()
title = plot_title or 'untitled %d' % PatchedMatplotlib._global_image_counter
# this is actually a failed plot, we should put it under plots: # this is actually a failed plot, we should put it under plots:
# currently disabled # currently disabled
# if image_format == 'svg': if force_save_as_image or not PatchedMatplotlib._support_image_plot:
# logger.report_image_plot_and_upload(title=title, series='plot image', path=image, # send the plot as image
# iteration=PatchedMatplotlib._global_image_counter PatchedMatplotlib._global_image_counter += 1
# if plot_title else 0) title = plot_title or 'untitled %d' % PatchedMatplotlib._global_image_counter
# else:
logger.report_image_and_upload(title=title, series='plot image', path=image, logger.report_image_and_upload(title=title, series='plot image', path=image,
iteration=PatchedMatplotlib._global_image_counter delete_after_upload=True,
if plot_title else 0) iteration=PatchedMatplotlib._global_image_counter
if plot_title else 0)
else:
# send the plot as plotly with embedded image
PatchedMatplotlib._global_plot_counter += 1
title = plot_title or 'untitled %d' % PatchedMatplotlib._global_plot_counter
logger.report_image_plot_and_upload(title=title, series='plot image', path=image,
delete_after_upload=True,
iteration=PatchedMatplotlib._global_plot_counter
if plot_title else 0)
except Exception: except Exception:
# plotly failed # plotly failed
pass pass