mirror of
https://github.com/clearml/clearml
synced 2025-04-29 18:51:47 +00:00
Fix mpl exporter. Added support for legend.
This commit is contained in:
parent
14706c03e1
commit
fe86fbcac3
@ -31,8 +31,10 @@ class Exporter(object):
|
|||||||
def __init__(self, renderer, close_mpl=True):
|
def __init__(self, renderer, close_mpl=True):
|
||||||
self.close_mpl = close_mpl
|
self.close_mpl = close_mpl
|
||||||
self.renderer = renderer
|
self.renderer = renderer
|
||||||
|
self.has_legend = False
|
||||||
|
self.legend_as_annotation = True
|
||||||
|
|
||||||
def run(self, fig):
|
def run(self, fig, show_legend=True):
|
||||||
"""
|
"""
|
||||||
Run the exporter on the given figure
|
Run the exporter on the given figure
|
||||||
|
|
||||||
@ -40,9 +42,12 @@ class Exporter(object):
|
|||||||
---------
|
---------
|
||||||
fig : matplotlib.Figure instance
|
fig : matplotlib.Figure instance
|
||||||
The figure to export
|
The figure to export
|
||||||
|
|
||||||
|
show_legend: If True, plotly show legend if plt has one, If False add legend as annotations on plot.
|
||||||
"""
|
"""
|
||||||
# Calling savefig executes the draw() command, putting elements
|
# Calling savefig executes the draw() command, putting elements
|
||||||
# in the correct place.
|
# in the correct place.
|
||||||
|
self.legend_as_annotation = not show_legend
|
||||||
if fig.canvas is None:
|
if fig.canvas is None:
|
||||||
canvas = FigureCanvasAgg(fig) # noqa: F841
|
canvas = FigureCanvasAgg(fig) # noqa: F841
|
||||||
fig.savefig(io.BytesIO(), format='png', dpi=fig.dpi)
|
fig.savefig(io.BytesIO(), format='png', dpi=fig.dpi)
|
||||||
@ -50,6 +55,8 @@ class Exporter(object):
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
self.crawl_fig(fig)
|
self.crawl_fig(fig)
|
||||||
|
if show_legend and self.has_legend:
|
||||||
|
self.renderer.plotly_fig['layout']['showlegend'] = True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def process_transform(transform, ax=None, data=None, return_trans=False,
|
def process_transform(transform, ax=None, data=None, return_trans=False,
|
||||||
@ -144,6 +151,8 @@ class Exporter(object):
|
|||||||
|
|
||||||
legend = ax.get_legend()
|
legend = ax.get_legend()
|
||||||
if legend is not None:
|
if legend is not None:
|
||||||
|
self.has_legend = True
|
||||||
|
if self.legend_as_annotation:
|
||||||
props = utils.get_legend_properties(ax, legend)
|
props = utils.get_legend_properties(ax, legend)
|
||||||
with self.renderer.draw_legend(legend=legend, props=props):
|
with self.renderer.draw_legend(legend=legend, props=props):
|
||||||
if props['visible']:
|
if props['visible']:
|
||||||
|
@ -8,6 +8,10 @@ import math
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
import matplotlib.dates
|
import matplotlib.dates
|
||||||
|
try:
|
||||||
|
from matplotlib.patches import FancyBboxPatch
|
||||||
|
except ImportError:
|
||||||
|
FancyBboxPatch = None
|
||||||
|
|
||||||
|
|
||||||
def check_bar_match(old_bar, new_bar):
|
def check_bar_match(old_bar, new_bar):
|
||||||
@ -389,6 +393,11 @@ def is_bar(bar_containers, **props):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_fancy_bbox(**props):
|
||||||
|
"""A test to decide whether a path is a simple FancyBboxPatch."""
|
||||||
|
return FancyBboxPatch and isinstance(props.get("mplobj"), FancyBboxPatch)
|
||||||
|
|
||||||
|
|
||||||
def make_bar(**props):
|
def make_bar(**props):
|
||||||
"""Make an intermediate bar dictionary.
|
"""Make an intermediate bar dictionary.
|
||||||
|
|
||||||
|
@ -509,6 +509,8 @@ class PlotlyRenderer(Renderer):
|
|||||||
is_bar = mpltools.is_bar(self.current_mpl_ax.containers, **props)
|
is_bar = mpltools.is_bar(self.current_mpl_ax.containers, **props)
|
||||||
if is_bar:
|
if is_bar:
|
||||||
self.current_bars += [props]
|
self.current_bars += [props]
|
||||||
|
elif mpltools.is_fancy_bbox(**props):
|
||||||
|
self.current_bars += [props]
|
||||||
else:
|
else:
|
||||||
self.msg += " This path isn't a bar, not drawing\n"
|
self.msg += " This path isn't a bar, not drawing\n"
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@ -588,8 +590,8 @@ class PlotlyRenderer(Renderer):
|
|||||||
self.msg += " Text object is linked to 'data' " "coordinates\n"
|
self.msg += " Text object is linked to 'data' " "coordinates\n"
|
||||||
x, y = props["position"]
|
x, y = props["position"]
|
||||||
axis_ct = self.axis_ct
|
axis_ct = self.axis_ct
|
||||||
xaxis = self.plotly_fig["layout"]["xaxis{0}".format(axis_ct)]
|
xaxis = self.plotly_fig["layout"]["xaxis{0}".format(axis_ct or '')]
|
||||||
yaxis = self.plotly_fig["layout"]["yaxis{0}".format(axis_ct)]
|
yaxis = self.plotly_fig["layout"]["yaxis{0}".format(axis_ct or '')]
|
||||||
if (
|
if (
|
||||||
xaxis["range"][0] < x < xaxis["range"][1]
|
xaxis["range"][0] < x < xaxis["range"][1]
|
||||||
and yaxis["range"][0] < y < yaxis["range"][1]
|
and yaxis["range"][0] < y < yaxis["range"][1]
|
||||||
@ -714,7 +716,10 @@ class PlotlyRenderer(Renderer):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
self.msg += " Adding xlabel\n"
|
self.msg += " Adding xlabel\n"
|
||||||
axis_key = "xaxis{0}".format(self.axis_ct)
|
axis_key = "xaxis{0}".format(self.axis_ct or '')
|
||||||
|
# bugfix: add on last axis, self.axis_ct-1
|
||||||
|
if axis_key not in self.plotly_fig["layout"]:
|
||||||
|
axis_key = "xaxis{0}".format(max(0, self.axis_ct - 1) or '')
|
||||||
self.plotly_fig["layout"][axis_key]["title"] = str(props["text"])
|
self.plotly_fig["layout"][axis_key]["title"] = str(props["text"])
|
||||||
titlefont = dict(size=props["style"]["fontsize"], color=props["style"]["color"])
|
titlefont = dict(size=props["style"]["fontsize"], color=props["style"]["color"])
|
||||||
self.plotly_fig["layout"][axis_key]["titlefont"] = titlefont
|
self.plotly_fig["layout"][axis_key]["titlefont"] = titlefont
|
||||||
@ -743,7 +748,10 @@ class PlotlyRenderer(Renderer):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
self.msg += " Adding ylabel\n"
|
self.msg += " Adding ylabel\n"
|
||||||
axis_key = "yaxis{0}".format(self.axis_ct)
|
axis_key = "yaxis{0}".format(self.axis_ct or '')
|
||||||
|
# bugfix: add on last axis, self.axis_ct-1
|
||||||
|
if axis_key not in self.plotly_fig["layout"]:
|
||||||
|
axis_key = "yaxis{0}".format(max(0, self.axis_ct - 1) or '')
|
||||||
self.plotly_fig["layout"][axis_key]["title"] = props["text"]
|
self.plotly_fig["layout"][axis_key]["title"] = props["text"]
|
||||||
titlefont = dict(size=props["style"]["fontsize"], color=props["style"]["color"])
|
titlefont = dict(size=props["style"]["fontsize"], color=props["style"]["color"])
|
||||||
self.plotly_fig["layout"][axis_key]["titlefont"] = titlefont
|
self.plotly_fig["layout"][axis_key]["titlefont"] = titlefont
|
||||||
|
Loading…
Reference in New Issue
Block a user