Add direct plotly figure reporting (see issue #136)

This commit is contained in:
allegroai 2020-06-14 00:01:30 +03:00
parent 8a5f6b7d02
commit 20a9f0997d
3 changed files with 84 additions and 30 deletions

14
examples/report_plotly.py Normal file
View File

@ -0,0 +1,14 @@
# TRAINS - Example of Plotly integration and reporting
#
from trains import Task
import plotly.express as px
task = Task.init('examples', 'plotly report')
df = px.data.iris()
fig = px.scatter(df, x="sepal_width", y="sepal_length", color="species", marginal_y="rug", marginal_x="histogram")
task.get_logger().report_plotly(title="iris", series="sepal", iteration=0, figure=fig)
print('done')

View File

@ -176,10 +176,16 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param iter: Iteration number
:type value: int
"""
# noinspection PyBroadException
try:
def default(o):
if isinstance(o, np.int64):
return int(o)
# Special json encoder for numpy types
def default(obj):
if isinstance(obj, (np.integer, np.int64)):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
except Exception:
default = None

View File

@ -1,7 +1,7 @@
import logging
import math
import warnings
from typing import Any, Sequence, Union, List, Optional, Tuple
from typing import Any, Sequence, Union, List, Optional, Tuple, Dict
import numpy as np
import six
@ -130,9 +130,9 @@ class Logger(object):
You can view the scalar plots in the **Trains Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab.
:param str title: The title of the plot. Plot more than one scalar series on the same plot by using
:param str title: The title (metric) of the plot. Plot more than one scalar series on the same plot by using
the same ``title`` for each call to this method.
:param str series: The title of the series.
:param str series: The series name (variant) of the reported scalar.
:param float value: The value to plot per iteration.
:param int iteration: The iteration number. Iterations are on the x-axis.
"""
@ -168,8 +168,8 @@ class Logger(object):
You can view the vectors plots in the **Trains Web-App (UI)**, **RESULTS** tab, **PLOTS** sub-tab.
:param str title: The title of the plot.
:param str series: The title of the series.
:param str title: The title (metric) of the plot.
:param str series: The series name (variant) of the reported histogram.
:param list(float) values: The series values. A list of floats, or an N-dimensional Numpy array containing
data for each histogram bar.
:type values: list(float), numpy.ndarray
@ -215,8 +215,8 @@ class Logger(object):
You can view the reported histograms in the **Trains Web-App (UI)**, **RESULTS** tab, **PLOTS** sub-tab.
:param str title: The title of the plot.
:param str series: The title of the series.
:param str title: The title (metric) of the plot.
:param str series: The series name (variant) of the reported histogram.
:param list(float) values: The series values. A list of floats, or an N-dimensional Numpy array containing
data for each histogram bar.
:type values: list(float), numpy.ndarray
@ -282,8 +282,8 @@ class Logger(object):
You can view the reported tables in the **Trains Web-App (UI)**, **RESULTS** tab, **PLOTS** sub-tab.
:param str title: The title of the table.
:param str series: The title of the series.
:param str title: The title (metric) of the table.
:param str series: The series name (variant) of the reported table.
:param int iteration: The iteration number.
:param table_plot: The output table plot object
:type table_plot: pandas.DataFrame
@ -331,7 +331,7 @@ class Logger(object):
def report_line_plot(
self,
title, # type: str
series, # type: str
series, # type: Sequence[SeriesInfo]
iteration, # type: int
xaxis, # type: str
yaxis, # type: str
@ -343,9 +343,8 @@ class Logger(object):
"""
For explicit reporting, plot one or more series as lines.
:param str title: The title of the plot.
:param list(LineSeriesInfo) series: All the series data, one list element for each line
in the plot.
:param str title: The title (metric) of the plot.
:param list series: All the series data, one list element for each line in the plot.
:param int iteration: The iteration number.
:param str xaxis: The x-axis title. (Optional)
:param str yaxis: The y-axis title. (Optional)
@ -423,8 +422,8 @@ class Logger(object):
logger.report_scatter2d("example_scatter", "series_2", iteration=1, scatter=scatter2d_2,
xaxis="title x", yaxis="title y")
:param str title: The title of the plot.
:param str series: The title of the series.
:param str title: The title (metric) of the plot.
:param str series: The series name (variant) of the reported scatter plot.
:param list scatter: The scatter data. numpy.ndarray or list of (pairs of x,y) scatter:
:param int iteration: The iteration number. To set an initial iteration, for example to continue a previously
:param str xaxis: The x-axis title. (Optional)
@ -483,8 +482,8 @@ class Logger(object):
"""
For explicit reporting, plot a 3d scatter graph (with markers).
:param str title: The title of the plot.
:param str series: The title of the series.
:param str title: The title (metric) of the plot.
:param str series: The series name (variant) of the reported scatter plot.
:param Union[numpy.ndarray, list] scatter: The scatter data.
list of (pairs of x,y,z), list of series [[(x1,y1,z1)...]], or numpy.ndarray
:param int iteration: The iteration number.
@ -585,8 +584,8 @@ class Logger(object):
logger.report_confusion_matrix("example confusion matrix", "ignored", iteration=1, matrix=confusion,
xaxis="title X", yaxis="title Y")
:param str title: The title of the plot.
:param str series: The title of the series.
:param str title: The title (metric) of the plot.
:param str series: The series name (variant) of the reported confusion matrix.
:param numpy.ndarray matrix: A heat-map matrix (example: confusion matrix)
:param int iteration: The iteration number.
:param str xaxis: The x-axis title. (Optional)
@ -635,8 +634,8 @@ class Logger(object):
.. note::
This method is the same as :meth:`Logger.report_confusion_matrix`.
:param str title: The title of the plot.
:param str series: The title of the series.
:param str title: The title (metric) of the plot.
:param str series: The series name (variant) of the reported confusion matrix.
:param numpy.ndarray matrix: A heat-map matrix (example: confusion matrix)
:param int iteration: The iteration number.
:param str xaxis: The x-axis title. (Optional)
@ -679,8 +678,8 @@ class Logger(object):
logger.report_surface("example surface", "series", iteration=0, matrix=surface_matrix,
xaxis="title X", yaxis="title Y", zaxis="title Z")
:param str title: The title of the plot.
:param str series: The title of the series.
:param str title: The title (metric) of the plot.
:param str series: The series name (variant) of the reported surface.
:param numpy.ndarray matrix: A heat-map matrix (example: confusion matrix)
:param int iteration: The iteration number.
:param str xaxis: The x-axis title. (Optional)
@ -751,8 +750,8 @@ class Logger(object):
- ``image``
- ``matrix``
:param str title: The title of the image.
:param str series: The title of the series of this image.
:param str title: The title (metric) of the image.
:param str series: The series name (variant) of the reported image.
:param int iteration: The iteration number.
:param str local_path: A path to an image file.
:param str url: A URL for the location of a pre-uploaded image.
@ -847,8 +846,8 @@ class Logger(object):
If you use ``stream`` for a BytesIO stream to upload, ``file_extension`` must be provided.
:param str title: The title of the media (metric).
:param str series: The title of the series of this (variant).
:param str title: The title (metric) of the media.
:param str series: The series name (variant) of the reported media.
:param int iteration: The iteration number.
:param str local_path: A path to an media file.
:param stream: BytesIO stream to upload. If provided, ``file_extension`` must also be provided.
@ -904,6 +903,41 @@ class Logger(object):
file_extension=file_extension,
)
def report_plotly(
self,
title, # type: str
series, # type: str
iteration, # type: int
figure, # type: Union[Dict, "Figure"]
):
"""
Report a ``Plotly`` figure (plot) directly
``Plotly`` figure can be a ``plotly.graph_objs._figure.Figure`` or a dictionary as defined by ``plotly.js``
:param str title: The title (metric) of the plot.
:param str series: The series name (variant) of the reported plot.
:param int iteration: The iteration number.
:param dict figure: A ``plotly`` Figure object or a ``poltly`` dictionary
"""
# if task was not started, we have to start it
self._start_task_if_needed()
self._touch_title_series(title, series)
plot = figure if isinstance(figure, dict) else figure.to_plotly_json()
# noinspection PyBroadException
try:
plot['layout']['title'] = series
except Exception:
pass
self._task.reporter.report_plot(
title=title,
series=series,
plot=plot,
iter=iteration,
)
def set_default_upload_destination(self, uri):
# type: (str) -> None
"""