Add Logger.report_matplotlib_figure() with examples

This commit is contained in:
allegroai
2020-10-15 23:20:17 +03:00
parent 7ce5bc0313
commit df395b67ba
5 changed files with 253 additions and 75 deletions

View File

@@ -33,7 +33,7 @@ class PatchedMatplotlib:
_patched_mpltools_get_spine_visible = False
_lock_renderer = threading.RLock()
_recursion_guard = {}
_matplot_major_version = 2
_matplot_major_version = 0
_matplot_minor_version = 0
_logger_started_reporting = False
_matplotlib_reported_titles = set()
@@ -62,9 +62,7 @@ class PatchedMatplotlib:
try:
# we support matplotlib version 2.0.0 and above
import matplotlib
version_split = matplotlib.__version__.split('.')
PatchedMatplotlib._matplot_major_version = int(version_split[0])
PatchedMatplotlib._matplot_minor_version = int(version_split[1])
PatchedMatplotlib._update_matplotlib_version()
if PatchedMatplotlib._matplot_major_version < 2:
LoggerRoot.get_base_logger().warning(
'matplotlib binding supports version 2.0 and above, found version {}'.format(
@@ -137,6 +135,36 @@ class PatchedMatplotlib:
# update api version
from ..backend_api import Session
PatchedMatplotlib._support_image_plot = Session.check_min_api_version('2.2')
# load plotly
PatchedMatplotlib._update_plotly_renderers()
return True
@staticmethod
def _update_matplotlib_version():
if PatchedMatplotlib._matplot_major_version:
return
# we support matplotlib version 2.0.0 and above
try:
import matplotlib
version_split = matplotlib.__version__.split('.')
PatchedMatplotlib._matplot_major_version = int(version_split[0])
PatchedMatplotlib._matplot_minor_version = int(version_split[1])
if running_remotely():
# disable GUI backend - make headless
matplotlib.rcParams['backend'] = 'agg'
import matplotlib.pyplot
matplotlib.pyplot.switch_backend('agg')
except Exception:
pass
@staticmethod
def _update_plotly_renderers():
if PatchedMatplotlib._matplotlylib and PatchedMatplotlib._plotly_renderer:
return True
# create plotly renderer
try:
@@ -144,7 +172,7 @@ class PatchedMatplotlib:
PatchedMatplotlib._matplotlylib = optional_imports.get_module('plotly.matplotlylib')
PatchedMatplotlib._plotly_renderer = PatchedMatplotlib._matplotlylib.PlotlyRenderer()
except Exception:
pass
return False
return True
@@ -251,10 +279,47 @@ class PatchedMatplotlib:
return ret
@staticmethod
def _report_figure(force_save_as_image=False, stored_figure=None, set_active=True, specific_fig=None):
if not PatchedMatplotlib._current_task:
def report_figure(title, series, figure, iter, force_save_as_image=False, reporter=None, logger=None):
PatchedMatplotlib._report_figure(
force_save_as_image=force_save_as_image,
specific_fig=figure.gcf() if hasattr(figure, 'gcf') else figure,
title=title,
series=series,
iter=iter,
reporter=reporter,
logger=logger
)
@staticmethod
def _report_figure(
force_save_as_image=False,
stored_figure=None,
set_active=True,
specific_fig=None,
title=None,
series=None,
iter=None,
reporter=None,
logger=None,
):
# get the main task
if not PatchedMatplotlib._current_task and not reporter and not logger:
return
# check if this is explicit reporting
is_explicit = reporter and logger
# noinspection PyProtectedMember
reporter = reporter or PatchedMatplotlib._current_task._reporter
if not reporter:
return
logger = logger or PatchedMatplotlib._current_task.get_logger()
if not logger:
return
# make sure we have matplotlib ready
PatchedMatplotlib._update_matplotlib_version()
# noinspection PyBroadException
try:
import matplotlib.pyplot as plt
@@ -276,6 +341,13 @@ class PatchedMatplotlib:
else:
mpl_fig = specific_fig
if is_explicit:
# marked displayed explicitly
mpl_fig._trains_explicit = True
elif getattr(mpl_fig, '_trains_explicit', False):
# if auto bind (i.e. plt.show) and plot already displayed explicitly, do nothing.
return
# convert to plotly
image = None
plotly_fig = None
@@ -284,6 +356,8 @@ class PatchedMatplotlib:
if force_save_as_image:
# if this is an image, store as is.
fig_dpi = None
if isinstance(force_save_as_image, str):
image_format = force_save_as_image
else:
image_format = 'svg'
# protect with lock, so we support multiple threads using the same renderer
@@ -291,7 +365,7 @@ class PatchedMatplotlib:
# noinspection PyBroadException
try:
def our_mpl_to_plotly(fig):
if not PatchedMatplotlib._matplotlylib or not PatchedMatplotlib._plotly_renderer:
if not PatchedMatplotlib._update_plotly_renderers():
return None
if not PatchedMatplotlib._patched_mpltools_get_spine_visible and \
PatchedMatplotlib._matplot_major_version and \
@@ -383,76 +457,53 @@ class PatchedMatplotlib:
if set_active and not _pylab_helpers.Gcf.get_active():
_pylab_helpers.Gcf.set_active(stored_figure)
# get the main task
# noinspection PyProtectedMember
reporter = PatchedMatplotlib._current_task._reporter
if reporter is not None:
last_iteration = iter if iter is not None else PatchedMatplotlib._get_last_iteration()
if not title:
if mpl_fig.texts:
plot_title = mpl_fig.texts[0].get_text()
else:
gca = mpl_fig.gca()
plot_title = gca.title.get_text() if gca.title else None
# remove borders and size, we should let the web take care of that
if plotly_fig:
last_iteration = PatchedMatplotlib._get_last_iteration()
if plot_title:
title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
else:
PatchedMatplotlib._global_plot_counter += 1
title = 'untitled %02d' % PatchedMatplotlib._global_plot_counter
plotly_fig.layout.margin = {}
plotly_fig.layout.autosize = True
plotly_fig.layout.height = None
plotly_fig.layout.width = None
# send the plot event
plotly_dict = plotly_fig.to_plotly_json()
if not plotly_dict.get('layout'):
plotly_dict['layout'] = {}
plotly_dict['layout']['title'] = title
PatchedMatplotlib._matplotlib_reported_titles.add(title)
reporter.report_plot(title=title, series='plot', plot=plotly_dict, iter=last_iteration)
if plot_title:
title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
else:
logger = PatchedMatplotlib._current_task.get_logger()
PatchedMatplotlib._global_plot_counter += 1
mod_ = 1 if plotly_fig else PatchedMatplotlib._global_image_counter_limit
title = 'untitled %02d' % (PatchedMatplotlib._global_plot_counter % mod_)
# this is actually a failed plot, we should put it under plots:
# currently disabled
if force_save_as_image or not PatchedMatplotlib._support_image_plot:
last_iteration = PatchedMatplotlib._get_last_iteration()
# send the plot as image
if plot_title:
title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
else:
PatchedMatplotlib._global_image_counter += 1
title = 'untitled %02d' % (PatchedMatplotlib._global_image_counter %
PatchedMatplotlib._global_image_counter_limit)
# remove borders and size, we should let the web take care of that
if plotly_fig:
plotly_fig.layout.margin = {}
plotly_fig.layout.autosize = True
plotly_fig.layout.height = None
plotly_fig.layout.width = None
# send the plot event
plotly_dict = plotly_fig.to_plotly_json()
if not plotly_dict.get('layout'):
plotly_dict['layout'] = {}
plotly_dict['layout']['title'] = series or title
PatchedMatplotlib._matplotlib_reported_titles.add(title)
logger.report_image(title=title, series='plot image', local_path=image,
delete_after_upload=True, iteration=last_iteration)
else:
# send the plot as plotly with embedded image
last_iteration = PatchedMatplotlib._get_last_iteration()
if plot_title:
title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
else:
PatchedMatplotlib._global_plot_counter += 1
title = 'untitled %02d' % (PatchedMatplotlib._global_plot_counter %
PatchedMatplotlib._global_image_counter_limit)
PatchedMatplotlib._matplotlib_reported_titles.add(title)
# noinspection PyProtectedMember
logger._report_image_plot_and_upload(
title=title, series='plot image', path=image,
delete_after_upload=True, iteration=last_iteration)
PatchedMatplotlib._matplotlib_reported_titles.add(title)
reporter.report_plot(title=title, series=series or 'plot', plot=plotly_dict, iter=last_iteration)
else:
# this is actually a failed plot, we should put it under plots:
# currently disabled
if force_save_as_image or not PatchedMatplotlib._support_image_plot:
PatchedMatplotlib._matplotlib_reported_titles.add(title)
logger.report_image(title=title, series=series or 'plot image', local_path=image,
delete_after_upload=True, iteration=last_iteration)
else:
PatchedMatplotlib._matplotlib_reported_titles.add(title)
# noinspection PyProtectedMember
logger._report_image_plot_and_upload(
title=title, series=series or 'plot image', path=image,
delete_after_upload=True, iteration=last_iteration)
except Exception:
# plotly failed
pass
return
@staticmethod
def _enforce_unique_title_per_iteration(title, last_iteration):
# type: (str, int) -> str