From efde6e4135b64dec3cfeef47cddfa3ddd308bc26 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 6 Sep 2022 10:39:17 +0300 Subject: [PATCH] Fix uploading 3D plots with plt shows 2D plot on task results page --- clearml/binding/matplotlib_bind.py | 30 +++++++++++++++++++ .../plotlympl/mplexporter/exporter.py | 4 +-- clearml/utilities/plotlympl/mpltools.py | 30 ++++++++++++++++++- clearml/utilities/plotlympl/renderer.py | 16 ++++++++-- 4 files changed, 74 insertions(+), 6 deletions(-) diff --git a/clearml/binding/matplotlib_bind.py b/clearml/binding/matplotlib_bind.py index 9ae0a28b..49b32ea9 100644 --- a/clearml/binding/matplotlib_bind.py +++ b/clearml/binding/matplotlib_bind.py @@ -458,6 +458,36 @@ class PatchedMatplotlib: process_tick_text( list(plotly_renderer.current_mpl_ax.get_yticklabels()), ("yaxis", "yaxis0", "yaxis1"), 1 ) + # noinspection PyBroadException + try: + # check if we have a 3d plot + if ( + "zaxis" in plotly_renderer.plotly_fig.get("layout", {}) + or "zaxis0" in plotly_renderer.plotly_fig.get("layout", {}) + or "zaxis1" in plotly_renderer.plotly_fig.get("layout", {}) + or len(plotly_renderer.plotly_fig.get("data", [{}])[0].get("z", [])) > 0 + ): + process_tick_text( + list(plotly_renderer.current_mpl_ax.get_zticklabels()), + ("zaxis", "zaxis0", "zaxis1"), + 2, + ) + + # rotate the X axis -90 degrees such that it matches matplotlib + plotly_renderer.plotly_fig.setdefault("layout", {}).setdefault("scene", {}).setdefault( + "camera", {} + ).setdefault("eye", {}).setdefault("x", -1) + + # reverse the X and Y axes such that they match matplotlib + plotly_renderer.plotly_fig.setdefault("layout", {}).setdefault("scene", {}).setdefault( + "xaxis", {} + ).setdefault("autorange", "reversed") + + plotly_renderer.plotly_fig.setdefault("layout", {}).setdefault("scene", {}).setdefault( + "yaxis", {} + ).setdefault("autorange", "reversed") + except Exception: + pass # try to bring back legend # noinspection PyBroadException diff --git a/clearml/utilities/plotlympl/mplexporter/exporter.py b/clearml/utilities/plotlympl/mplexporter/exporter.py index 614569bf..8927c141 100644 --- a/clearml/utilities/plotlympl/mplexporter/exporter.py +++ b/clearml/utilities/plotlympl/mplexporter/exporter.py @@ -133,8 +133,8 @@ class Exporter(object): self.draw_line(ax, line) for text in ax.texts: self.draw_text(ax, text) - for (text, ttp) in zip([ax.xaxis.label, ax.yaxis.label, ax.title], - ["xlabel", "ylabel", "title"]): + for (text, ttp) in zip([ax.xaxis.label, ax.yaxis.label, ax.zaxis.label, ax.title], + ["xlabel", "ylabel", "zlabel", "title"]): if(hasattr(text, "get_text") and text.get_text()): self.draw_text(ax, text, force_trans=ax.transAxes, text_type=ttp) diff --git a/clearml/utilities/plotlympl/mpltools.py b/clearml/utilities/plotlympl/mpltools.py index 95fb3e89..dcd19089 100644 --- a/clearml/utilities/plotlympl/mpltools.py +++ b/clearml/utilities/plotlympl/mpltools.py @@ -449,7 +449,7 @@ def prep_ticks(ax, index, ax_type, props): positional arguments: ax - the mpl axes instance index - the index of the axis in `props` - ax_type - 'x' or 'y' (for now) + ax_type - 'x' or 'y' or 'z' props - an mplexporter poperties dictionary """ @@ -458,6 +458,8 @@ def prep_ticks(ax, index, ax_type, props): axis = ax.get_xaxis() elif ax_type == "y": axis = ax.get_yaxis() + elif ax_type == "z": + axis = ax.get_zaxis() else: return dict() # whoops! @@ -491,6 +493,9 @@ def prep_ticks(ax, index, ax_type, props): elif ax_type == "y" and "ylim" in props: axis_dict["range"] = [props["ylim"][0], props["ylim"][1]] axis_dict["custom_range"] = True + elif ax_type == "z" and "zlim" in props: + axis_dict["range"] = [props["zlim"][0], props["zlim"][1]] + axis_dict["custom_range"] = True elif scale == "log": try: axis_dict["tick0"] = props["axes"][index]["tickvalues"][0] @@ -515,6 +520,12 @@ def prep_ticks(ax, index, ax_type, props): math.log10(props["ylim"][1]), ] axis_dict["custom_range"] = True + elif ax_type == "z" and "zlim" in props: + axis_dict["range"] = [ + math.log10(props["zlim"][0]), + math.log10(props["zlim"][1]), + ] + axis_dict["custom_range"] = True else: axis_dict = dict(range=None, type="linear") warnings.warn( @@ -561,6 +572,23 @@ def prep_xy_axis(ax, props, x_bounds, y_bounds): yaxis.update(prep_ticks(ax, 1, "y", props)) return xaxis, yaxis +def prep_xyz_axis(ax, props, x_bounds, y_bounds): + # there is no z_bounds as they can't (at least easily) be extracted from an `Axes3DSubplot` object + xaxis, yaxis = prep_xy_axis(ax, props, x_bounds, y_bounds) + # noinspection PyBroadException + try: + zaxis = dict( + type=props["axes"][2]["scale"], + range=list(props.get("zlim", [])), + showgrid=props["axes"][2].get("grid", {}).get("gridOn", False), + side=props["axes"][2].get("position"), + tickfont=dict(size=props["axes"][2].get("fontsize", {})), + ) + zaxis.update(prep_ticks(ax, 2, "z", props)) + except Exception: + zaxis = {} + return xaxis, yaxis, zaxis + def mpl_dates_to_datestrings(dates, mpl_formatter): """Convert matplotlib dates to iso-formatted-like time strings. diff --git a/clearml/utilities/plotlympl/renderer.py b/clearml/utilities/plotlympl/renderer.py index c0d247f3..d86aa44f 100644 --- a/clearml/utilities/plotlympl/renderer.py +++ b/clearml/utilities/plotlympl/renderer.py @@ -160,12 +160,16 @@ class PlotlyRenderer(Renderer): yaxis = dict( anchor="x{0}".format(self.axis_ct or ''), zeroline=False, ticks="inside" ) + zaxis = dict( + anchor="x{0}".format(self.axis_ct or ''), zeroline=False, ticks="inside" + ) # update defaults with things set in mpl - mpl_xaxis, mpl_yaxis = mpltools.prep_xy_axis( + mpl_xaxis, mpl_yaxis, mpl_zaxis = mpltools.prep_xyz_axis( ax=ax, props=props, x_bounds=self.mpl_x_bounds, y_bounds=self.mpl_y_bounds ) xaxis.update(mpl_xaxis) yaxis.update(mpl_yaxis) + zaxis.update(mpl_zaxis) bottom_spine = mpltools.get_spine_visible(ax, "bottom") top_spine = mpltools.get_spine_visible(ax, "top") left_spine = mpltools.get_spine_visible(ax, "left") @@ -178,6 +182,8 @@ class PlotlyRenderer(Renderer): # put axes in our figure self.plotly_fig["layout"]["xaxis{0}".format(self.axis_ct or '')] = xaxis self.plotly_fig["layout"]["yaxis{0}".format(self.axis_ct or '')] = yaxis + if mpl_zaxis: + self.plotly_fig["layout"]["zaxis{0}".format(self.axis_ct or '')] = zaxis # let all subsequent dates be handled properly if required @@ -398,13 +404,17 @@ class PlotlyRenderer(Renderer): if isinstance(props["label"], six.string_types) else props["label"] ), - x=[xy_pair[0] for xy_pair in props["data"]], - y=[xy_pair[1] for xy_pair in props["data"]], + x=props["data"][0], + y=props["data"][1], xaxis="x{0}".format(self.axis_ct), yaxis="y{0}".format(self.axis_ct), line=line, marker=marker, ) + if len(props["data"]) >= 3: + marked_line["z"] = props["data"][2] + marked_line["zaxis"] = "z{0}".format(self.axis_ct) + marked_line["type"] = "scatter3d" if self.x_is_mpl_date: formatter = ( self.current_mpl_ax.get_xaxis()