Fix uploading 3D plots with plt shows 2D plot on task results page

This commit is contained in:
allegroai 2022-09-06 10:39:17 +03:00
parent 4020c8cd42
commit efde6e4135
4 changed files with 74 additions and 6 deletions

View File

@ -458,6 +458,36 @@ class PatchedMatplotlib:
process_tick_text(
list(plotly_renderer.current_mpl_ax.get_yticklabels()), ("yaxis", "yaxis0", "yaxis1"), 1
)
# noinspection PyBroadException
try:
# check if we have a 3d plot
if (
"zaxis" in plotly_renderer.plotly_fig.get("layout", {})
or "zaxis0" in plotly_renderer.plotly_fig.get("layout", {})
or "zaxis1" in plotly_renderer.plotly_fig.get("layout", {})
or len(plotly_renderer.plotly_fig.get("data", [{}])[0].get("z", [])) > 0
):
process_tick_text(
list(plotly_renderer.current_mpl_ax.get_zticklabels()),
("zaxis", "zaxis0", "zaxis1"),
2,
)
# rotate the X axis -90 degrees such that it matches matplotlib
plotly_renderer.plotly_fig.setdefault("layout", {}).setdefault("scene", {}).setdefault(
"camera", {}
).setdefault("eye", {}).setdefault("x", -1)
# reverse the X and Y axes such that they match matplotlib
plotly_renderer.plotly_fig.setdefault("layout", {}).setdefault("scene", {}).setdefault(
"xaxis", {}
).setdefault("autorange", "reversed")
plotly_renderer.plotly_fig.setdefault("layout", {}).setdefault("scene", {}).setdefault(
"yaxis", {}
).setdefault("autorange", "reversed")
except Exception:
pass
# try to bring back legend
# noinspection PyBroadException

View File

@ -133,8 +133,8 @@ class Exporter(object):
self.draw_line(ax, line)
for text in ax.texts:
self.draw_text(ax, text)
for (text, ttp) in zip([ax.xaxis.label, ax.yaxis.label, ax.title],
["xlabel", "ylabel", "title"]):
for (text, ttp) in zip([ax.xaxis.label, ax.yaxis.label, ax.zaxis.label, ax.title],
["xlabel", "ylabel", "zlabel", "title"]):
if(hasattr(text, "get_text") and text.get_text()):
self.draw_text(ax, text, force_trans=ax.transAxes,
text_type=ttp)

View File

@ -449,7 +449,7 @@ def prep_ticks(ax, index, ax_type, props):
positional arguments:
ax - the mpl axes instance
index - the index of the axis in `props`
ax_type - 'x' or 'y' (for now)
ax_type - 'x' or 'y' or 'z'
props - an mplexporter poperties dictionary
"""
@ -458,6 +458,8 @@ def prep_ticks(ax, index, ax_type, props):
axis = ax.get_xaxis()
elif ax_type == "y":
axis = ax.get_yaxis()
elif ax_type == "z":
axis = ax.get_zaxis()
else:
return dict() # whoops!
@ -491,6 +493,9 @@ def prep_ticks(ax, index, ax_type, props):
elif ax_type == "y" and "ylim" in props:
axis_dict["range"] = [props["ylim"][0], props["ylim"][1]]
axis_dict["custom_range"] = True
elif ax_type == "z" and "zlim" in props:
axis_dict["range"] = [props["zlim"][0], props["zlim"][1]]
axis_dict["custom_range"] = True
elif scale == "log":
try:
axis_dict["tick0"] = props["axes"][index]["tickvalues"][0]
@ -515,6 +520,12 @@ def prep_ticks(ax, index, ax_type, props):
math.log10(props["ylim"][1]),
]
axis_dict["custom_range"] = True
elif ax_type == "z" and "zlim" in props:
axis_dict["range"] = [
math.log10(props["zlim"][0]),
math.log10(props["zlim"][1]),
]
axis_dict["custom_range"] = True
else:
axis_dict = dict(range=None, type="linear")
warnings.warn(
@ -561,6 +572,23 @@ def prep_xy_axis(ax, props, x_bounds, y_bounds):
yaxis.update(prep_ticks(ax, 1, "y", props))
return xaxis, yaxis
def prep_xyz_axis(ax, props, x_bounds, y_bounds):
# there is no z_bounds as they can't (at least easily) be extracted from an `Axes3DSubplot` object
xaxis, yaxis = prep_xy_axis(ax, props, x_bounds, y_bounds)
# noinspection PyBroadException
try:
zaxis = dict(
type=props["axes"][2]["scale"],
range=list(props.get("zlim", [])),
showgrid=props["axes"][2].get("grid", {}).get("gridOn", False),
side=props["axes"][2].get("position"),
tickfont=dict(size=props["axes"][2].get("fontsize", {})),
)
zaxis.update(prep_ticks(ax, 2, "z", props))
except Exception:
zaxis = {}
return xaxis, yaxis, zaxis
def mpl_dates_to_datestrings(dates, mpl_formatter):
"""Convert matplotlib dates to iso-formatted-like time strings.

View File

@ -160,12 +160,16 @@ class PlotlyRenderer(Renderer):
yaxis = dict(
anchor="x{0}".format(self.axis_ct or ''), zeroline=False, ticks="inside"
)
zaxis = dict(
anchor="x{0}".format(self.axis_ct or ''), zeroline=False, ticks="inside"
)
# update defaults with things set in mpl
mpl_xaxis, mpl_yaxis = mpltools.prep_xy_axis(
mpl_xaxis, mpl_yaxis, mpl_zaxis = mpltools.prep_xyz_axis(
ax=ax, props=props, x_bounds=self.mpl_x_bounds, y_bounds=self.mpl_y_bounds
)
xaxis.update(mpl_xaxis)
yaxis.update(mpl_yaxis)
zaxis.update(mpl_zaxis)
bottom_spine = mpltools.get_spine_visible(ax, "bottom")
top_spine = mpltools.get_spine_visible(ax, "top")
left_spine = mpltools.get_spine_visible(ax, "left")
@ -178,6 +182,8 @@ class PlotlyRenderer(Renderer):
# put axes in our figure
self.plotly_fig["layout"]["xaxis{0}".format(self.axis_ct or '')] = xaxis
self.plotly_fig["layout"]["yaxis{0}".format(self.axis_ct or '')] = yaxis
if mpl_zaxis:
self.plotly_fig["layout"]["zaxis{0}".format(self.axis_ct or '')] = zaxis
# let all subsequent dates be handled properly if required
@ -398,13 +404,17 @@ class PlotlyRenderer(Renderer):
if isinstance(props["label"], six.string_types)
else props["label"]
),
x=[xy_pair[0] for xy_pair in props["data"]],
y=[xy_pair[1] for xy_pair in props["data"]],
x=props["data"][0],
y=props["data"][1],
xaxis="x{0}".format(self.axis_ct),
yaxis="y{0}".format(self.axis_ct),
line=line,
marker=marker,
)
if len(props["data"]) >= 3:
marked_line["z"] = props["data"][2]
marked_line["zaxis"] = "z{0}".format(self.axis_ct)
marked_line["type"] = "scatter3d"
if self.x_is_mpl_date:
formatter = (
self.current_mpl_ax.get_xaxis()