From 0be016d3a545bde05b5071d36e0a0e8848804de7 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 25 Nov 2020 11:00:34 +0200 Subject: [PATCH] Use a built-in matplotlib convertor --- trains/binding/matplotlib_bind.py | 59 +- trains/utilities/plotlympl/__init__.py | 15 + .../plotlympl/mplexporter/__init__.py | 2 + .../plotlympl/mplexporter/_py3k_compat.py | 22 + .../plotlympl/mplexporter/exporter.py | 286 +++++++ .../mplexporter/renderers/__init__.py | 12 + .../plotlympl/mplexporter/renderers/base.py | 388 +++++++++ .../mplexporter/renderers/fake_renderer.py | 68 ++ .../mplexporter/renderers/vega_renderer.py | 138 ++++ .../mplexporter/renderers/vincent_renderer.py | 52 ++ .../utilities/plotlympl/mplexporter/tools.py | 52 ++ .../utilities/plotlympl/mplexporter/utils.py | 362 +++++++++ trains/utilities/plotlympl/mpltools.py | 600 ++++++++++++++ trains/utilities/plotlympl/renderer.py | 768 ++++++++++++++++++ 14 files changed, 2784 insertions(+), 40 deletions(-) create mode 100644 trains/utilities/plotlympl/__init__.py create mode 100644 trains/utilities/plotlympl/mplexporter/__init__.py create mode 100644 trains/utilities/plotlympl/mplexporter/_py3k_compat.py create mode 100644 trains/utilities/plotlympl/mplexporter/exporter.py create mode 100644 trains/utilities/plotlympl/mplexporter/renderers/__init__.py create mode 100644 trains/utilities/plotlympl/mplexporter/renderers/base.py create mode 100644 trains/utilities/plotlympl/mplexporter/renderers/fake_renderer.py create mode 100644 trains/utilities/plotlympl/mplexporter/renderers/vega_renderer.py create mode 100644 trains/utilities/plotlympl/mplexporter/renderers/vincent_renderer.py create mode 100644 trains/utilities/plotlympl/mplexporter/tools.py create mode 100644 trains/utilities/plotlympl/mplexporter/utils.py create mode 100644 trains/utilities/plotlympl/mpltools.py create mode 100644 trains/utilities/plotlympl/renderer.py diff --git a/trains/binding/matplotlib_bind.py b/trains/binding/matplotlib_bind.py index a6138ad8..5d4e68c7 100644 --- a/trains/binding/matplotlib_bind.py +++ b/trains/binding/matplotlib_bind.py @@ -30,7 +30,6 @@ class PatchedMatplotlib: _support_image_plot = False _matplotlylib = None _plotly_renderer = None - _patched_mpltools_get_spine_visible = False _lock_renderer = threading.RLock() _recursion_guard = {} _matplot_major_version = 0 @@ -102,18 +101,10 @@ class PatchedMatplotlib: sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow sys.modules['matplotlib'].figure.Figure.savefig = PatchedMatplotlib.patched_savefig # patch plotly so we know it failed us. - from plotly.matplotlylib import renderer + from ..utilities.plotlympl import renderer + # noinspection PyProtectedMember renderer.warnings = PatchedMatplotlib._PatchWarnings() - # ignore deprecation warnings from plotly to matplotlib - try: - import warnings - warnings.filterwarnings(action='ignore', category=matplotlib.MatplotlibDeprecationWarning, - module='plotly') - warnings.filterwarnings(action='ignore', category=UserWarning, module='plotly') - except Exception: - pass - except Exception: return False @@ -167,10 +158,11 @@ class PatchedMatplotlib: return True # create plotly renderer + # noinspection PyBroadException try: - from plotly import optional_imports - PatchedMatplotlib._matplotlylib = optional_imports.get_module('plotly.matplotlylib') - PatchedMatplotlib._plotly_renderer = PatchedMatplotlib._matplotlylib.PlotlyRenderer() + from ..utilities import plotlympl + PatchedMatplotlib._matplotlylib = plotlympl + PatchedMatplotlib._plotly_renderer = plotlympl.PlotlyRenderer() except Exception: return False @@ -324,7 +316,7 @@ class PatchedMatplotlib: try: import matplotlib.pyplot as plt from matplotlib import _pylab_helpers - from plotly.io import templates + if specific_fig is None: # store the figure object we just created (if it is not already there) stored_figure = stored_figure or _pylab_helpers.Gcf.get_active() @@ -350,7 +342,7 @@ class PatchedMatplotlib: # convert to plotly image = None - plotly_fig = None + plotly_dict = None image_format = 'jpeg' fig_dpi = 300 if force_save_as_image: @@ -367,13 +359,7 @@ class PatchedMatplotlib: def our_mpl_to_plotly(fig): if not PatchedMatplotlib._update_plotly_renderers(): return None - if not PatchedMatplotlib._patched_mpltools_get_spine_visible and \ - PatchedMatplotlib._matplot_major_version and \ - PatchedMatplotlib._matplot_major_version >= 3 and \ - PatchedMatplotlib._matplot_minor_version >= 3: - from plotly.matplotlylib import mpltools - mpltools.get_spine_visible = lambda *_, **__: True - PatchedMatplotlib._patched_mpltools_get_spine_visible = True + plotly_renderer = PatchedMatplotlib._matplotlylib.PlotlyRenderer() PatchedMatplotlib._matplotlylib.Exporter(plotly_renderer, close_mpl=False).run(fig) @@ -382,7 +368,7 @@ class PatchedMatplotlib: # noinspection PyBroadException try: # check if all values can be cast to float - [float(t.get_text().replace('−', '-')) for t in x_ticks] + _ = [float(t.get_text().replace('−', '-')) for t in x_ticks] except Exception: # noinspection PyBroadException try: @@ -409,13 +395,7 @@ class PatchedMatplotlib: pass return deepcopy(plotly_renderer.plotly_fig) - plotly_fig = our_mpl_to_plotly(mpl_fig) - # noinspection PyBroadException - try: - if 'none' in templates: - plotly_fig._layout_obj.template = templates['none'] - except Exception: - pass + plotly_dict = our_mpl_to_plotly(mpl_fig) except Exception as ex: # this was an image, change format to png image_format = 'jpeg' if 'selfie' in str(ex) else 'png' @@ -424,8 +404,8 @@ class PatchedMatplotlib: PatchedMatplotlib._lock_renderer.release() # plotly could not serialize the plot, we should convert to image - if not plotly_fig: - plotly_fig = None + if not plotly_dict: + plotly_dict = None # noinspection PyBroadException try: # first try SVG if we fail then fallback to png @@ -459,7 +439,7 @@ class PatchedMatplotlib: last_iteration = iter if iter is not None else PatchedMatplotlib._get_last_iteration() - report_as_debug_sample = not plotly_fig and ( + report_as_debug_sample = not plotly_dict and ( force_save_as_image or not PatchedMatplotlib._support_image_plot) if not title: @@ -486,15 +466,14 @@ class PatchedMatplotlib: PatchedMatplotlib._matplotlib_reported_titles.add(title) # remove borders and size, we should let the web take care of that - if plotly_fig: - plotly_fig.layout.margin = {} - plotly_fig.layout.autosize = True - plotly_fig.layout.height = None - plotly_fig.layout.width = None + if plotly_dict: # send the plot event - plotly_dict = plotly_fig.to_plotly_json() if not plotly_dict.get('layout'): plotly_dict['layout'] = {} + plotly_dict['layout']['margin'] = {} + plotly_dict['layout']['autosize'] = True + plotly_dict['layout']['height'] = None + plotly_dict['layout']['width'] = None plotly_dict['layout']['title'] = series or title reporter.report_plot(title=title, series=series or 'plot', plot=plotly_dict, iter=last_iteration) diff --git a/trains/utilities/plotlympl/__init__.py b/trains/utilities/plotlympl/__init__.py new file mode 100644 index 00000000..6e7d836b --- /dev/null +++ b/trains/utilities/plotlympl/__init__.py @@ -0,0 +1,15 @@ +""" +matplotlylib +============ + +This module converts matplotlib figure objects into JSON structures which can +be understood and visualized by Plotly. + +Most of the functionality should be accessed through the parent directory's +'tools' module or 'plotly' package. + +""" +from __future__ import absolute_import + +from .renderer import PlotlyRenderer +from .mplexporter import Exporter diff --git a/trains/utilities/plotlympl/mplexporter/__init__.py b/trains/utilities/plotlympl/mplexporter/__init__.py new file mode 100644 index 00000000..970731c6 --- /dev/null +++ b/trains/utilities/plotlympl/mplexporter/__init__.py @@ -0,0 +1,2 @@ +from .renderers import Renderer +from .exporter import Exporter diff --git a/trains/utilities/plotlympl/mplexporter/_py3k_compat.py b/trains/utilities/plotlympl/mplexporter/_py3k_compat.py new file mode 100644 index 00000000..9ca84550 --- /dev/null +++ b/trains/utilities/plotlympl/mplexporter/_py3k_compat.py @@ -0,0 +1,22 @@ +""" +Simple fixes for Python 2/3 compatibility +""" +import sys +PY3K = sys.version_info[0] >= 3 + + +if PY3K: + import builtins + import functools + reduce = functools.reduce + zip = builtins.zip + xrange = builtins.range + map = builtins.map +else: + import __builtin__ + import itertools + builtins = __builtin__ + reduce = __builtin__.reduce + zip = itertools.izip + xrange = __builtin__.xrange + map = itertools.imap diff --git a/trains/utilities/plotlympl/mplexporter/exporter.py b/trains/utilities/plotlympl/mplexporter/exporter.py new file mode 100644 index 00000000..e49fb46d --- /dev/null +++ b/trains/utilities/plotlympl/mplexporter/exporter.py @@ -0,0 +1,286 @@ +""" +Matplotlib Exporter +=================== +This submodule contains tools for crawling a matplotlib figure and exporting +relevant pieces to a renderer. +""" +import warnings +import io +from . import utils + +import matplotlib +from matplotlib import transforms, collections +from matplotlib.backends.backend_agg import FigureCanvasAgg + +class Exporter(object): + """Matplotlib Exporter + + Parameters + ---------- + renderer : Renderer object + The renderer object called by the exporter to create a figure + visualization. See mplexporter.Renderer for information on the + methods which should be defined within the renderer. + close_mpl : bool + If True (default), close the matplotlib figure as it is rendered. This + is useful for when the exporter is used within the notebook, or with + an interactive matplotlib backend. + """ + + def __init__(self, renderer, close_mpl=True): + self.close_mpl = close_mpl + self.renderer = renderer + + def run(self, fig): + """ + Run the exporter on the given figure + + Parmeters + --------- + fig : matplotlib.Figure instance + The figure to export + """ + # Calling savefig executes the draw() command, putting elements + # in the correct place. + if fig.canvas is None: + canvas = FigureCanvasAgg(fig) + fig.savefig(io.BytesIO(), format='png', dpi=fig.dpi) + if self.close_mpl: + import matplotlib.pyplot as plt + plt.close(fig) + self.crawl_fig(fig) + + @staticmethod + def process_transform(transform, ax=None, data=None, return_trans=False, + force_trans=None): + """Process the transform and convert data to figure or data coordinates + + Parameters + ---------- + transform : matplotlib Transform object + The transform applied to the data + ax : matplotlib Axes object (optional) + The axes the data is associated with + data : ndarray (optional) + The array of data to be transformed. + return_trans : bool (optional) + If true, return the final transform of the data + force_trans : matplotlib.transform instance (optional) + If supplied, first force the data to this transform + + Returns + ------- + code : string + Code is either "data", "axes", "figure", or "display", indicating + the type of coordinates output. + transform : matplotlib transform + the transform used to map input data to output data. + Returned only if return_trans is True + new_data : ndarray + Data transformed to match the given coordinate code. + Returned only if data is specified + """ + if isinstance(transform, transforms.BlendedGenericTransform): + warnings.warn("Blended transforms not yet supported. " + "Zoom behavior may not work as expected.") + + if force_trans is not None: + if data is not None: + data = (transform - force_trans).transform(data) + transform = force_trans + + code = "display" + if ax is not None: + for (c, trans) in [("data", ax.transData), + ("axes", ax.transAxes), + ("figure", ax.figure.transFigure), + ("display", transforms.IdentityTransform())]: + if transform.contains_branch(trans): + code, transform = (c, transform - trans) + break + + if data is not None: + if return_trans: + return code, transform.transform(data), transform + else: + return code, transform.transform(data) + else: + if return_trans: + return code, transform + else: + return code + + def crawl_fig(self, fig): + """Crawl the figure and process all axes""" + with self.renderer.draw_figure(fig=fig, + props=utils.get_figure_properties(fig)): + for ax in fig.axes: + self.crawl_ax(ax) + + def crawl_ax(self, ax): + """Crawl the axes and process all elements within""" + with self.renderer.draw_axes(ax=ax, + props=utils.get_axes_properties(ax)): + for line in ax.lines: + self.draw_line(ax, line) + for text in ax.texts: + self.draw_text(ax, text) + for (text, ttp) in zip([ax.xaxis.label, ax.yaxis.label, ax.title], + ["xlabel", "ylabel", "title"]): + if(hasattr(text, 'get_text') and text.get_text()): + self.draw_text(ax, text, force_trans=ax.transAxes, + text_type=ttp) + for artist in ax.artists: + # TODO: process other artists + if isinstance(artist, matplotlib.text.Text): + self.draw_text(ax, artist) + for patch in ax.patches: + self.draw_patch(ax, patch) + for collection in ax.collections: + self.draw_collection(ax, collection) + for image in ax.images: + self.draw_image(ax, image) + + legend = ax.get_legend() + if legend is not None: + props = utils.get_legend_properties(ax, legend) + with self.renderer.draw_legend(legend=legend, props=props): + if props['visible']: + self.crawl_legend(ax, legend) + + def crawl_legend(self, ax, legend): + """ + Recursively look through objects in legend children + """ + legendElements = list(utils.iter_all_children(legend._legend_box, + skipContainers=True)) + legendElements.append(legend.legendPatch) + for child in legendElements: + # force a large zorder so it appears on top + child.set_zorder(1E6 + child.get_zorder()) + + # reorder border box to make sure marks are visible + if isinstance(child, matplotlib.patches.FancyBboxPatch): + child.set_zorder(child.get_zorder()-1) + + try: + # What kind of object... + if isinstance(child, matplotlib.patches.Patch): + self.draw_patch(ax, child, force_trans=ax.transAxes) + elif isinstance(child, matplotlib.text.Text): + if child.get_text() != 'None': + self.draw_text(ax, child, force_trans=ax.transAxes) + elif isinstance(child, matplotlib.lines.Line2D): + self.draw_line(ax, child, force_trans=ax.transAxes) + elif isinstance(child, matplotlib.collections.Collection): + self.draw_collection(ax, child, + force_pathtrans=ax.transAxes) + else: + warnings.warn("Legend element %s not impemented" % child) + except NotImplementedError: + warnings.warn("Legend element %s not impemented" % child) + + def draw_line(self, ax, line, force_trans=None): + """Process a matplotlib line and call renderer.draw_line""" + coordinates, data = self.process_transform(line.get_transform(), + ax, line.get_xydata(), + force_trans=force_trans) + linestyle = utils.get_line_style(line) + if (linestyle['dasharray'] is None + and linestyle['drawstyle'] == 'default'): + linestyle = None + markerstyle = utils.get_marker_style(line) + if (markerstyle['marker'] in ['None', 'none', None] + or markerstyle['markerpath'][0].size == 0): + markerstyle = None + label = line.get_label() + if markerstyle or linestyle: + self.renderer.draw_marked_line(data=data, coordinates=coordinates, + linestyle=linestyle, + markerstyle=markerstyle, + label=label, + mplobj=line) + + def draw_text(self, ax, text, force_trans=None, text_type=None): + """Process a matplotlib text object and call renderer.draw_text""" + content = text.get_text() + if content: + transform = text.get_transform() + position = text.get_position() + coords, position = self.process_transform(transform, ax, + position, + force_trans=force_trans) + style = utils.get_text_style(text) + self.renderer.draw_text(text=content, position=position, + coordinates=coords, + text_type=text_type, + style=style, mplobj=text) + + def draw_patch(self, ax, patch, force_trans=None): + """Process a matplotlib patch object and call renderer.draw_path""" + vertices, pathcodes = utils.SVG_path(patch.get_path()) + transform = patch.get_transform() + coordinates, vertices = self.process_transform(transform, + ax, vertices, + force_trans=force_trans) + linestyle = utils.get_path_style(patch, fill=patch.get_fill()) + self.renderer.draw_path(data=vertices, + coordinates=coordinates, + pathcodes=pathcodes, + style=linestyle, + mplobj=patch) + + def draw_collection(self, ax, collection, + force_pathtrans=None, + force_offsettrans=None): + """Process a matplotlib collection and call renderer.draw_collection""" + (transform, transOffset, + offsets, paths) = collection._prepare_points() + + offset_coords, offsets = self.process_transform( + transOffset, ax, offsets, force_trans=force_offsettrans) + path_coords = self.process_transform( + transform, ax, force_trans=force_pathtrans) + + processed_paths = [utils.SVG_path(path) for path in paths] + processed_paths = [(self.process_transform( + transform, ax, path[0], force_trans=force_pathtrans)[1], path[1]) + for path in processed_paths] + + path_transforms = collection.get_transforms() + try: + # matplotlib 1.3: path_transforms are transform objects. + # Convert them to numpy arrays. + path_transforms = [t.get_matrix() for t in path_transforms] + except AttributeError: + # matplotlib 1.4: path transforms are already numpy arrays. + pass + + styles = {'linewidth': collection.get_linewidths(), + 'facecolor': collection.get_facecolors(), + 'edgecolor': collection.get_edgecolors(), + 'alpha': collection._alpha, + 'zorder': collection.get_zorder()} + + offset_dict = {"data": "before", + "screen": "after"} + # protected for removing _offset_position() and default to "screen" + offset_order = offset_dict[getattr(collection, '_offset_position', 'screen')] + + self.renderer.draw_path_collection(paths=processed_paths, + path_coordinates=path_coords, + path_transforms=path_transforms, + offsets=offsets, + offset_coordinates=offset_coords, + offset_order=offset_order, + styles=styles, + mplobj=collection) + + def draw_image(self, ax, image): + """Process a matplotlib image object and call renderer.draw_image""" + self.renderer.draw_image(imdata=utils.image_to_base64(image), + extent=image.get_extent(), + coordinates="data", + style={"alpha": image.get_alpha(), + "zorder": image.get_zorder()}, + mplobj=image) diff --git a/trains/utilities/plotlympl/mplexporter/renderers/__init__.py b/trains/utilities/plotlympl/mplexporter/renderers/__init__.py new file mode 100644 index 00000000..ba85b1aa --- /dev/null +++ b/trains/utilities/plotlympl/mplexporter/renderers/__init__.py @@ -0,0 +1,12 @@ +""" +Matplotlib Renderers +==================== +This submodule contains renderer objects which define renderer behavior used +within the Exporter class. The base renderer class is :class:`Renderer`, an +abstract base class +""" + +from .base import Renderer +from .vega_renderer import VegaRenderer, fig_to_vega +from .vincent_renderer import VincentRenderer, fig_to_vincent +from .fake_renderer import FakeRenderer, FullFakeRenderer diff --git a/trains/utilities/plotlympl/mplexporter/renderers/base.py b/trains/utilities/plotlympl/mplexporter/renderers/base.py new file mode 100644 index 00000000..6bf5acb4 --- /dev/null +++ b/trains/utilities/plotlympl/mplexporter/renderers/base.py @@ -0,0 +1,388 @@ +import warnings +import itertools +from contextlib import contextmanager +from distutils.version import LooseVersion + +import numpy as np +import matplotlib as mpl +from matplotlib import transforms + +from .. import utils +from .. import _py3k_compat as py3k + + +class Renderer(object): + @staticmethod + def ax_zoomable(ax): + return bool(ax and ax.get_navigate()) + + @staticmethod + def ax_has_xgrid(ax): + return bool(ax and ax.xaxis._gridOnMajor and ax.yaxis.get_gridlines()) + + @staticmethod + def ax_has_ygrid(ax): + return bool(ax and ax.yaxis._gridOnMajor and ax.yaxis.get_gridlines()) + + @property + def current_ax_zoomable(self): + return self.ax_zoomable(self._current_ax) + + @property + def current_ax_has_xgrid(self): + return self.ax_has_xgrid(self._current_ax) + + @property + def current_ax_has_ygrid(self): + return self.ax_has_ygrid(self._current_ax) + + @contextmanager + def draw_figure(self, fig, props): + if hasattr(self, "_current_fig") and self._current_fig is not None: + warnings.warn("figure embedded in figure: something is wrong") + self._current_fig = fig + self._fig_props = props + self.open_figure(fig=fig, props=props) + yield + self.close_figure(fig=fig) + self._current_fig = None + self._fig_props = {} + + @contextmanager + def draw_axes(self, ax, props): + if hasattr(self, "_current_ax") and self._current_ax is not None: + warnings.warn("axes embedded in axes: something is wrong") + self._current_ax = ax + self._ax_props = props + self.open_axes(ax=ax, props=props) + yield + self.close_axes(ax=ax) + self._current_ax = None + self._ax_props = {} + + @contextmanager + def draw_legend(self, legend, props): + self._current_legend = legend + self._legend_props = props + self.open_legend(legend=legend, props=props) + yield + self.close_legend(legend=legend) + self._current_legend = None + self._legend_props = {} + + # Following are the functions which should be overloaded in subclasses + + def open_figure(self, fig, props): + """ + Begin commands for a particular figure. + + Parameters + ---------- + fig : matplotlib.Figure + The Figure which will contain the ensuing axes and elements + props : dictionary + The dictionary of figure properties + """ + pass + + def close_figure(self, fig): + """ + Finish commands for a particular figure. + + Parameters + ---------- + fig : matplotlib.Figure + The figure which is finished being drawn. + """ + pass + + def open_axes(self, ax, props): + """ + Begin commands for a particular axes. + + Parameters + ---------- + ax : matplotlib.Axes + The Axes which will contain the ensuing axes and elements + props : dictionary + The dictionary of axes properties + """ + pass + + def close_axes(self, ax): + """ + Finish commands for a particular axes. + + Parameters + ---------- + ax : matplotlib.Axes + The Axes which is finished being drawn. + """ + pass + + def open_legend(self, legend, props): + """ + Beging commands for a particular legend. + + Parameters + ---------- + legend : matplotlib.legend.Legend + The Legend that will contain the ensuing elements + props : dictionary + The dictionary of legend properties + """ + pass + + def close_legend(self, legend): + """ + Finish commands for a particular legend. + + Parameters + ---------- + legend : matplotlib.legend.Legend + The Legend which is finished being drawn + """ + pass + + def draw_marked_line(self, data, coordinates, linestyle, markerstyle, + label, mplobj=None): + """Draw a line that also has markers. + + If this isn't reimplemented by a renderer object, by default, it will + make a call to BOTH draw_line and draw_markers when both markerstyle + and linestyle are not None in the same Line2D object. + + """ + if linestyle is not None: + self.draw_line(data, coordinates, linestyle, label, mplobj) + if markerstyle is not None: + self.draw_markers(data, coordinates, markerstyle, label, mplobj) + + def draw_line(self, data, coordinates, style, label, mplobj=None): + """ + Draw a line. By default, draw the line via the draw_path() command. + Some renderers might wish to override this and provide more + fine-grained behavior. + + In matplotlib, lines are generally created via the plt.plot() command, + though this command also can create marker collections. + + Parameters + ---------- + data : array_like + A shape (N, 2) array of datapoints. + coordinates : string + A string code, which should be either 'data' for data coordinates, + or 'figure' for figure (pixel) coordinates. + style : dictionary + a dictionary specifying the appearance of the line. + mplobj : matplotlib object + the matplotlib plot element which generated this line + """ + pathcodes = ['M'] + (data.shape[0] - 1) * ['L'] + pathstyle = dict(facecolor='none', **style) + pathstyle['edgecolor'] = pathstyle.pop('color') + pathstyle['edgewidth'] = pathstyle.pop('linewidth') + self.draw_path(data=data, coordinates=coordinates, + pathcodes=pathcodes, style=pathstyle, mplobj=mplobj) + + @staticmethod + def _iter_path_collection(paths, path_transforms, offsets, styles): + """Build an iterator over the elements of the path collection""" + N = max(len(paths), len(offsets)) + + # Before mpl 1.4.0, path_transform can be a false-y value, not a valid + # transformation matrix. + if LooseVersion(mpl.__version__) < LooseVersion('1.4.0'): + if path_transforms is None: + path_transforms = [np.eye(3)] + + edgecolor = styles['edgecolor'] + if np.size(edgecolor) == 0: + edgecolor = ['none'] + facecolor = styles['facecolor'] + if np.size(facecolor) == 0: + facecolor = ['none'] + + elements = [paths, path_transforms, offsets, + edgecolor, styles['linewidth'], facecolor] + + it = itertools + return it.islice(py3k.zip(*py3k.map(it.cycle, elements)), N) + + def draw_path_collection(self, paths, path_coordinates, path_transforms, + offsets, offset_coordinates, offset_order, + styles, mplobj=None): + """ + Draw a collection of paths. The paths, offsets, and styles are all + iterables, and the number of paths is max(len(paths), len(offsets)). + + By default, this is implemented via multiple calls to the draw_path() + function. For efficiency, Renderers may choose to customize this + implementation. + + Examples of path collections created by matplotlib are scatter plots, + histograms, contour plots, and many others. + + Parameters + ---------- + paths : list + list of tuples, where each tuple has two elements: + (data, pathcodes). See draw_path() for a description of these. + path_coordinates: string + the coordinates code for the paths, which should be either + 'data' for data coordinates, or 'figure' for figure (pixel) + coordinates. + path_transforms: array_like + an array of shape (*, 3, 3), giving a series of 2D Affine + transforms for the paths. These encode translations, rotations, + and scalings in the standard way. + offsets: array_like + An array of offsets of shape (N, 2) + offset_coordinates : string + the coordinates code for the offsets, which should be either + 'data' for data coordinates, or 'figure' for figure (pixel) + coordinates. + offset_order : string + either "before" or "after". This specifies whether the offset + is applied before the path transform, or after. The matplotlib + backend equivalent is "before"->"data", "after"->"screen". + styles: dictionary + A dictionary in which each value is a list of length N, containing + the style(s) for the paths. + mplobj : matplotlib object + the matplotlib plot element which generated this collection + """ + if offset_order == "before": + raise NotImplementedError("offset before transform") + + for tup in self._iter_path_collection(paths, path_transforms, + offsets, styles): + (path, path_transform, offset, ec, lw, fc) = tup + vertices, pathcodes = path + path_transform = transforms.Affine2D(path_transform) + vertices = path_transform.transform(vertices) + # This is a hack: + if path_coordinates == "figure": + path_coordinates = "points" + style = {"edgecolor": utils.export_color(ec), + "facecolor": utils.export_color(fc), + "edgewidth": lw, + "dasharray": "10,0", + "alpha": styles['alpha'], + "zorder": styles['zorder']} + self.draw_path(data=vertices, coordinates=path_coordinates, + pathcodes=pathcodes, style=style, offset=offset, + offset_coordinates=offset_coordinates, + mplobj=mplobj) + + def draw_markers(self, data, coordinates, style, label, mplobj=None): + """ + Draw a set of markers. By default, this is done by repeatedly + calling draw_path(), but renderers should generally overload + this method to provide a more efficient implementation. + + In matplotlib, markers are created using the plt.plot() command. + + Parameters + ---------- + data : array_like + A shape (N, 2) array of datapoints. + coordinates : string + A string code, which should be either 'data' for data coordinates, + or 'figure' for figure (pixel) coordinates. + style : dictionary + a dictionary specifying the appearance of the markers. + mplobj : matplotlib object + the matplotlib plot element which generated this marker collection + """ + vertices, pathcodes = style['markerpath'] + pathstyle = dict((key, style[key]) for key in ['alpha', 'edgecolor', + 'facecolor', 'zorder', + 'edgewidth']) + pathstyle['dasharray'] = "10,0" + for vertex in data: + self.draw_path(data=vertices, coordinates="points", + pathcodes=pathcodes, style=pathstyle, + offset=vertex, offset_coordinates=coordinates, + mplobj=mplobj) + + def draw_text(self, text, position, coordinates, style, + text_type=None, mplobj=None): + """ + Draw text on the image. + + Parameters + ---------- + text : string + The text to draw + position : tuple + The (x, y) position of the text + coordinates : string + A string code, which should be either 'data' for data coordinates, + or 'figure' for figure (pixel) coordinates. + style : dictionary + a dictionary specifying the appearance of the text. + text_type : string or None + if specified, a type of text such as "xlabel", "ylabel", "title" + mplobj : matplotlib object + the matplotlib plot element which generated this text + """ + raise NotImplementedError() + + def draw_path(self, data, coordinates, pathcodes, style, + offset=None, offset_coordinates="data", mplobj=None): + """ + Draw a path. + + In matplotlib, paths are created by filled regions, histograms, + contour plots, patches, etc. + + Parameters + ---------- + data : array_like + A shape (N, 2) array of datapoints. + coordinates : string + A string code, which should be either 'data' for data coordinates, + 'figure' for figure (pixel) coordinates, or "points" for raw + point coordinates (useful in conjunction with offsets, below). + pathcodes : list + A list of single-character SVG pathcodes associated with the data. + Path codes are one of ['M', 'm', 'L', 'l', 'Q', 'q', 'T', 't', + 'S', 's', 'C', 'c', 'Z', 'z'] + See the SVG specification for details. Note that some path codes + consume more than one datapoint (while 'Z' consumes none), so + in general, the length of the pathcodes list will not be the same + as that of the data array. + style : dictionary + a dictionary specifying the appearance of the line. + offset : list (optional) + the (x, y) offset of the path. If not given, no offset will + be used. + offset_coordinates : string (optional) + A string code, which should be either 'data' for data coordinates, + or 'figure' for figure (pixel) coordinates. + mplobj : matplotlib object + the matplotlib plot element which generated this path + """ + raise NotImplementedError() + + def draw_image(self, imdata, extent, coordinates, style, mplobj=None): + """ + Draw an image. + + Parameters + ---------- + imdata : string + base64 encoded png representation of the image + extent : list + the axes extent of the image: [xmin, xmax, ymin, ymax] + coordinates: string + A string code, which should be either 'data' for data coordinates, + or 'figure' for figure (pixel) coordinates. + style : dictionary + a dictionary specifying the appearance of the image + mplobj : matplotlib object + the matplotlib plot object which generated this image + """ + raise NotImplementedError() diff --git a/trains/utilities/plotlympl/mplexporter/renderers/fake_renderer.py b/trains/utilities/plotlympl/mplexporter/renderers/fake_renderer.py new file mode 100644 index 00000000..2c4c708c --- /dev/null +++ b/trains/utilities/plotlympl/mplexporter/renderers/fake_renderer.py @@ -0,0 +1,68 @@ +from .base import Renderer + + +class FakeRenderer(Renderer): + """ + Fake Renderer + + This is a fake renderer which simply outputs a text tree representing the + elements found in the plot(s). This is used in the unit tests for the + package. + + Below are the methods your renderer must implement. You are free to do + anything you wish within the renderer (i.e. build an XML or JSON + representation, call an external API, etc.) Here the renderer just + builds a simple string representation for testing purposes. + """ + def __init__(self): + self.output = "" + + def open_figure(self, fig, props): + self.output += "opening figure\n" + + def close_figure(self, fig): + self.output += "closing figure\n" + + def open_axes(self, ax, props): + self.output += " opening axes\n" + + def close_axes(self, ax): + self.output += " closing axes\n" + + def open_legend(self, legend, props): + self.output += " opening legend\n" + + def close_legend(self, legend): + self.output += " closing legend\n" + + def draw_text(self, text, position, coordinates, style, + text_type=None, mplobj=None): + self.output += " draw text '{0}' {1}\n".format(text, text_type) + + def draw_path(self, data, coordinates, pathcodes, style, + offset=None, offset_coordinates="data", mplobj=None): + self.output += " draw path with {0} vertices\n".format(data.shape[0]) + + def draw_image(self, imdata, extent, coordinates, style, mplobj=None): + self.output += " draw image of size {0}\n".format(len(imdata)) + + +class FullFakeRenderer(FakeRenderer): + """ + Renderer with the full complement of methods. + + When the following are left undefined, they will be implemented via + other methods in the class. They can be defined explicitly for + more efficient or specialized use within the renderer implementation. + """ + def draw_line(self, data, coordinates, style, label, mplobj=None): + self.output += " draw line with {0} points\n".format(data.shape[0]) + + def draw_markers(self, data, coordinates, style, label, mplobj=None): + self.output += " draw {0} markers\n".format(data.shape[0]) + + def draw_path_collection(self, paths, path_coordinates, path_transforms, + offsets, offset_coordinates, offset_order, + styles, mplobj=None): + self.output += (" draw path collection " + "with {0} offsets\n".format(offsets.shape[0])) diff --git a/trains/utilities/plotlympl/mplexporter/renderers/vega_renderer.py b/trains/utilities/plotlympl/mplexporter/renderers/vega_renderer.py new file mode 100644 index 00000000..82a30bd9 --- /dev/null +++ b/trains/utilities/plotlympl/mplexporter/renderers/vega_renderer.py @@ -0,0 +1,138 @@ +import warnings +import json +import random +from .base import Renderer +from ..exporter import Exporter + + +class VegaRenderer(Renderer): + def open_figure(self, fig, props): + self.props = props + self.figwidth = int(props['figwidth'] * props['dpi']) + self.figheight = int(props['figheight'] * props['dpi']) + self.data = [] + self.scales = [] + self.axes = [] + self.marks = [] + + def open_axes(self, ax, props): + if len(self.axes) > 0: + warnings.warn("multiple axes not yet supported") + self.axes = [dict(type="x", scale="x", ticks=10), + dict(type="y", scale="y", ticks=10)] + self.scales = [dict(name="x", + domain=props['xlim'], + type="linear", + range="width", + ), + dict(name="y", + domain=props['ylim'], + type="linear", + range="height", + ),] + + def draw_line(self, data, coordinates, style, label, mplobj=None): + if coordinates != 'data': + warnings.warn("Only data coordinates supported. Skipping this") + dataname = "table{0:03d}".format(len(self.data) + 1) + + # TODO: respect the other style settings + self.data.append({'name': dataname, + 'values': [dict(x=d[0], y=d[1]) for d in data]}) + self.marks.append({'type': 'line', + 'from': {'data': dataname}, + 'properties': { + "enter": { + "interpolate": {"value": "monotone"}, + "x": {"scale": "x", "field": "data.x"}, + "y": {"scale": "y", "field": "data.y"}, + "stroke": {"value": style['color']}, + "strokeOpacity": {"value": style['alpha']}, + "strokeWidth": {"value": style['linewidth']}, + } + } + }) + + def draw_markers(self, data, coordinates, style, label, mplobj=None): + if coordinates != 'data': + warnings.warn("Only data coordinates supported. Skipping this") + dataname = "table{0:03d}".format(len(self.data) + 1) + + # TODO: respect the other style settings + self.data.append({'name': dataname, + 'values': [dict(x=d[0], y=d[1]) for d in data]}) + self.marks.append({'type': 'symbol', + 'from': {'data': dataname}, + 'properties': { + "enter": { + "interpolate": {"value": "monotone"}, + "x": {"scale": "x", "field": "data.x"}, + "y": {"scale": "y", "field": "data.y"}, + "fill": {"value": style['facecolor']}, + "fillOpacity": {"value": style['alpha']}, + "stroke": {"value": style['edgecolor']}, + "strokeOpacity": {"value": style['alpha']}, + "strokeWidth": {"value": style['edgewidth']}, + } + } + }) + + def draw_text(self, text, position, coordinates, style, + text_type=None, mplobj=None): + if text_type == 'xlabel': + self.axes[0]['title'] = text + elif text_type == 'ylabel': + self.axes[1]['title'] = text + + +class VegaHTML(object): + def __init__(self, renderer): + self.specification = dict(width=renderer.figwidth, + height=renderer.figheight, + data=renderer.data, + scales=renderer.scales, + axes=renderer.axes, + marks=renderer.marks) + + def html(self): + """Build the HTML representation for IPython.""" + id = random.randint(0, 2 ** 16) + html = '
' % id + html += '\n' + return html + + def _repr_html_(self): + return self.html() + + +def fig_to_vega(fig, notebook=False): + """Convert a matplotlib figure to vega dictionary + + if notebook=True, then return an object which will display in a notebook + otherwise, return an HTML string. + """ + renderer = VegaRenderer() + Exporter(renderer).run(fig) + vega_html = VegaHTML(renderer) + if notebook: + return vega_html + else: + return vega_html.html() + + +VEGA_TEMPLATE = """ +( function() { + var _do_plot = function() { + if ( (typeof vg == 'undefined') && (typeof IPython != 'undefined')) { + $([IPython.events]).on("vega_loaded.vincent", _do_plot); + return; + } + vg.parse.spec(%s, function(chart) { + chart({el: "#vis%d"}).update(); + }); + }; + _do_plot(); +})(); +""" diff --git a/trains/utilities/plotlympl/mplexporter/renderers/vincent_renderer.py b/trains/utilities/plotlympl/mplexporter/renderers/vincent_renderer.py new file mode 100644 index 00000000..73691eab --- /dev/null +++ b/trains/utilities/plotlympl/mplexporter/renderers/vincent_renderer.py @@ -0,0 +1,52 @@ +import warnings +from .base import Renderer +from ..exporter import Exporter + + +class VincentRenderer(Renderer): + def open_figure(self, fig, props): + self.chart = None + self.figwidth = int(props['figwidth'] * props['dpi']) + self.figheight = int(props['figheight'] * props['dpi']) + + def draw_line(self, data, coordinates, style, label, mplobj=None): + import vincent # only import if VincentRenderer is used + if coordinates != 'data': + warnings.warn("Only data coordinates supported. Skipping this") + linedata = {'x': data[:, 0], + 'y': data[:, 1]} + line = vincent.Line(linedata, iter_idx='x', + width=self.figwidth, height=self.figheight) + + # TODO: respect the other style settings + line.scales['color'].range = [style['color']] + + if self.chart is None: + self.chart = line + else: + warnings.warn("Multiple plot elements not yet supported") + + def draw_markers(self, data, coordinates, style, label, mplobj=None): + import vincent # only import if VincentRenderer is used + if coordinates != 'data': + warnings.warn("Only data coordinates supported. Skipping this") + markerdata = {'x': data[:, 0], + 'y': data[:, 1]} + markers = vincent.Scatter(markerdata, iter_idx='x', + width=self.figwidth, height=self.figheight) + + # TODO: respect the other style settings + markers.scales['color'].range = [style['facecolor']] + + if self.chart is None: + self.chart = markers + else: + warnings.warn("Multiple plot elements not yet supported") + + +def fig_to_vincent(fig): + """Convert a matplotlib figure to a vincent object""" + renderer = VincentRenderer() + exporter = Exporter(renderer) + exporter.run(fig) + return renderer.chart diff --git a/trains/utilities/plotlympl/mplexporter/tools.py b/trains/utilities/plotlympl/mplexporter/tools.py new file mode 100644 index 00000000..551e8bea --- /dev/null +++ b/trains/utilities/plotlympl/mplexporter/tools.py @@ -0,0 +1,52 @@ +""" +Tools for matplotlib plot exporting +""" + + +def ipynb_vega_init(): + """Initialize the IPython notebook display elements + + This function borrows heavily from the excellent vincent package: + http://github.com/wrobstory/vincent + """ + try: + from IPython.core.display import display, HTML + except ImportError: + print('IPython Notebook could not be loaded.') + + require_js = ''' + if (window['d3'] === undefined) {{ + require.config({{ paths: {{d3: "http://d3js.org/d3.v3.min"}} }}); + require(["d3"], function(d3) {{ + window.d3 = d3; + {0} + }}); + }}; + if (window['topojson'] === undefined) {{ + require.config( + {{ paths: {{topojson: "http://d3js.org/topojson.v1.min"}} }} + ); + require(["topojson"], function(topojson) {{ + window.topojson = topojson; + }}); + }}; + ''' + d3_geo_projection_js_url = "http://d3js.org/d3.geo.projection.v0.min.js" + d3_layout_cloud_js_url = ("http://wrobstory.github.io/d3-cloud/" + "d3.layout.cloud.js") + topojson_js_url = "http://d3js.org/topojson.v1.min.js" + vega_js_url = 'http://trifacta.github.com/vega/vega.js' + + dep_libs = '''$.getScript("%s", function() { + $.getScript("%s", function() { + $.getScript("%s", function() { + $.getScript("%s", function() { + $([IPython.events]).trigger("vega_loaded.vincent"); + }) + }) + }) + });''' % (d3_geo_projection_js_url, d3_layout_cloud_js_url, + topojson_js_url, vega_js_url) + load_js = require_js.format(dep_libs) + html = '' + display(HTML(html)) diff --git a/trains/utilities/plotlympl/mplexporter/utils.py b/trains/utilities/plotlympl/mplexporter/utils.py new file mode 100644 index 00000000..4059a6b6 --- /dev/null +++ b/trains/utilities/plotlympl/mplexporter/utils.py @@ -0,0 +1,362 @@ +""" +Utility Routines for Working with Matplotlib Objects +==================================================== +""" +import itertools +import io +import base64 + +import numpy as np + +import warnings + +import matplotlib +from matplotlib.colors import colorConverter +from matplotlib.path import Path +from matplotlib.markers import MarkerStyle +from matplotlib.transforms import Affine2D +from matplotlib import ticker + + +def export_color(color): + """Convert matplotlib color code to hex color or RGBA color""" + if color is None or colorConverter.to_rgba(color)[3] == 0: + return 'none' + elif colorConverter.to_rgba(color)[3] == 1: + rgb = colorConverter.to_rgb(color) + return '#{0:02X}{1:02X}{2:02X}'.format(*(int(255 * c) for c in rgb)) + else: + c = colorConverter.to_rgba(color) + return "rgba(" + ", ".join(str(int(np.round(val * 255))) + for val in c[:3])+', '+str(c[3])+")" + + +def _many_to_one(input_dict): + """Convert a many-to-one mapping to a one-to-one mapping""" + return dict((key, val) + for keys, val in input_dict.items() + for key in keys) + +LINESTYLES = _many_to_one({('solid', '-', (None, None)): 'none', + ('dashed', '--'): "6,6", + ('dotted', ':'): "2,2", + ('dashdot', '-.'): "4,4,2,4", + ('', ' ', 'None', 'none'): None}) + + +def get_dasharray(obj): + """Get an SVG dash array for the given matplotlib linestyle + + Parameters + ---------- + obj : matplotlib object + The matplotlib line or path object, which must have a get_linestyle() + method which returns a valid matplotlib line code + + Returns + ------- + dasharray : string + The HTML/SVG dasharray code associated with the object. + """ + if obj.__dict__.get('_dashSeq', None) is not None: + return ','.join(map(str, obj._dashSeq)) + else: + ls = obj.get_linestyle() + dasharray = LINESTYLES.get(ls, 'not found') + if dasharray == 'not found': + warnings.warn("line style '{0}' not understood: " + "defaulting to solid line.".format(ls)) + dasharray = LINESTYLES['solid'] + return dasharray + + +PATH_DICT = {Path.LINETO: 'L', + Path.MOVETO: 'M', + Path.CURVE3: 'S', + Path.CURVE4: 'C', + Path.CLOSEPOLY: 'Z'} + + +def SVG_path(path, transform=None, simplify=False): + """Construct the vertices and SVG codes for the path + + Parameters + ---------- + path : matplotlib.Path object + + transform : matplotlib transform (optional) + if specified, the path will be transformed before computing the output. + + Returns + ------- + vertices : array + The shape (M, 2) array of vertices of the Path. Note that some Path + codes require multiple vertices, so the length of these vertices may + be longer than the list of path codes. + path_codes : list + A length N list of single-character path codes, N <= M. Each code is + a single character, in ['L','M','S','C','Z']. See the standard SVG + path specification for a description of these. + """ + if transform is not None: + path = path.transformed(transform) + + vc_tuples = [(vertices if path_code != Path.CLOSEPOLY else [], + PATH_DICT[path_code]) + for (vertices, path_code) + in path.iter_segments(simplify=simplify)] + + if not vc_tuples: + # empty path is a special case + return np.zeros((0, 2)), [] + else: + vertices, codes = zip(*vc_tuples) + vertices = np.array(list(itertools.chain(*vertices))).reshape(-1, 2) + return vertices, list(codes) + + +def get_path_style(path, fill=True): + """Get the style dictionary for matplotlib path objects""" + style = {} + style['alpha'] = path.get_alpha() + if style['alpha'] is None: + style['alpha'] = 1 + style['edgecolor'] = export_color(path.get_edgecolor()) + if fill: + style['facecolor'] = export_color(path.get_facecolor()) + else: + style['facecolor'] = 'none' + style['edgewidth'] = path.get_linewidth() + style['dasharray'] = get_dasharray(path) + style['zorder'] = path.get_zorder() + return style + + +def get_line_style(line): + """Get the style dictionary for matplotlib line objects""" + style = {} + style['alpha'] = line.get_alpha() + if style['alpha'] is None: + style['alpha'] = 1 + style['color'] = export_color(line.get_color()) + style['linewidth'] = line.get_linewidth() + style['dasharray'] = get_dasharray(line) + style['zorder'] = line.get_zorder() + style['drawstyle'] = line.get_drawstyle() + return style + + +def get_marker_style(line): + """Get the style dictionary for matplotlib marker objects""" + style = {} + style['alpha'] = line.get_alpha() + if style['alpha'] is None: + style['alpha'] = 1 + + style['facecolor'] = export_color(line.get_markerfacecolor()) + style['edgecolor'] = export_color(line.get_markeredgecolor()) + style['edgewidth'] = line.get_markeredgewidth() + + style['marker'] = line.get_marker() + markerstyle = MarkerStyle(line.get_marker()) + markersize = line.get_markersize() + markertransform = (markerstyle.get_transform() + + Affine2D().scale(markersize, -markersize)) + style['markerpath'] = SVG_path(markerstyle.get_path(), + markertransform) + style['markersize'] = markersize + style['zorder'] = line.get_zorder() + return style + + +def get_text_style(text): + """Return the text style dict for a text instance""" + style = {} + style['alpha'] = text.get_alpha() + if style['alpha'] is None: + style['alpha'] = 1 + style['fontsize'] = text.get_size() + style['color'] = export_color(text.get_color()) + style['halign'] = text.get_horizontalalignment() # left, center, right + style['valign'] = text.get_verticalalignment() # baseline, center, top + style['malign'] = text._multialignment # text alignment when '\n' in text + style['rotation'] = text.get_rotation() + style['zorder'] = text.get_zorder() + return style + + +def get_axis_properties(axis): + """Return the property dictionary for a matplotlib.Axis instance""" + props = {} + label1On = axis._major_tick_kw.get('label1On', True) + + if isinstance(axis, matplotlib.axis.XAxis): + if label1On: + props['position'] = "bottom" + else: + props['position'] = "top" + elif isinstance(axis, matplotlib.axis.YAxis): + if label1On: + props['position'] = "left" + else: + props['position'] = "right" + else: + raise ValueError("{0} should be an Axis instance".format(axis)) + + # Use tick values if appropriate + locator = axis.get_major_locator() + props['nticks'] = len(locator()) + if isinstance(locator, ticker.FixedLocator): + props['tickvalues'] = list(locator()) + else: + props['tickvalues'] = None + + # Find tick formats + formatter = axis.get_major_formatter() + if isinstance(formatter, ticker.NullFormatter): + props['tickformat'] = "" + elif isinstance(formatter, ticker.FixedFormatter): + props['tickformat'] = list(formatter.seq) + elif not any(label.get_visible() for label in axis.get_ticklabels()): + props['tickformat'] = "" + else: + props['tickformat'] = None + + # Get axis scale + props['scale'] = axis.get_scale() + + # Get major tick label size (assumes that's all we really care about!) + labels = axis.get_ticklabels() + if labels: + props['fontsize'] = labels[0].get_fontsize() + else: + props['fontsize'] = None + + # Get associated grid + props['grid'] = get_grid_style(axis) + + # get axis visibility + props['visible'] = axis.get_visible() + + return props + + +def get_grid_style(axis): + gridlines = axis.get_gridlines() + if axis._gridOnMajor and len(gridlines) > 0: + color = export_color(gridlines[0].get_color()) + alpha = gridlines[0].get_alpha() + dasharray = get_dasharray(gridlines[0]) + return dict(gridOn=True, + color=color, + dasharray=dasharray, + alpha=alpha) + else: + return {"gridOn": False} + + +def get_figure_properties(fig): + return {'figwidth': fig.get_figwidth(), + 'figheight': fig.get_figheight(), + 'dpi': fig.dpi} + + +def get_axes_properties(ax): + props = {'axesbg': export_color(ax.patch.get_facecolor()), + 'axesbgalpha': ax.patch.get_alpha(), + 'bounds': ax.get_position().bounds, + 'dynamic': ax.get_navigate(), + 'axison': ax.axison, + 'frame_on': ax.get_frame_on(), + 'patch_visible':ax.patch.get_visible(), + 'axes': [get_axis_properties(ax.xaxis), + get_axis_properties(ax.yaxis)]} + + for axname in ['x', 'y']: + axis = getattr(ax, axname + 'axis') + domain = getattr(ax, 'get_{0}lim'.format(axname))() + lim = domain + if isinstance(axis.converter, matplotlib.dates.DateConverter): + scale = 'date' + try: + import pandas as pd + from pandas.tseries.converter import PeriodConverter + except ImportError: + pd = None + + if (pd is not None and isinstance(axis.converter, + PeriodConverter)): + _dates = [pd.Period(ordinal=int(d), freq=axis.freq) + for d in domain] + domain = [(d.year, d.month - 1, d.day, + d.hour, d.minute, d.second, 0) + for d in _dates] + else: + domain = [(d.year, d.month - 1, d.day, + d.hour, d.minute, d.second, + d.microsecond * 1E-3) + for d in matplotlib.dates.num2date(domain)] + else: + scale = axis.get_scale() + + if scale not in ['date', 'linear', 'log']: + raise ValueError("Unknown axis scale: " + "{0}".format(axis.get_scale())) + + props[axname + 'scale'] = scale + props[axname + 'lim'] = lim + props[axname + 'domain'] = domain + + return props + + +def iter_all_children(obj, skipContainers=False): + """ + Returns an iterator over all childen and nested children using + obj's get_children() method + + if skipContainers is true, only childless objects are returned. + """ + if hasattr(obj, 'get_children') and len(obj.get_children()) > 0: + for child in obj.get_children(): + if not skipContainers: + yield child + # could use `yield from` in python 3... + for grandchild in iter_all_children(child, skipContainers): + yield grandchild + else: + yield obj + + +def get_legend_properties(ax, legend): + handles, labels = ax.get_legend_handles_labels() + visible = legend.get_visible() + return {'handles': handles, 'labels': labels, 'visible': visible} + + +def image_to_base64(image): + """ + Convert a matplotlib image to a base64 png representation + + Parameters + ---------- + image : matplotlib image object + The image to be converted. + + Returns + ------- + image_base64 : string + The UTF8-encoded base64 string representation of the png image. + """ + ax = image.axes + binary_buffer = io.BytesIO() + + # image is saved in axes coordinates: we need to temporarily + # set the correct limits to get the correct image + lim = ax.axis() + ax.axis(image.get_extent()) + image.write_png(binary_buffer) + ax.axis(lim) + + binary_buffer.seek(0) + return base64.b64encode(binary_buffer.read()).decode('utf-8') diff --git a/trains/utilities/plotlympl/mpltools.py b/trains/utilities/plotlympl/mpltools.py new file mode 100644 index 00000000..a756ebdc --- /dev/null +++ b/trains/utilities/plotlympl/mpltools.py @@ -0,0 +1,600 @@ +""" +Tools + +A module for converting from mpl language to plotly language. + +""" +import math + +import warnings +import matplotlib.dates + + +def check_bar_match(old_bar, new_bar): + """Check if two bars belong in the same collection (bar chart). + + Positional arguments: + old_bar -- a previously sorted bar dictionary. + new_bar -- a new bar dictionary that needs to be sorted. + + """ + tests = [] + tests += (new_bar["orientation"] == old_bar["orientation"],) + tests += (new_bar["facecolor"] == old_bar["facecolor"],) + if new_bar["orientation"] == "v": + new_width = new_bar["x1"] - new_bar["x0"] + old_width = old_bar["x1"] - old_bar["x0"] + tests += (new_width - old_width < 0.000001,) + tests += (new_bar["y0"] == old_bar["y0"],) + elif new_bar["orientation"] == "h": + new_height = new_bar["y1"] - new_bar["y0"] + old_height = old_bar["y1"] - old_bar["y0"] + tests += (new_height - old_height < 0.000001,) + tests += (new_bar["x0"] == old_bar["x0"],) + if all(tests): + return True + else: + return False + + +def check_corners(inner_obj, outer_obj): + inner_corners = inner_obj.get_window_extent().corners() + outer_corners = outer_obj.get_window_extent().corners() + if inner_corners[0][0] < outer_corners[0][0]: + return False + elif inner_corners[0][1] < outer_corners[0][1]: + return False + elif inner_corners[3][0] > outer_corners[3][0]: + return False + elif inner_corners[3][1] > outer_corners[3][1]: + return False + else: + return True + + +def convert_dash(mpl_dash): + """Convert mpl line symbol to plotly line symbol and return symbol.""" + if mpl_dash in DASH_MAP: + return DASH_MAP[mpl_dash] + else: + dash_array = mpl_dash.split(",") + + if len(dash_array) < 2: + return "solid" + + # Catch the exception where the off length is zero, in case + # matplotlib 'solid' changes from '10,0' to 'N,0' + if math.isclose(float(dash_array[1]), 0.0): + return "solid" + + # If we can't find the dash pattern in the map, convert it + # into custom values in px, e.g. '7,5' -> '7px,5px' + dashpx = ",".join([x + "px" for x in dash_array]) + + # TODO: rewrite the convert_dash code + # only strings 'solid', 'dashed', etc allowed + if dashpx == "7.4px,3.2px": + dashpx = "dashed" + elif dashpx == "12.8px,3.2px,2.0px,3.2px": + dashpx = "dashdot" + elif dashpx == "2.0px,3.3px": + dashpx = "dotted" + return dashpx + + +def convert_path(path): + verts = path[0] # may use this later + code = tuple(path[1]) + if code in PATH_MAP: + return PATH_MAP[code] + else: + return None + + +def convert_symbol(mpl_symbol): + """Convert mpl marker symbol to plotly symbol and return symbol.""" + if isinstance(mpl_symbol, list): + symbol = list() + for s in mpl_symbol: + symbol += [convert_symbol(s)] + return symbol + elif mpl_symbol in SYMBOL_MAP: + return SYMBOL_MAP[mpl_symbol] + else: + return "circle" # default + + +def hex_to_rgb(value): + """ + Change a hex color to an rgb tuple + + :param (str|unicode) value: The hex string we want to convert. + :return: (int, int, int) The red, green, blue int-tuple. + + Example: + + '#FFFFFF' --> (255, 255, 255) + + """ + value = value.lstrip("#") + lv = len(value) + return tuple(int(value[i : i + lv // 3], 16) for i in range(0, lv, lv // 3)) + + +def merge_color_and_opacity(color, opacity): + """ + Merge hex color with an alpha (opacity) to get an rgba tuple. + + :param (str|unicode) color: A hex color string. + :param (float|int) opacity: A value [0, 1] for the 'a' in 'rgba'. + :return: (int, int, int, float) The rgba color and alpha tuple. + + """ + if color is None: # None can be used as a placeholder, just bail. + return None + + rgb_tup = hex_to_rgb(color) + if opacity is None: + return "rgb {}".format(rgb_tup) + + rgba_tup = rgb_tup + (opacity,) + return "rgba {}".format(rgba_tup) + + +def convert_va(mpl_va): + """Convert mpl vertical alignment word to equivalent HTML word. + + Text alignment specifiers from mpl differ very slightly from those used + in HTML. See the VA_MAP for more details. + + Positional arguments: + mpl_va -- vertical mpl text alignment spec. + + """ + if mpl_va in VA_MAP: + return VA_MAP[mpl_va] + else: + return None # let plotly figure it out! + + +def convert_x_domain(mpl_plot_bounds, mpl_max_x_bounds): + """Map x dimension of current plot to plotly's domain space. + + The bbox used to locate an axes object in mpl differs from the + method used to locate axes in plotly. The mpl version locates each + axes in the figure so that axes in a single-plot figure might have + the bounds, [0.125, 0.125, 0.775, 0.775] (x0, y0, width, height), + in mpl's figure coordinates. However, the axes all share one space in + plotly such that the domain will always be [0, 0, 1, 1] + (x0, y0, x1, y1). To convert between the two, the mpl figure bounds + need to be mapped to a [0, 1] domain for x and y. The margins set + upon opening a new figure will appropriately match the mpl margins. + + Optionally, setting margins=0 and simply copying the domains from + mpl to plotly would place axes appropriately. However, + this would throw off axis and title labeling. + + Positional arguments: + mpl_plot_bounds -- the (x0, y0, width, height) params for current ax ** + mpl_max_x_bounds -- overall (x0, x1) bounds for all axes ** + + ** these are all specified in mpl figure coordinates + + """ + mpl_x_dom = [mpl_plot_bounds[0], mpl_plot_bounds[0] + mpl_plot_bounds[2]] + plotting_width = mpl_max_x_bounds[1] - mpl_max_x_bounds[0] + x0 = (mpl_x_dom[0] - mpl_max_x_bounds[0]) / plotting_width + x1 = (mpl_x_dom[1] - mpl_max_x_bounds[0]) / plotting_width + return [x0, x1] + + +def convert_y_domain(mpl_plot_bounds, mpl_max_y_bounds): + """Map y dimension of current plot to plotly's domain space. + + The bbox used to locate an axes object in mpl differs from the + method used to locate axes in plotly. The mpl version locates each + axes in the figure so that axes in a single-plot figure might have + the bounds, [0.125, 0.125, 0.775, 0.775] (x0, y0, width, height), + in mpl's figure coordinates. However, the axes all share one space in + plotly such that the domain will always be [0, 0, 1, 1] + (x0, y0, x1, y1). To convert between the two, the mpl figure bounds + need to be mapped to a [0, 1] domain for x and y. The margins set + upon opening a new figure will appropriately match the mpl margins. + + Optionally, setting margins=0 and simply copying the domains from + mpl to plotly would place axes appropriately. However, + this would throw off axis and title labeling. + + Positional arguments: + mpl_plot_bounds -- the (x0, y0, width, height) params for current ax ** + mpl_max_y_bounds -- overall (y0, y1) bounds for all axes ** + + ** these are all specified in mpl figure coordinates + + """ + mpl_y_dom = [mpl_plot_bounds[1], mpl_plot_bounds[1] + mpl_plot_bounds[3]] + plotting_height = mpl_max_y_bounds[1] - mpl_max_y_bounds[0] + y0 = (mpl_y_dom[0] - mpl_max_y_bounds[0]) / plotting_height + y1 = (mpl_y_dom[1] - mpl_max_y_bounds[0]) / plotting_height + return [y0, y1] + + +def display_to_paper(x, y, layout): + """Convert mpl display coordinates to plotly paper coordinates. + + Plotly references object positions with an (x, y) coordinate pair in either + 'data' or 'paper' coordinates which reference actual data in a plot or + the entire plotly axes space where the bottom-left of the bottom-left + plot has the location (x, y) = (0, 0) and the top-right of the top-right + plot has the location (x, y) = (1, 1). Display coordinates in mpl reference + objects with an (x, y) pair in pixel coordinates, where the bottom-left + corner is at the location (x, y) = (0, 0) and the top-right corner is at + the location (x, y) = (figwidth*dpi, figheight*dpi). Here, figwidth and + figheight are in inches and dpi are the dots per inch resolution. + + """ + num_x = x - layout["margin"]["l"] + den_x = layout["width"] - (layout["margin"]["l"] + layout["margin"]["r"]) + num_y = y - layout["margin"]["b"] + den_y = layout["height"] - (layout["margin"]["b"] + layout["margin"]["t"]) + return num_x / den_x, num_y / den_y + + +def get_axes_bounds(fig): + """Return the entire axes space for figure. + + An axes object in mpl is specified by its relation to the figure where + (0,0) corresponds to the bottom-left part of the figure and (1,1) + corresponds to the top-right. Margins exist in matplotlib because axes + objects normally don't go to the edges of the figure. + + In plotly, the axes area (where all subplots go) is always specified with + the domain [0,1] for both x and y. This function finds the smallest box, + specified by two points, that all of the mpl axes objects fit into. This + box is then used to map mpl axes domains to plotly axes domains. + + """ + x_min, x_max, y_min, y_max = [], [], [], [] + for axes_obj in fig.get_axes(): + bounds = axes_obj.get_position().bounds + x_min.append(bounds[0]) + x_max.append(bounds[0] + bounds[2]) + y_min.append(bounds[1]) + y_max.append(bounds[1] + bounds[3]) + x_min, y_min, x_max, y_max = min(x_min), min(y_min), max(x_max), max(y_max) + return (x_min, x_max), (y_min, y_max) + + +def get_axis_mirror(main_spine, mirror_spine): + if main_spine and mirror_spine: + return "ticks" + elif main_spine and not mirror_spine: + return False + elif not main_spine and mirror_spine: + return False # can't handle this case yet! + else: + return False # nuttin'! + + +def get_bar_gap(bar_starts, bar_ends, tol=1e-10): + if len(bar_starts) == len(bar_ends) and len(bar_starts) > 1: + sides1 = bar_starts[1:] + sides2 = bar_ends[:-1] + gaps = [s2 - s1 for s2, s1 in zip(sides1, sides2)] + gap0 = gaps[0] + uniform = all([abs(gap0 - gap) < tol for gap in gaps]) + if uniform: + return gap0 + + +def convert_rgba_array(color_list): + clean_color_list = list() + for c in color_list: + clean_color_list += [ + (dict(r=int(c[0] * 255), g=int(c[1] * 255), b=int(c[2] * 255), a=c[3])) + ] + plotly_colors = list() + for rgba in clean_color_list: + plotly_colors += ["rgba({r},{g},{b},{a})".format(**rgba)] + if len(plotly_colors) == 1: + return plotly_colors[0] + else: + return plotly_colors + + +def convert_path_array(path_array): + symbols = list() + for path in path_array: + symbols += [convert_path(path)] + if len(symbols) == 1: + return symbols[0] + else: + return symbols + + +def convert_linewidth_array(width_array): + if len(width_array) == 1: + return width_array[0] + else: + return width_array + + +def convert_size_array(size_array): + size = [math.sqrt(s) for s in size_array] + if len(size) == 1: + return size[0] + else: + return size + + +def get_markerstyle_from_collection(props): + markerstyle = dict( + alpha=None, + facecolor=convert_rgba_array(props["styles"]["facecolor"]), + marker=convert_path_array(props["paths"]), + edgewidth=convert_linewidth_array(props["styles"]["linewidth"]), + # markersize=convert_size_array(props['styles']['size']), # TODO! + markersize=convert_size_array(props["mplobj"].get_sizes()), + edgecolor=convert_rgba_array(props["styles"]["edgecolor"]), + ) + return markerstyle + + +def get_rect_xmin(data): + """Find minimum x value from four (x,y) vertices.""" + return min(data[0][0], data[1][0], data[2][0], data[3][0]) + + +def get_rect_xmax(data): + """Find maximum x value from four (x,y) vertices.""" + return max(data[0][0], data[1][0], data[2][0], data[3][0]) + + +def get_rect_ymin(data): + """Find minimum y value from four (x,y) vertices.""" + return min(data[0][1], data[1][1], data[2][1], data[3][1]) + + +def get_rect_ymax(data): + """Find maximum y value from four (x,y) vertices.""" + return max(data[0][1], data[1][1], data[2][1], data[3][1]) + + +def get_spine_visible(ax, spine_key): + """Return some spine parameters for the spine, `spine_key`.""" + spine = ax.spines[spine_key] + ax_frame_on = ax.get_frame_on() + spine_frame_like = spine.is_frame_like() if hasattr(spine, 'is_frame_like') else True + if not spine.get_visible(): + return False + elif not spine._edgecolor[-1]: # user's may have set edgecolor alpha==0 + return False + elif not ax_frame_on and spine_frame_like: + return False + elif ax_frame_on and spine_frame_like: + return True + elif not ax_frame_on and not spine_frame_like: + return True # we've already checked for that it's visible. + else: + return False # oh man, and i thought we exhausted the options... + + +def is_bar(bar_containers, **props): + """A test to decide whether a path is a bar from a vertical bar chart.""" + + # is this patch in a bar container? + for container in bar_containers: + if props["mplobj"] in container: + return True + return False + + +def make_bar(**props): + """Make an intermediate bar dictionary. + + This creates a bar dictionary which aids in the comparison of new bars to + old bars from other bar chart (patch) collections. This is not the + dictionary that needs to get passed to plotly as a data dictionary. That + happens in PlotlyRenderer in that class's draw_bar method. In other + words, this dictionary describes a SINGLE bar, whereas, plotly will + require a set of bars to be passed in a data dictionary. + + """ + return { + "bar": props["mplobj"], + "x0": get_rect_xmin(props["data"]), + "y0": get_rect_ymin(props["data"]), + "x1": get_rect_xmax(props["data"]), + "y1": get_rect_ymax(props["data"]), + "alpha": props["style"]["alpha"], + "edgecolor": props["style"]["edgecolor"], + "facecolor": props["style"]["facecolor"], + "edgewidth": props["style"]["edgewidth"], + "dasharray": props["style"]["dasharray"], + "zorder": props["style"]["zorder"], + } + + +def prep_ticks(ax, index, ax_type, props): + """Prepare axis obj belonging to axes obj. + + positional arguments: + ax - the mpl axes instance + index - the index of the axis in `props` + ax_type - 'x' or 'y' (for now) + props - an mplexporter poperties dictionary + + """ + axis_dict = dict() + if ax_type == "x": + axis = ax.get_xaxis() + elif ax_type == "y": + axis = ax.get_yaxis() + else: + return dict() # whoops! + + scale = props["axes"][index]["scale"] + if scale == "linear": + # get tick location information + try: + tickvalues = props["axes"][index]["tickvalues"] + tick0 = tickvalues[0] + dticks = [ + round(tickvalues[i] - tickvalues[i - 1], 12) + for i in range(1, len(tickvalues) - 1) + ] + if all([dticks[i] == dticks[i - 1] for i in range(1, len(dticks) - 1)]): + dtick = tickvalues[1] - tickvalues[0] + else: + warnings.warn( + "'linear' {0}-axis tick spacing not even, " + "ignoring mpl tick formatting.".format(ax_type) + ) + raise TypeError + except (IndexError, TypeError): + axis_dict["nticks"] = props["axes"][index]["nticks"] + else: + axis_dict["tick0"] = tick0 + axis_dict["dtick"] = dtick + axis_dict["tickmode"] = None + elif scale == "log": + try: + axis_dict["tick0"] = props["axes"][index]["tickvalues"][0] + axis_dict["dtick"] = ( + props["axes"][index]["tickvalues"][1] + - props["axes"][index]["tickvalues"][0] + ) + axis_dict["tickmode"] = None + except (IndexError, TypeError): + axis_dict = dict(nticks=props["axes"][index]["nticks"]) + base = axis.get_transform().base + if base == 10: + if ax_type == "x": + axis_dict["range"] = [ + math.log10(props["xlim"][0]), + math.log10(props["xlim"][1]), + ] + elif ax_type == "y": + axis_dict["range"] = [ + math.log10(props["ylim"][0]), + math.log10(props["ylim"][1]), + ] + else: + axis_dict = dict(range=None, type="linear") + warnings.warn( + "Converted non-base10 {0}-axis log scale to 'linear'" "".format(ax_type) + ) + else: + return dict() + # get tick label formatting information + formatter = axis.get_major_formatter().__class__.__name__ + if ax_type == "x" and "DateFormatter" in formatter: + axis_dict["type"] = "date" + try: + axis_dict["tick0"] = mpl_dates_to_datestrings(axis_dict["tick0"], formatter) + except KeyError: + pass + finally: + axis_dict.pop("dtick", None) + axis_dict.pop("tickmode", None) + axis_dict["range"] = mpl_dates_to_datestrings(props["xlim"], formatter) + + if formatter == "LogFormatterMathtext": + axis_dict["exponentformat"] = "e" + return axis_dict + + +def prep_xy_axis(ax, props, x_bounds, y_bounds): + xaxis = dict( + type=props["axes"][0]["scale"], + range=list(props["xlim"]), + showgrid=props["axes"][0]["grid"]["gridOn"], + domain=convert_x_domain(props["bounds"], x_bounds), + side=props["axes"][0]["position"], + tickfont=dict(size=props["axes"][0]["fontsize"]), + ) + xaxis.update(prep_ticks(ax, 0, "x", props)) + yaxis = dict( + type=props["axes"][1]["scale"], + range=list(props["ylim"]), + showgrid=props["axes"][1]["grid"]["gridOn"], + domain=convert_y_domain(props["bounds"], y_bounds), + side=props["axes"][1]["position"], + tickfont=dict(size=props["axes"][1]["fontsize"]), + ) + yaxis.update(prep_ticks(ax, 1, "y", props)) + return xaxis, yaxis + + +def mpl_dates_to_datestrings(dates, mpl_formatter): + """Convert matplotlib dates to iso-formatted-like time strings. + + Plotly's accepted format: "YYYY-MM-DD HH:MM:SS" (e.g., 2001-01-01 00:00:00) + + Info on mpl dates: http://matplotlib.org/api/dates_api.html + + """ + _dates = dates + + # this is a pandas datetime formatter, times show up in floating point days + # since the epoch (1970-01-01T00:00:00+00:00) + if mpl_formatter == "TimeSeries_DateFormatter": + try: + dates = matplotlib.dates.epoch2num([date * 24 * 60 * 60 for date in dates]) + dates = matplotlib.dates.num2date(dates) + except: + return _dates + + # the rest of mpl dates are in floating point days since + # (0001-01-01T00:00:00+00:00) + 1. I.e., (0001-01-01T00:00:00+00:00) == 1.0 + # according to mpl --> try num2date(1) + else: + try: + dates = matplotlib.dates.num2date(dates) + except: + return _dates + + time_stings = [ + " ".join(date.isoformat().split("+")[0].split("T")) for date in dates + ] + return time_stings + + +# dashed is dash in matplotlib +DASH_MAP = { + "10,0": "solid", + "6,6": "dash", + "2,2": "circle", + "4,4,2,4": "dashdot", + "none": "solid", + "7.4,3.2": "dash", +} + +PATH_MAP = { + ("M", "C", "C", "C", "C", "C", "C", "C", "C", "Z"): "o", + ("M", "L", "L", "L", "L", "L", "L", "L", "L", "L", "Z"): "*", + ("M", "L", "L", "L", "L", "L", "L", "L", "Z"): "8", + ("M", "L", "L", "L", "L", "L", "Z"): "h", + ("M", "L", "L", "L", "L", "Z"): "p", + ("M", "L", "M", "L", "M", "L"): "1", + ("M", "L", "L", "L", "Z"): "s", + ("M", "L", "M", "L"): "+", + ("M", "L", "L", "Z"): "^", + ("M", "L"): "|", +} + +SYMBOL_MAP = { + "o": "circle", + "v": "triangle-down", + "^": "triangle-up", + "<": "triangle-left", + ">": "triangle-right", + "s": "square", + "+": "cross", + "x": "x", + "*": "star", + "D": "diamond", + "d": "diamond", +} + +VA_MAP = {"center": "middle", "baseline": "bottom", "top": "top"} diff --git a/trains/utilities/plotlympl/renderer.py b/trains/utilities/plotlympl/renderer.py new file mode 100644 index 00000000..ab6c5ac7 --- /dev/null +++ b/trains/utilities/plotlympl/renderer.py @@ -0,0 +1,768 @@ +""" +Renderer Module + +This module defines the PlotlyRenderer class and a single function, +fig_to_plotly, which is intended to be the main way that user's will interact +with the matplotlylib package. + +""" +from __future__ import absolute_import + +import six +import warnings + +from .mplexporter import Renderer +from . import mpltools + + +# Warning format +def warning_on_one_line(msg, category, filename, lineno, file=None, line=None): + return "%s:%s: %s:\n\n%s\n\n" % (filename, lineno, category.__name__, msg) + + +warnings.formatwarning = warning_on_one_line + + +class PlotlyRenderer(Renderer): + """A renderer class inheriting from base for rendering mpl plots in plotly. + + A renderer class to be used with an exporter for rendering matplotlib + plots in Plotly. This module defines the PlotlyRenderer class which handles + the creation of the JSON structures that get sent to plotly. + + All class attributes available are defined in __init__(). + + Basic Usage: + + # (mpl code) # + fig = gcf() + renderer = PlotlyRenderer(fig) + exporter = Exporter(renderer) + exporter.run(fig) # ... et voila + + """ + + def __init__(self): + """Initialize PlotlyRenderer obj. + + PlotlyRenderer obj is called on by an Exporter object to draw + matplotlib objects like figures, axes, text, etc. + + All class attributes are listed here in the __init__ method. + + """ + self.plotly_fig = dict(data=[], layout={}) + self.mpl_fig = None + self.current_mpl_ax = None + self.bar_containers = None + self.current_bars = [] + self.axis_ct = 0 + self.x_is_mpl_date = False + self.mpl_x_bounds = (0, 1) + self.mpl_y_bounds = (0, 1) + self.msg = "Initialized PlotlyRenderer\n" + + def open_figure(self, fig, props): + """Creates a new figure by beginning to fill out layout dict. + + The 'autosize' key is set to false so that the figure will mirror + sizes set by mpl. The 'hovermode' key controls what shows up when you + mouse around a figure in plotly, it's set to show the 'closest' point. + + Positional agurments: + fig -- a matplotlib.figure.Figure object. + props.keys(): [ + 'figwidth', + 'figheight', + 'dpi' + ] + + """ + self.msg += "Opening figure\n" + self.mpl_fig = fig + self.plotly_fig["layout"] = dict( + width=int(props["figwidth"] * props["dpi"]), + height=int(props["figheight"] * props["dpi"]), + autosize=False, + hovermode="closest", + ) + self.mpl_x_bounds, self.mpl_y_bounds = mpltools.get_axes_bounds(fig) + margin = dict( + l=int(self.mpl_x_bounds[0] * self.plotly_fig["layout"]["width"]), + r=int((1 - self.mpl_x_bounds[1]) * self.plotly_fig["layout"]["width"]), + t=int((1 - self.mpl_y_bounds[1]) * self.plotly_fig["layout"]["height"]), + b=int(self.mpl_y_bounds[0] * self.plotly_fig["layout"]["height"]), + pad=0, + ) + self.plotly_fig["layout"]["margin"] = margin + + def close_figure(self, fig): + """Closes figure by cleaning up data and layout dictionaries. + + The PlotlyRenderer's job is to create an appropriate set of data and + layout dictionaries. When the figure is closed, some cleanup and + repair is necessary. This method removes inappropriate dictionary + entries, freeing up Plotly to use defaults and best judgements to + complete the entries. This method is called by an Exporter object. + + Positional arguments: + fig -- a matplotlib.figure.Figure object. + + """ + self.plotly_fig["layout"]["showlegend"] = False + self.msg += "Closing figure\n" + + def open_axes(self, ax, props): + """Setup a new axes object (subplot in plotly). + + Plotly stores information about subplots in different 'xaxis' and + 'yaxis' objects which are numbered. These are just dictionaries + included in the layout dictionary. This function takes information + from the Exporter, fills in appropriate dictionary entries, + and updates the layout dictionary. PlotlyRenderer keeps track of the + number of plots by incrementing the axis_ct attribute. + + Setting the proper plot domain in plotly is a bit tricky. Refer to + the documentation for mpltools.convert_x_domain and + mpltools.convert_y_domain. + + Positional arguments: + ax -- an mpl axes object. This will become a subplot in plotly. + props.keys() -- [ + 'axesbg', (background color for axes obj) + 'axesbgalpha', (alpha, or opacity for background) + 'bounds', ((x0, y0, width, height) for axes) + 'dynamic', (zoom/pan-able?) + 'axes', (list: [xaxis, yaxis]) + 'xscale', (log, linear, or date) + 'yscale', + 'xlim', (range limits for x) + 'ylim', + 'xdomain' (xdomain=xlim, unless it's a date) + 'ydomain' + ] + + """ + self.msg += " Opening axes\n" + self.current_mpl_ax = ax + self.bar_containers = [ + c + for c in ax.containers # empty is OK + if c.__class__.__name__ == "BarContainer" + ] + self.current_bars = [] + + # set defaults in axes + xaxis = dict( + anchor="y{0}".format(self.axis_ct or ''), zeroline=False, ticks="inside" + ) + yaxis = dict( + anchor="x{0}".format(self.axis_ct or ''), zeroline=False, ticks="inside" + ) + # update defaults with things set in mpl + mpl_xaxis, mpl_yaxis = mpltools.prep_xy_axis( + ax=ax, props=props, x_bounds=self.mpl_x_bounds, y_bounds=self.mpl_y_bounds + ) + xaxis.update(mpl_xaxis) + yaxis.update(mpl_yaxis) + bottom_spine = mpltools.get_spine_visible(ax, "bottom") + top_spine = mpltools.get_spine_visible(ax, "top") + left_spine = mpltools.get_spine_visible(ax, "left") + right_spine = mpltools.get_spine_visible(ax, "right") + xaxis["mirror"] = mpltools.get_axis_mirror(bottom_spine, top_spine) + yaxis["mirror"] = mpltools.get_axis_mirror(left_spine, right_spine) + xaxis["showline"] = bottom_spine + yaxis["showline"] = top_spine + + # put axes in our figure + self.plotly_fig["layout"]["xaxis{0}".format(self.axis_ct or '')] = xaxis + self.plotly_fig["layout"]["yaxis{0}".format(self.axis_ct or '')] = yaxis + + # let all subsequent dates be handled properly if required + + if "type" in dir(xaxis) and xaxis["type"] == "date": + self.x_is_mpl_date = True + + self.axis_ct += 1 + + def close_axes(self, ax): + """Close the axes object and clean up. + + Bars from bar charts are given to PlotlyRenderer one-by-one, + thus they need to be taken care of at the close of each axes object. + The self.current_bars variable should be empty unless a bar + chart has been created. + + Positional arguments: + ax -- an mpl axes object, not required at this time. + + """ + self.draw_bars(self.current_bars) + self.msg += " Closing axes\n" + self.x_is_mpl_date = False + + def draw_bars(self, bars): + + # sort bars according to bar containers + mpl_traces = [] + for container in self.bar_containers: + mpl_traces.append( + [ + bar_props + for bar_props in self.current_bars + if bar_props["mplobj"] in container + ] + ) + for trace in mpl_traces: + self.draw_bar(trace) + + def draw_bar(self, coll): + """Draw a collection of similar patches as a bar chart. + + After bars are sorted, an appropriate data dictionary must be created + to tell plotly about this data. Just like draw_line or draw_markers, + draw_bar translates patch/path information into something plotly + understands. + + Positional arguments: + patch_coll -- a collection of patches to be drawn as a bar chart. + + """ + tol = 1e-10 + trace = [mpltools.make_bar(**bar_props) for bar_props in coll] + widths = [bar_props["x1"] - bar_props["x0"] for bar_props in trace] + heights = [bar_props["y1"] - bar_props["y0"] for bar_props in trace] + vertical = abs(sum(widths[0] - widths[iii] for iii in range(len(widths)))) < tol + horizontal = ( + abs(sum(heights[0] - heights[iii] for iii in range(len(heights)))) < tol + ) + if vertical and horizontal: + # Check for monotonic x. Can't both be true! + x_zeros = [bar_props["x0"] for bar_props in trace] + if all( + (x_zeros[iii + 1] > x_zeros[iii] for iii in range(len(x_zeros[:-1]))) + ): + orientation = "v" + else: + orientation = "h" + elif vertical: + orientation = "v" + else: + orientation = "h" + if orientation == "v": + self.msg += " Attempting to draw a vertical bar chart\n" + old_heights = [bar_props["y1"] for bar_props in trace] + for bar in trace: + bar["y0"], bar["y1"] = 0, bar["y1"] - bar["y0"] + new_heights = [bar_props["y1"] for bar_props in trace] + # check if we're stacked or not... + for old, new in zip(old_heights, new_heights): + if abs(old - new) > tol: + self.plotly_fig["layout"]["barmode"] = "stack" + self.plotly_fig["layout"]["hovermode"] = "x" + x = [bar["x0"] + (bar["x1"] - bar["x0"]) / 2 for bar in trace] + y = [bar["y1"] for bar in trace] + bar_gap = mpltools.get_bar_gap( + [bar["x0"] for bar in trace], [bar["x1"] for bar in trace] + ) + if self.x_is_mpl_date: + x = [bar["x0"] for bar in trace] + formatter = ( + self.current_mpl_ax.get_xaxis() + .get_major_formatter() + .__class__.__name__ + ) + x = mpltools.mpl_dates_to_datestrings(x, formatter) + else: + self.msg += " Attempting to draw a horizontal bar chart\n" + old_rights = [bar_props["x1"] for bar_props in trace] + for bar in trace: + bar["x0"], bar["x1"] = 0, bar["x1"] - bar["x0"] + new_rights = [bar_props["x1"] for bar_props in trace] + # check if we're stacked or not... + for old, new in zip(old_rights, new_rights): + if abs(old - new) > tol: + self.plotly_fig["layout"]["barmode"] = "stack" + self.plotly_fig["layout"]["hovermode"] = "y" + x = [bar["x1"] for bar in trace] + y = [bar["y0"] + (bar["y1"] - bar["y0"]) / 2 for bar in trace] + bar_gap = mpltools.get_bar_gap( + [bar["y0"] for bar in trace], [bar["y1"] for bar in trace] + ) + bar = dict( + type="bar", + orientation=orientation, + x=x, + y=y, + xaxis="x{0}".format(self.axis_ct), + yaxis="y{0}".format(self.axis_ct), + opacity=trace[0]["alpha"], # TODO: get all alphas if array? + marker=dict( + color=trace[0]["facecolor"], # TODO: get all + line=dict(width=trace[0]["edgewidth"]), + ), + ) # TODO ditto + if len(bar["x"]) > 1: + self.msg += " Heck yeah, I drew that bar chart\n" + self.plotly_fig['data'].append(bar) + if bar_gap is not None: + self.plotly_fig["layout"]["bargap"] = bar_gap + else: + self.msg += " Bar chart not drawn\n" + warnings.warn( + "found box chart data with length <= 1, " + "assuming data redundancy, not plotting." + ) + + def draw_marked_line(self, **props): + """Create a data dict for a line obj. + + This will draw 'lines', 'markers', or 'lines+markers'. + + props.keys() -- [ + 'coordinates', ('data', 'axes', 'figure', or 'display') + 'data', (a list of xy pairs) + 'mplobj', (the matplotlib.lines.Line2D obj being rendered) + 'label', (the name of the Line2D obj being rendered) + 'linestyle', (linestyle dict, can be None, see below) + 'markerstyle', (markerstyle dict, can be None, see below) + ] + + props['linestyle'].keys() -- [ + 'alpha', (opacity of Line2D obj) + 'color', (color of the line if it exists, not the marker) + 'linewidth', + 'dasharray', (code for linestyle, see DASH_MAP in mpltools.py) + 'zorder', (viewing precedence when stacked with other objects) + ] + + props['markerstyle'].keys() -- [ + 'alpha', (opacity of Line2D obj) + 'marker', (the mpl marker symbol, see SYMBOL_MAP in mpltools.py) + 'facecolor', (color of the marker face) + 'edgecolor', (color of the marker edge) + 'edgewidth', (width of marker edge) + 'markerpath', (an SVG path for drawing the specified marker) + 'zorder', (viewing precedence when stacked with other objects) + ] + + """ + self.msg += " Attempting to draw a line " + line, marker = {}, {} + if props["linestyle"] and props["markerstyle"]: + self.msg += "... with both lines+markers\n" + mode = "lines+markers" + elif props["linestyle"]: + self.msg += "... with just lines\n" + mode = "lines" + elif props["markerstyle"]: + self.msg += "... with just markers\n" + mode = "markers" + if props["linestyle"]: + color = mpltools.merge_color_and_opacity( + props["linestyle"]["color"], props["linestyle"]["alpha"] + ) + + # print(mpltools.convert_dash(props['linestyle']['dasharray'])) + line = dict( + color=color, + width=props["linestyle"]["linewidth"], + dash=mpltools.convert_dash(props["linestyle"]["dasharray"]), + ) + if props["markerstyle"]: + marker = dict( + opacity=props["markerstyle"]["alpha"], + color=props["markerstyle"]["facecolor"], + symbol=mpltools.convert_symbol(props["markerstyle"]["marker"]), + size=props["markerstyle"]["markersize"], + line=dict( + color=props["markerstyle"]["edgecolor"], + width=props["markerstyle"]["edgewidth"], + ), + ) + if props["coordinates"] == "data": + marked_line = dict( + type="scatter", + mode=mode, + name=( + str(props["label"]) + if isinstance(props["label"], six.string_types) + else props["label"] + ), + x=[xy_pair[0] for xy_pair in props["data"]], + y=[xy_pair[1] for xy_pair in props["data"]], + xaxis="x{0}".format(self.axis_ct), + yaxis="y{0}".format(self.axis_ct), + line=line, + marker=marker, + ) + if self.x_is_mpl_date: + formatter = ( + self.current_mpl_ax.get_xaxis() + .get_major_formatter() + .__class__.__name__ + ) + marked_line["x"] = mpltools.mpl_dates_to_datestrings( + marked_line["x"], formatter + ) + self.plotly_fig['data'].append(marked_line) + self.msg += " Heck yeah, I drew that line\n" + else: + self.msg += " Line didn't have 'data' coordinates, " "not drawing\n" + warnings.warn( + "Bummer! Plotly can currently only draw Line2D " + "objects from matplotlib that are in 'data' " + "coordinates!" + ) + + def draw_image(self, **props): + """Draw image. + + Not implemented yet! + + """ + self.msg += " Attempting to draw image\n" + self.msg += " Not drawing image\n" + warnings.warn( + "Aw. Snap! You're gonna have to hold off on " + "the selfies for now. Plotly can't import " + "images from matplotlib yet!" + ) + + def draw_path_collection(self, **props): + """Add a path collection to data list as a scatter plot. + + Current implementation defaults such collections as scatter plots. + Matplotlib supports collections that have many of the same parameters + in common like color, size, path, etc. However, they needn't all be + the same. Plotly does not currently support such functionality and + therefore, the style for the first object is taken and used to define + the remaining paths in the collection. + + props.keys() -- [ + 'paths', (structure: [vertices, path_code]) + 'path_coordinates', ('data', 'axes', 'figure', or 'display') + 'path_transforms', (mpl transform, including Affine2D matrix) + 'offsets', (offset from axes, helpful if in 'data') + 'offset_coordinates', ('data', 'axes', 'figure', or 'display') + 'offset_order', + 'styles', (style dict, see below) + 'mplobj' (the collection obj being drawn) + ] + + props['styles'].keys() -- [ + 'linewidth', (one or more linewidths) + 'facecolor', (one or more facecolors for path) + 'edgecolor', (one or more edgecolors for path) + 'alpha', (one or more opacites for path) + 'zorder', (precedence when stacked) + ] + + """ + self.msg += " Attempting to draw a path collection\n" + if props["offset_coordinates"] == "data": + markerstyle = mpltools.get_markerstyle_from_collection(props) + scatter_props = { + "coordinates": "data", + "data": props["offsets"], + "label": None, + "markerstyle": markerstyle, + "linestyle": None, + } + self.msg += " Drawing path collection as markers\n" + self.draw_marked_line(**scatter_props) + else: + self.msg += " Path collection not linked to 'data', " "not drawing\n" + warnings.warn( + "Dang! That path collection is out of this " + "world. I totally don't know what to do with " + "it yet! Plotly can only import path " + "collections linked to 'data' coordinates" + ) + + def draw_path(self, **props): + """Draw path, currently only attempts to draw bar charts. + + This function attempts to sort a given path into a collection of + horizontal or vertical bar charts. Most of the actual code takes + place in functions from mpltools.py. + + props.keys() -- [ + 'data', (a list of verticies for the path) + 'coordinates', ('data', 'axes', 'figure', or 'display') + 'pathcodes', (code for the path, structure: ['M', 'L', 'Z', etc.]) + 'style', (style dict, see below) + 'mplobj' (the mpl path object) + ] + + props['style'].keys() -- [ + 'alpha', (opacity of path obj) + 'edgecolor', + 'facecolor', + 'edgewidth', + 'dasharray', (style for path's enclosing line) + 'zorder' (precedence of obj when stacked) + ] + + """ + self.msg += " Attempting to draw a path\n" + is_bar = mpltools.is_bar(self.current_mpl_ax.containers, **props) + if is_bar: + self.current_bars += [props] + else: + self.msg += " This path isn't a bar, not drawing\n" + warnings.warn( + "I found a path object that I don't think is part " + "of a bar chart. Ignoring." + ) + + def draw_text(self, **props): + """Create an annotation dict for a text obj. + + Currently, plotly uses either 'page' or 'data' to reference + annotation locations. These refer to 'display' and 'data', + respectively for the 'coordinates' key used in the Exporter. + Appropriate measures are taken to transform text locations to + reference one of these two options. + + props.keys() -- [ + 'text', (actual content string, not the text obj) + 'position', (an x, y pair, not an mpl Bbox) + 'coordinates', ('data', 'axes', 'figure', 'display') + 'text_type', ('title', 'xlabel', or 'ylabel') + 'style', (style dict, see below) + 'mplobj' (actual mpl text object) + ] + + props['style'].keys() -- [ + 'alpha', (opacity of text) + 'fontsize', (size in points of text) + 'color', (hex color) + 'halign', (horizontal alignment, 'left', 'center', or 'right') + 'valign', (vertical alignment, 'baseline', 'center', or 'top') + 'rotation', + 'zorder', (precedence of text when stacked with other objs) + ] + + """ + self.msg += " Attempting to draw an mpl text object\n" + if not mpltools.check_corners(props["mplobj"], self.mpl_fig): + warnings.warn( + "Looks like the annotation(s) you are trying \n" + "to draw lies/lay outside the given figure size.\n\n" + "Therefore, the resulting Plotly figure may not be \n" + "large enough to view the full text. To adjust \n" + "the size of the figure, use the 'width' and \n" + "'height' keys in the Layout object. Alternatively,\n" + "use the Margin object to adjust the figure's margins." + ) + align = props["mplobj"]._multialignment + if not align: + align = props["style"]["halign"] # mpl default + if "annotations" not in self.plotly_fig["layout"]: + self.plotly_fig["layout"]["annotations"] = [] + if props["text_type"] == "xlabel": + self.msg += " Text object is an xlabel\n" + self.draw_xlabel(**props) + elif props["text_type"] == "ylabel": + self.msg += " Text object is a ylabel\n" + self.draw_ylabel(**props) + elif props["text_type"] == "title": + self.msg += " Text object is a title\n" + self.draw_title(**props) + else: # just a regular text annotation... + self.msg += " Text object is a normal annotation\n" + if props["coordinates"] != "data": + self.msg += ( + " Text object isn't linked to 'data' " "coordinates\n" + ) + x_px, y_px = ( + props["mplobj"].get_transform().transform(props["position"]) + ) + x, y = mpltools.display_to_paper(x_px, y_px, self.plotly_fig["layout"]) + xref = "paper" + yref = "paper" + xanchor = props["style"]["halign"] # no difference here! + yanchor = mpltools.convert_va(props["style"]["valign"]) + else: + self.msg += " Text object is linked to 'data' " "coordinates\n" + x, y = props["position"] + axis_ct = self.axis_ct + xaxis = self.plotly_fig["layout"]["xaxis{0}".format(axis_ct)] + yaxis = self.plotly_fig["layout"]["yaxis{0}".format(axis_ct)] + if ( + xaxis["range"][0] < x < xaxis["range"][1] + and yaxis["range"][0] < y < yaxis["range"][1] + ): + xref = "x{0}".format(self.axis_ct) + yref = "y{0}".format(self.axis_ct) + else: + self.msg += ( + " Text object is outside " + "plotting area, making 'paper' reference.\n" + ) + x_px, y_px = ( + props["mplobj"].get_transform().transform(props["position"]) + ) + x, y = mpltools.display_to_paper( + x_px, y_px, self.plotly_fig["layout"] + ) + xref = "paper" + yref = "paper" + xanchor = props["style"]["halign"] # no difference here! + yanchor = mpltools.convert_va(props["style"]["valign"]) + annotation = dict( + text=( + str(props["text"]) + if isinstance(props["text"], six.string_types) + else props["text"] + ), + opacity=props["style"]["alpha"], + x=x, + y=y, + xref=xref, + yref=yref, + align=align, + xanchor=xanchor, + yanchor=yanchor, + showarrow=False, # change this later? + font=dict( + color=props["style"]["color"], size=props["style"]["fontsize"] + ), + ) + self.plotly_fig["layout"]["annotations"] += (annotation,) + self.msg += " Heck, yeah I drew that annotation\n" + + def draw_title(self, **props): + """Add a title to the current subplot in layout dictionary. + + If there exists more than a single plot in the figure, titles revert + to 'page'-referenced annotations. + + props.keys() -- [ + 'text', (actual content string, not the text obj) + 'position', (an x, y pair, not an mpl Bbox) + 'coordinates', ('data', 'axes', 'figure', 'display') + 'text_type', ('title', 'xlabel', or 'ylabel') + 'style', (style dict, see below) + 'mplobj' (actual mpl text object) + ] + + props['style'].keys() -- [ + 'alpha', (opacity of text) + 'fontsize', (size in points of text) + 'color', (hex color) + 'halign', (horizontal alignment, 'left', 'center', or 'right') + 'valign', (vertical alignment, 'baseline', 'center', or 'top') + 'rotation', + 'zorder', (precedence of text when stacked with other objs) + ] + + """ + self.msg += " Attempting to draw a title\n" + if len(self.mpl_fig.axes) > 1: + self.msg += ( + " More than one subplot, adding title as " "annotation\n" + ) + x_px, y_px = props["mplobj"].get_transform().transform(props["position"]) + x, y = mpltools.display_to_paper(x_px, y_px, self.plotly_fig["layout"]) + annotation = dict( + text=props["text"], + font=dict( + color=props["style"]["color"], size=props["style"]["fontsize"] + ), + xref="paper", + yref="paper", + x=x, + y=y, + xanchor="center", + yanchor="bottom", + showarrow=False, # no arrow for a title! + ) + self.plotly_fig["layout"]["annotations"] += (annotation,) + else: + self.msg += ( + " Only one subplot found, adding as a " "plotly title\n" + ) + self.plotly_fig["layout"]["title"] = props["text"] + titlefont = dict( + size=props["style"]["fontsize"], color=props["style"]["color"] + ) + self.plotly_fig["layout"]["titlefont"] = titlefont + + def draw_xlabel(self, **props): + """Add an xaxis label to the current subplot in layout dictionary. + + props.keys() -- [ + 'text', (actual content string, not the text obj) + 'position', (an x, y pair, not an mpl Bbox) + 'coordinates', ('data', 'axes', 'figure', 'display') + 'text_type', ('title', 'xlabel', or 'ylabel') + 'style', (style dict, see below) + 'mplobj' (actual mpl text object) + ] + + props['style'].keys() -- [ + 'alpha', (opacity of text) + 'fontsize', (size in points of text) + 'color', (hex color) + 'halign', (horizontal alignment, 'left', 'center', or 'right') + 'valign', (vertical alignment, 'baseline', 'center', or 'top') + 'rotation', + 'zorder', (precedence of text when stacked with other objs) + ] + + """ + self.msg += " Adding xlabel\n" + axis_key = "xaxis{0}".format(self.axis_ct) + self.plotly_fig["layout"][axis_key]["title"] = str(props["text"]) + titlefont = dict(size=props["style"]["fontsize"], color=props["style"]["color"]) + self.plotly_fig["layout"][axis_key]["titlefont"] = titlefont + + def draw_ylabel(self, **props): + """Add a yaxis label to the current subplot in layout dictionary. + + props.keys() -- [ + 'text', (actual content string, not the text obj) + 'position', (an x, y pair, not an mpl Bbox) + 'coordinates', ('data', 'axes', 'figure', 'display') + 'text_type', ('title', 'xlabel', or 'ylabel') + 'style', (style dict, see below) + 'mplobj' (actual mpl text object) + ] + + props['style'].keys() -- [ + 'alpha', (opacity of text) + 'fontsize', (size in points of text) + 'color', (hex color) + 'halign', (horizontal alignment, 'left', 'center', or 'right') + 'valign', (vertical alignment, 'baseline', 'center', or 'top') + 'rotation', + 'zorder', (precedence of text when stacked with other objs) + ] + + """ + self.msg += " Adding ylabel\n" + axis_key = "yaxis{0}".format(self.axis_ct) + self.plotly_fig["layout"][axis_key]["title"] = props["text"] + titlefont = dict(size=props["style"]["fontsize"], color=props["style"]["color"]) + self.plotly_fig["layout"][axis_key]["titlefont"] = titlefont + + def resize(self): + """Revert figure layout to allow plotly to resize. + + By default, PlotlyRenderer tries its hardest to precisely mimic an + mpl figure. However, plotly is pretty good with aesthetics. By + running PlotlyRenderer.resize(), layout parameters are deleted. This + lets plotly choose them instead of mpl. + + """ + self.msg += "Resizing figure, deleting keys from layout\n" + for key in ["width", "height", "autosize", "margin"]: + try: + del self.plotly_fig["layout"][key] + except (KeyError, AttributeError): + pass + + def strip_style(self): + self.msg += "Stripping mpl style is no longer supported\n"