mirror of
https://github.com/clearml/clearml
synced 2025-04-16 21:42:10 +00:00
Fix Matplotlib figure.show support
Improve Matplotlib support, plt.show will show under plots tabs, and plt.imshow under debug images
This commit is contained in:
parent
787da4798b
commit
67df9774e1
@ -11,12 +11,14 @@ from ..config import running_remotely
|
|||||||
|
|
||||||
class PatchedMatplotlib:
|
class PatchedMatplotlib:
|
||||||
_patched_original_plot = None
|
_patched_original_plot = None
|
||||||
|
_patched_original_figure = None
|
||||||
__patched_original_imshow = None
|
__patched_original_imshow = None
|
||||||
__patched_original_draw_all = None
|
__patched_original_draw_all = None
|
||||||
__patched_draw_all_recursion_guard = False
|
__patched_draw_all_recursion_guard = False
|
||||||
_global_plot_counter = -1
|
_global_plot_counter = -1
|
||||||
_global_image_counter = -1
|
_global_image_counter = -1
|
||||||
_current_task = None
|
_current_task = None
|
||||||
|
_support_image_plot = False
|
||||||
|
|
||||||
class _PatchWarnings(object):
|
class _PatchWarnings(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -51,15 +53,19 @@ class PatchedMatplotlib:
|
|||||||
import matplotlib.pyplot
|
import matplotlib.pyplot
|
||||||
matplotlib.pyplot.switch_backend('agg')
|
matplotlib.pyplot.switch_backend('agg')
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.figure as figure
|
||||||
from matplotlib import _pylab_helpers
|
from matplotlib import _pylab_helpers
|
||||||
if six.PY2:
|
if six.PY2:
|
||||||
PatchedMatplotlib._patched_original_plot = staticmethod(plt.show)
|
PatchedMatplotlib._patched_original_plot = staticmethod(plt.show)
|
||||||
PatchedMatplotlib._patched_original_imshow = staticmethod(plt.imshow)
|
PatchedMatplotlib._patched_original_imshow = staticmethod(plt.imshow)
|
||||||
|
PatchedMatplotlib._patched_original_figure = staticmethod(figure.Figure.show)
|
||||||
else:
|
else:
|
||||||
PatchedMatplotlib._patched_original_plot = plt.show
|
PatchedMatplotlib._patched_original_plot = plt.show
|
||||||
PatchedMatplotlib._patched_original_imshow = plt.imshow
|
PatchedMatplotlib._patched_original_imshow = plt.imshow
|
||||||
|
PatchedMatplotlib._patched_original_figure = figure.Figure.show
|
||||||
plt.show = PatchedMatplotlib.patched_show
|
plt.show = PatchedMatplotlib.patched_show
|
||||||
# sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow
|
figure.Figure.show = PatchedMatplotlib.patched_figure_show
|
||||||
|
sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow
|
||||||
# patch plotly so we know it failed us.
|
# patch plotly so we know it failed us.
|
||||||
from plotly.matplotlylib import renderer
|
from plotly.matplotlylib import renderer
|
||||||
renderer.warnings = PatchedMatplotlib._PatchWarnings()
|
renderer.warnings = PatchedMatplotlib._PatchWarnings()
|
||||||
@ -87,11 +93,26 @@ class PatchedMatplotlib:
|
|||||||
def update_current_task(task):
|
def update_current_task(task):
|
||||||
if PatchedMatplotlib.patch_matplotlib():
|
if PatchedMatplotlib.patch_matplotlib():
|
||||||
PatchedMatplotlib._current_task = task
|
PatchedMatplotlib._current_task = task
|
||||||
|
from ..backend_api import Session
|
||||||
|
PatchedMatplotlib._support_image_plot = Session.api_version > '2.1'
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def patched_imshow(*args, **kw):
|
def patched_imshow(*args, **kw):
|
||||||
ret = PatchedMatplotlib._patched_original_imshow(*args, **kw)
|
ret = PatchedMatplotlib._patched_original_imshow(*args, **kw)
|
||||||
PatchedMatplotlib._report_figure(force_save_as_image=True)
|
try:
|
||||||
|
from matplotlib import _pylab_helpers
|
||||||
|
# store on the plot that this is an imshow plot
|
||||||
|
stored_figure = _pylab_helpers.Gcf.get_active()
|
||||||
|
if stored_figure:
|
||||||
|
stored_figure._trains_is_imshow = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def patched_figure_show(self, *args, **kw):
|
||||||
|
PatchedMatplotlib._report_figure(set_active=False, specific_fig=self)
|
||||||
|
ret = PatchedMatplotlib._patched_original_figure(self, *args, **kw)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -110,7 +131,7 @@ class PatchedMatplotlib:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _report_figure(force_save_as_image=False, stored_figure=None, set_active=True):
|
def _report_figure(force_save_as_image=False, stored_figure=None, set_active=True, specific_fig=None):
|
||||||
if not PatchedMatplotlib._current_task:
|
if not PatchedMatplotlib._current_task:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -119,18 +140,27 @@ class PatchedMatplotlib:
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from plotly import optional_imports
|
from plotly import optional_imports
|
||||||
from matplotlib import _pylab_helpers
|
from matplotlib import _pylab_helpers
|
||||||
# store the figure object we just created (if it is not already there)
|
if specific_fig is None:
|
||||||
stored_figure = stored_figure or _pylab_helpers.Gcf.get_active()
|
# store the figure object we just created (if it is not already there)
|
||||||
if not stored_figure:
|
stored_figure = stored_figure or _pylab_helpers.Gcf.get_active()
|
||||||
# nothing for us to do
|
if not stored_figure:
|
||||||
return
|
# nothing for us to do
|
||||||
# get current figure
|
return
|
||||||
mpl_fig = stored_figure.canvas.figure # plt.gcf()
|
# check if this is an imshow
|
||||||
|
if hasattr(stored_figure, '_trains_is_imshow') and stored_figure._trains_is_imshow:
|
||||||
|
force_save_as_image = True
|
||||||
|
# flag will be cleared when calling clf() (object will be replaced)
|
||||||
|
# get current figure
|
||||||
|
mpl_fig = stored_figure.canvas.figure # plt.gcf()
|
||||||
|
else:
|
||||||
|
mpl_fig = specific_fig
|
||||||
|
|
||||||
# convert to plotly
|
# convert to plotly
|
||||||
image = None
|
image = None
|
||||||
plotly_fig = None
|
plotly_fig = None
|
||||||
image_format = 'svg'
|
image_format = 'jpeg'
|
||||||
if not force_save_as_image:
|
if not force_save_as_image:
|
||||||
|
image_format = 'svg'
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
def our_mpl_to_plotly(fig):
|
def our_mpl_to_plotly(fig):
|
||||||
@ -142,9 +172,8 @@ class PatchedMatplotlib:
|
|||||||
|
|
||||||
plotly_fig = our_mpl_to_plotly(mpl_fig)
|
plotly_fig = our_mpl_to_plotly(mpl_fig)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
# this was an image, change format to jpeg
|
# this was an image, change format to png
|
||||||
if 'selfie' in str(ex):
|
image_format = 'jpeg' if 'selfie' in str(ex) else 'png'
|
||||||
image_format = 'jpeg'
|
|
||||||
|
|
||||||
# plotly could not serialize the plot, we should convert to image
|
# plotly could not serialize the plot, we should convert to image
|
||||||
if not plotly_fig:
|
if not plotly_fig:
|
||||||
@ -153,12 +182,14 @@ class PatchedMatplotlib:
|
|||||||
try:
|
try:
|
||||||
# first try SVG if we fail then fallback to png
|
# first try SVG if we fail then fallback to png
|
||||||
buffer_ = BytesIO()
|
buffer_ = BytesIO()
|
||||||
plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0)
|
a_plt = specific_fig if specific_fig is not None else plt
|
||||||
|
a_plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
|
||||||
buffer_.seek(0)
|
buffer_.seek(0)
|
||||||
except Exception:
|
except Exception:
|
||||||
image_format = 'png'
|
image_format = 'png'
|
||||||
buffer_ = BytesIO()
|
buffer_ = BytesIO()
|
||||||
plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0)
|
a_plt = specific_fig if specific_fig is not None else plt
|
||||||
|
a_plt.savefig(buffer_, format=image_format, bbox_inches='tight', pad_inches=0, frameon=False)
|
||||||
buffer_.seek(0)
|
buffer_.seek(0)
|
||||||
fd, image = mkstemp(suffix='.'+image_format)
|
fd, image = mkstemp(suffix='.'+image_format)
|
||||||
os.write(fd, buffer_.read())
|
os.write(fd, buffer_.read())
|
||||||
@ -193,20 +224,29 @@ class PatchedMatplotlib:
|
|||||||
reporter.report_plot(title=title, series='plot', plot=plotly_dict,
|
reporter.report_plot(title=title, series='plot', plot=plotly_dict,
|
||||||
iter=PatchedMatplotlib._global_plot_counter if plot_title else 0)
|
iter=PatchedMatplotlib._global_plot_counter if plot_title else 0)
|
||||||
else:
|
else:
|
||||||
# send the plot as image
|
|
||||||
PatchedMatplotlib._global_image_counter += 1
|
|
||||||
logger = PatchedMatplotlib._current_task.get_logger()
|
logger = PatchedMatplotlib._current_task.get_logger()
|
||||||
title = plot_title or 'untitled %d' % PatchedMatplotlib._global_image_counter
|
|
||||||
# this is actually a failed plot, we should put it under plots:
|
# this is actually a failed plot, we should put it under plots:
|
||||||
# currently disabled
|
# currently disabled
|
||||||
# if image_format == 'svg':
|
if force_save_as_image or not PatchedMatplotlib._support_image_plot:
|
||||||
# logger.report_image_plot_and_upload(title=title, series='plot image', path=image,
|
# send the plot as image
|
||||||
# iteration=PatchedMatplotlib._global_image_counter
|
PatchedMatplotlib._global_image_counter += 1
|
||||||
# if plot_title else 0)
|
title = plot_title or 'untitled %d' % PatchedMatplotlib._global_image_counter
|
||||||
# else:
|
|
||||||
logger.report_image_and_upload(title=title, series='plot image', path=image,
|
logger.report_image_and_upload(title=title, series='plot image', path=image,
|
||||||
iteration=PatchedMatplotlib._global_image_counter
|
delete_after_upload=True,
|
||||||
if plot_title else 0)
|
iteration=PatchedMatplotlib._global_image_counter
|
||||||
|
if plot_title else 0)
|
||||||
|
else:
|
||||||
|
# send the plot as plotly with embedded image
|
||||||
|
PatchedMatplotlib._global_plot_counter += 1
|
||||||
|
title = plot_title or 'untitled %d' % PatchedMatplotlib._global_plot_counter
|
||||||
|
|
||||||
|
logger.report_image_plot_and_upload(title=title, series='plot image', path=image,
|
||||||
|
delete_after_upload=True,
|
||||||
|
iteration=PatchedMatplotlib._global_plot_counter
|
||||||
|
if plot_title else 0)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# plotly failed
|
# plotly failed
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user