diff --git a/examples/report_plotly.py b/examples/report_plotly.py new file mode 100644 index 00000000..36dffd67 --- /dev/null +++ b/examples/report_plotly.py @@ -0,0 +1,14 @@ +# TRAINS - Example of Plotly integration and reporting +# +from trains import Task +import plotly.express as px + + +task = Task.init('examples', 'plotly report') + +df = px.data.iris() +fig = px.scatter(df, x="sepal_width", y="sepal_length", color="species", marginal_y="rug", marginal_x="histogram") + +task.get_logger().report_plotly(title="iris", series="sepal", iteration=0, figure=fig) + +print('done') diff --git a/trains/backend_interface/metrics/reporter.py b/trains/backend_interface/metrics/reporter.py index 678b4f1c..9d489c1e 100644 --- a/trains/backend_interface/metrics/reporter.py +++ b/trains/backend_interface/metrics/reporter.py @@ -176,10 +176,16 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan :param iter: Iteration number :type value: int """ + # noinspection PyBroadException try: - def default(o): - if isinstance(o, np.int64): - return int(o) + # Special json encoder for numpy types + def default(obj): + if isinstance(obj, (np.integer, np.int64)): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() except Exception: default = None diff --git a/trains/logger.py b/trains/logger.py index 2659dbdf..fd92989e 100644 --- a/trains/logger.py +++ b/trains/logger.py @@ -1,7 +1,7 @@ import logging import math import warnings -from typing import Any, Sequence, Union, List, Optional, Tuple +from typing import Any, Sequence, Union, List, Optional, Tuple, Dict import numpy as np import six @@ -130,9 +130,9 @@ class Logger(object): You can view the scalar plots in the **Trains Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab. - :param str title: The title of the plot. Plot more than one scalar series on the same plot by using + :param str title: The title (metric) of the plot. Plot more than one scalar series on the same plot by using the same ``title`` for each call to this method. - :param str series: The title of the series. + :param str series: The series name (variant) of the reported scalar. :param float value: The value to plot per iteration. :param int iteration: The iteration number. Iterations are on the x-axis. """ @@ -168,8 +168,8 @@ class Logger(object): You can view the vectors plots in the **Trains Web-App (UI)**, **RESULTS** tab, **PLOTS** sub-tab. - :param str title: The title of the plot. - :param str series: The title of the series. + :param str title: The title (metric) of the plot. + :param str series: The series name (variant) of the reported histogram. :param list(float) values: The series values. A list of floats, or an N-dimensional Numpy array containing data for each histogram bar. :type values: list(float), numpy.ndarray @@ -215,8 +215,8 @@ class Logger(object): You can view the reported histograms in the **Trains Web-App (UI)**, **RESULTS** tab, **PLOTS** sub-tab. - :param str title: The title of the plot. - :param str series: The title of the series. + :param str title: The title (metric) of the plot. + :param str series: The series name (variant) of the reported histogram. :param list(float) values: The series values. A list of floats, or an N-dimensional Numpy array containing data for each histogram bar. :type values: list(float), numpy.ndarray @@ -282,8 +282,8 @@ class Logger(object): You can view the reported tables in the **Trains Web-App (UI)**, **RESULTS** tab, **PLOTS** sub-tab. - :param str title: The title of the table. - :param str series: The title of the series. + :param str title: The title (metric) of the table. + :param str series: The series name (variant) of the reported table. :param int iteration: The iteration number. :param table_plot: The output table plot object :type table_plot: pandas.DataFrame @@ -331,7 +331,7 @@ class Logger(object): def report_line_plot( self, title, # type: str - series, # type: str + series, # type: Sequence[SeriesInfo] iteration, # type: int xaxis, # type: str yaxis, # type: str @@ -343,9 +343,8 @@ class Logger(object): """ For explicit reporting, plot one or more series as lines. - :param str title: The title of the plot. - :param list(LineSeriesInfo) series: All the series data, one list element for each line - in the plot. + :param str title: The title (metric) of the plot. + :param list series: All the series data, one list element for each line in the plot. :param int iteration: The iteration number. :param str xaxis: The x-axis title. (Optional) :param str yaxis: The y-axis title. (Optional) @@ -423,8 +422,8 @@ class Logger(object): logger.report_scatter2d("example_scatter", "series_2", iteration=1, scatter=scatter2d_2, xaxis="title x", yaxis="title y") - :param str title: The title of the plot. - :param str series: The title of the series. + :param str title: The title (metric) of the plot. + :param str series: The series name (variant) of the reported scatter plot. :param list scatter: The scatter data. numpy.ndarray or list of (pairs of x,y) scatter: :param int iteration: The iteration number. To set an initial iteration, for example to continue a previously :param str xaxis: The x-axis title. (Optional) @@ -483,8 +482,8 @@ class Logger(object): """ For explicit reporting, plot a 3d scatter graph (with markers). - :param str title: The title of the plot. - :param str series: The title of the series. + :param str title: The title (metric) of the plot. + :param str series: The series name (variant) of the reported scatter plot. :param Union[numpy.ndarray, list] scatter: The scatter data. list of (pairs of x,y,z), list of series [[(x1,y1,z1)...]], or numpy.ndarray :param int iteration: The iteration number. @@ -585,8 +584,8 @@ class Logger(object): logger.report_confusion_matrix("example confusion matrix", "ignored", iteration=1, matrix=confusion, xaxis="title X", yaxis="title Y") - :param str title: The title of the plot. - :param str series: The title of the series. + :param str title: The title (metric) of the plot. + :param str series: The series name (variant) of the reported confusion matrix. :param numpy.ndarray matrix: A heat-map matrix (example: confusion matrix) :param int iteration: The iteration number. :param str xaxis: The x-axis title. (Optional) @@ -635,8 +634,8 @@ class Logger(object): .. note:: This method is the same as :meth:`Logger.report_confusion_matrix`. - :param str title: The title of the plot. - :param str series: The title of the series. + :param str title: The title (metric) of the plot. + :param str series: The series name (variant) of the reported confusion matrix. :param numpy.ndarray matrix: A heat-map matrix (example: confusion matrix) :param int iteration: The iteration number. :param str xaxis: The x-axis title. (Optional) @@ -679,8 +678,8 @@ class Logger(object): logger.report_surface("example surface", "series", iteration=0, matrix=surface_matrix, xaxis="title X", yaxis="title Y", zaxis="title Z") - :param str title: The title of the plot. - :param str series: The title of the series. + :param str title: The title (metric) of the plot. + :param str series: The series name (variant) of the reported surface. :param numpy.ndarray matrix: A heat-map matrix (example: confusion matrix) :param int iteration: The iteration number. :param str xaxis: The x-axis title. (Optional) @@ -751,8 +750,8 @@ class Logger(object): - ``image`` - ``matrix`` - :param str title: The title of the image. - :param str series: The title of the series of this image. + :param str title: The title (metric) of the image. + :param str series: The series name (variant) of the reported image. :param int iteration: The iteration number. :param str local_path: A path to an image file. :param str url: A URL for the location of a pre-uploaded image. @@ -847,8 +846,8 @@ class Logger(object): If you use ``stream`` for a BytesIO stream to upload, ``file_extension`` must be provided. - :param str title: The title of the media (metric). - :param str series: The title of the series of this (variant). + :param str title: The title (metric) of the media. + :param str series: The series name (variant) of the reported media. :param int iteration: The iteration number. :param str local_path: A path to an media file. :param stream: BytesIO stream to upload. If provided, ``file_extension`` must also be provided. @@ -904,6 +903,41 @@ class Logger(object): file_extension=file_extension, ) + def report_plotly( + self, + title, # type: str + series, # type: str + iteration, # type: int + figure, # type: Union[Dict, "Figure"] + ): + """ + Report a ``Plotly`` figure (plot) directly + + ``Plotly`` figure can be a ``plotly.graph_objs._figure.Figure`` or a dictionary as defined by ``plotly.js`` + + :param str title: The title (metric) of the plot. + :param str series: The series name (variant) of the reported plot. + :param int iteration: The iteration number. + :param dict figure: A ``plotly`` Figure object or a ``poltly`` dictionary + """ + # if task was not started, we have to start it + self._start_task_if_needed() + + self._touch_title_series(title, series) + + plot = figure if isinstance(figure, dict) else figure.to_plotly_json() + # noinspection PyBroadException + try: + plot['layout']['title'] = series + except Exception: + pass + self._task.reporter.report_plot( + title=title, + series=series, + plot=plot, + iter=iteration, + ) + def set_default_upload_destination(self, uri): # type: (str) -> None """