Fix matplotlib auto-magic detect bar graph series name (issue #518)

This commit is contained in:
allegroai 2021-12-22 13:26:51 +02:00
parent 570fcfd061
commit 7e32278ebf

View File

@ -56,6 +56,7 @@ class PlotlyRenderer(Renderer):
self.current_mpl_ax = None
self.bar_containers = None
self.current_bars = []
self.current_bars_names = []
self.axis_ct = 0
self.x_is_mpl_date = False
self.mpl_x_bounds = (0, 1)
@ -197,6 +198,12 @@ class PlotlyRenderer(Renderer):
ax -- an mpl axes object, not required at this time.
"""
if self.current_bars:
# noinspection PyBroadException
try:
self.current_bars_names = [n.get_text() for n in ax.legend().texts]
except Exception:
pass
self.draw_bars(self.current_bars)
self.msg += " Closing axes\n"
self.x_is_mpl_date = False
@ -213,10 +220,10 @@ class PlotlyRenderer(Renderer):
if bar_props["mplobj"] in container
]
)
for trace in mpl_traces:
self.draw_bar(trace)
for i, trace in enumerate(mpl_traces):
self.draw_bar(trace, self.current_bars_names[i] if i < len(self.current_bars_names) else None)
def draw_bar(self, coll):
def draw_bar(self, coll, name=None):
"""Draw a collection of similar patches as a bar chart.
After bars are sorted, an appropriate data dictionary must be created
@ -302,6 +309,8 @@ class PlotlyRenderer(Renderer):
line=dict(width=trace[0]["edgewidth"]),
),
) # TODO ditto
if name:
bar["name"] = name
if len(bar["x"]) > 1:
self.msg += " Heck yeah, I drew that bar chart\n"
self.plotly_fig['data'].append(bar)