Improve matplotlib integration, issue #140

This commit is contained in:
allegroai 2020-06-13 22:09:45 +03:00
parent 2784a48c47
commit a5b1ed0330
3 changed files with 84 additions and 34 deletions

View File

@ -629,6 +629,19 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
raise ValueError('Expected only one of [filename, matrix]') raise ValueError('Expected only one of [filename, matrix]')
kwargs = dict(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter, kwargs = dict(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter,
file_history_size=max_image_history) file_history_size=max_image_history)
if matrix is not None:
width = matrix.shape[1]
height = matrix.shape[0]
else:
# noinspection PyBroadException
try:
from PIL import Image
width, height = Image.open(path).size
except Exception:
width = 640
height = 480
ev = UploadEvent(image_data=matrix, upload_uri=upload_uri, local_image_path=path, ev = UploadEvent(image_data=matrix, upload_uri=upload_uri, local_image_path=path,
delete_after_upload=delete_after_upload, **kwargs) delete_after_upload=delete_after_upload, **kwargs)
_, url = ev.get_target_full_upload_uri(upload_uri or self._storage_uri, self._metrics.storage_key_prefix) _, url = ev.get_target_full_upload_uri(upload_uri or self._storage_uri, self._metrics.storage_key_prefix)
@ -643,8 +656,8 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
plotly_dict = create_image_plot( plotly_dict = create_image_plot(
image_src=url, image_src=url,
title=title + '/' + series, title=title + '/' + series,
width=matrix.shape[1] if matrix is not None else 640, width=640,
height=matrix.shape[0] if matrix is not None else 480, height=int(640*float(height or 480)/float(width or 640)),
) )
return self.report_plot( return self.report_plot(

View File

@ -25,7 +25,7 @@ class PatchedMatplotlib:
_global_plot_counter = -1 _global_plot_counter = -1
_global_image_counter = -1 _global_image_counter = -1
_global_image_counter_limit = None _global_image_counter_limit = None
_last_iteration_plot_titles = (-1, []) _last_iteration_plot_titles = {}
_current_task = None _current_task = None
_support_image_plot = False _support_image_plot = False
_matplotlylib = None _matplotlylib = None
@ -179,6 +179,21 @@ class PatchedMatplotlib:
@staticmethod @staticmethod
def patched_savefig(self, *args, **kw): def patched_savefig(self, *args, **kw):
ret = PatchedMatplotlib._patched_original_savefig(self, *args, **kw) ret = PatchedMatplotlib._patched_original_savefig(self, *args, **kw)
# noinspection PyBroadException
try:
fname = kw.get('fname') or args[0]
from pathlib2 import Path
if six.PY3:
from pathlib import Path as Path3
else:
Path3 = Path
# if we are not storing into a file (str/Path) do not log the matplotlib
if not isinstance(fname, (str, Path, Path3)):
return ret
except Exception:
pass
tid = threading._get_ident() if six.PY2 else threading.get_ident() tid = threading._get_ident() if six.PY2 else threading.get_ident()
if not PatchedMatplotlib._recursion_guard.get(tid): if not PatchedMatplotlib._recursion_guard.get(tid):
PatchedMatplotlib._recursion_guard[tid] = True PatchedMatplotlib._recursion_guard[tid] = True
@ -273,35 +288,36 @@ class PatchedMatplotlib:
def our_mpl_to_plotly(fig): def our_mpl_to_plotly(fig):
if not PatchedMatplotlib._matplotlylib or not PatchedMatplotlib._plotly_renderer: if not PatchedMatplotlib._matplotlylib or not PatchedMatplotlib._plotly_renderer:
return None return None
PatchedMatplotlib._matplotlylib.Exporter(PatchedMatplotlib._plotly_renderer, plotly_renderer = PatchedMatplotlib._matplotlylib.PlotlyRenderer()
close_mpl=False).run(fig) PatchedMatplotlib._matplotlylib.Exporter(plotly_renderer, close_mpl=False).run(fig)
x_ticks = list(PatchedMatplotlib._plotly_renderer.current_mpl_ax.get_xticklabels())
x_ticks = list(plotly_renderer.current_mpl_ax.get_xticklabels())
if x_ticks: if x_ticks:
try: try:
# check if all values can be cast to float # check if all values can be cast to float
values = [float(t.get_text().replace('', '-')) for t in x_ticks] values = [float(t.get_text().replace('', '-')) for t in x_ticks]
except: except:
try: try:
PatchedMatplotlib._plotly_renderer.plotly_fig['layout']['xaxis1'].update({ plotly_renderer.plotly_fig['layout']['xaxis1'].update({
'ticktext': [t.get_text() for t in x_ticks], 'ticktext': [t.get_text() for t in x_ticks],
'tickvals': [t.get_position()[0] for t in x_ticks], 'tickvals': [t.get_position()[0] for t in x_ticks],
}) })
except: except:
pass pass
y_ticks = list(PatchedMatplotlib._plotly_renderer.current_mpl_ax.get_yticklabels()) y_ticks = list(plotly_renderer.current_mpl_ax.get_yticklabels())
if y_ticks: if y_ticks:
try: try:
# check if all values can be cast to float # check if all values can be cast to float
values = [float(t.get_text().replace('', '-')) for t in y_ticks] values = [float(t.get_text().replace('', '-')) for t in y_ticks]
except: except:
try: try:
PatchedMatplotlib._plotly_renderer.plotly_fig['layout']['yaxis1'].update({ plotly_renderer.plotly_fig['layout']['yaxis1'].update({
'ticktext': [t.get_text() for t in y_ticks], 'ticktext': [t.get_text() for t in y_ticks],
'tickvals': [t.get_position()[1] for t in y_ticks], 'tickvals': [t.get_position()[1] for t in y_ticks],
}) })
except: except:
pass pass
return deepcopy(PatchedMatplotlib._plotly_renderer.plotly_fig) return deepcopy(plotly_renderer.plotly_fig)
plotly_fig = our_mpl_to_plotly(mpl_fig) plotly_fig = our_mpl_to_plotly(mpl_fig)
try: try:
@ -366,7 +382,7 @@ class PatchedMatplotlib:
title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration) title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
else: else:
PatchedMatplotlib._global_plot_counter += 1 PatchedMatplotlib._global_plot_counter += 1
title = 'untitled %d' % PatchedMatplotlib._global_plot_counter title = 'untitled %02d' % PatchedMatplotlib._global_plot_counter
plotly_fig.layout.margin = {} plotly_fig.layout.margin = {}
plotly_fig.layout.autosize = True plotly_fig.layout.autosize = True
@ -392,8 +408,8 @@ class PatchedMatplotlib:
title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration) title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
else: else:
PatchedMatplotlib._global_image_counter += 1 PatchedMatplotlib._global_image_counter += 1
title = 'untitled %d' % (PatchedMatplotlib._global_image_counter % title = 'untitled %02d' % (PatchedMatplotlib._global_image_counter %
PatchedMatplotlib._global_image_counter_limit) PatchedMatplotlib._global_image_counter_limit)
PatchedMatplotlib._matplotlib_reported_titles.add(title) PatchedMatplotlib._matplotlib_reported_titles.add(title)
logger.report_image(title=title, series='plot image', local_path=image, logger.report_image(title=title, series='plot image', local_path=image,
@ -405,12 +421,14 @@ class PatchedMatplotlib:
title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration) title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
else: else:
PatchedMatplotlib._global_plot_counter += 1 PatchedMatplotlib._global_plot_counter += 1
title = 'untitled %d' % (PatchedMatplotlib._global_plot_counter % title = 'untitled %02d' % (PatchedMatplotlib._global_plot_counter %
PatchedMatplotlib._global_image_counter_limit) PatchedMatplotlib._global_image_counter_limit)
PatchedMatplotlib._matplotlib_reported_titles.add(title) PatchedMatplotlib._matplotlib_reported_titles.add(title)
logger._report_image_plot_and_upload(title=title, series='plot image', path=image, # noinspection PyProtectedMember
delete_after_upload=True, iteration=last_iteration) logger._report_image_plot_and_upload(
title=title, series='plot image', path=image,
delete_after_upload=True, iteration=last_iteration)
except Exception: except Exception:
# plotly failed # plotly failed
pass pass
@ -419,19 +437,37 @@ class PatchedMatplotlib:
@staticmethod @staticmethod
def _enforce_unique_title_per_iteration(title, last_iteration): def _enforce_unique_title_per_iteration(title, last_iteration):
if last_iteration != PatchedMatplotlib._last_iteration_plot_titles[0]: # type: (str, int) -> str
PatchedMatplotlib._last_iteration_plot_titles = (last_iteration, [title]) """
elif title not in PatchedMatplotlib._last_iteration_plot_titles[1]: Matplotlib with specific title will reset the title counter on every new iteration.
PatchedMatplotlib._last_iteration_plot_titles[1].append(title) Calling title twice each iteration will produce "title" and "title/1" for every iteration
:param title: original matplotlib title
:param last_iteration: the current "last_iteration"
:return: new title to use (with counter attached if necessary)
"""
# check if we already encountered the title
if title in PatchedMatplotlib._last_iteration_plot_titles:
# if we have check the last iteration
title_last_iteration, title_counter = PatchedMatplotlib._last_iteration_plot_titles[title]
# if this is a new iteration start from the beginning
if last_iteration == title_last_iteration:
title_counter += 1
else: # if this is a new iteration start from the beginning
title_last_iteration = last_iteration
title_counter = 0
else: else:
base_title = title # this is a new title
counter = 1 title_last_iteration = last_iteration
while title in PatchedMatplotlib._last_iteration_plot_titles[1]: title_counter = 0
# we already used this title in this iteration, we should change the title
title = base_title + ' %d' % counter base_title = title
counter += 1 # if this is the zero counter to not add the counter to the title
# store the new title if title_counter != 0:
PatchedMatplotlib._last_iteration_plot_titles[1].append(title) title = base_title + '/%d' % title_counter
# update back the title iteration counter
PatchedMatplotlib._last_iteration_plot_titles[base_title] = (title_last_iteration, title_counter)
return title return title
@staticmethod @staticmethod

View File

@ -314,9 +314,9 @@ def create_image_plot(image_src, title, width=640, height=480, series=None, comm
"data": [], "data": [],
"layout": { "layout": {
"xaxis": {"visible": False, "range": [0, width]}, "xaxis": {"visible": False, "range": [0, width]},
"yaxis": {"visible": False, "range": [0, height]}, "yaxis": {"visible": False, "range": [0, height], "scaleanchor": "x"},
# "width": width, "width": width,
# "height": height, "height": height,
"margin": {'l': 0, 'r': 0, 't': 0, 'b': 0}, "margin": {'l': 0, 'r': 0, 't': 0, 'b': 0},
"images": [{ "images": [{
"sizex": width, "sizex": width,
@ -325,8 +325,9 @@ def create_image_plot(image_src, title, width=640, height=480, series=None, comm
"yref": "y", "yref": "y",
"opacity": 1.0, "opacity": 1.0,
"x": 0, "x": 0,
"y": int(height / 2), "y": height,
"yanchor": "middle", # "xanchor": "left",
# "yanchor": "bottom",
"sizing": "contain", "sizing": "contain",
"layer": "below", "layer": "below",
"source": image_src "source": image_src