diff --git a/trains/backend_interface/metrics/reporter.py b/trains/backend_interface/metrics/reporter.py index d9143e83..72911d8d 100644 --- a/trains/backend_interface/metrics/reporter.py +++ b/trains/backend_interface/metrics/reporter.py @@ -284,7 +284,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan self._report(ev) def report_histogram(self, title, series, histogram, iter, labels=None, xlabels=None, - xtitle=None, ytitle=None, comment=None): + xtitle=None, ytitle=None, comment=None, mode='group'): """ Report an histogram bar plot :param title: Title (AKA metric) @@ -304,7 +304,11 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan :param str ytitle: optional y-axis title :param comment: comment underneath the title :type comment: str + :param mode: multiple histograms mode. valid options are: stack / group / relative. Default is 'group'. + :type mode: str """ + assert mode in ('stack', 'group', 'relative') + plotly_dict = create_2d_histogram_plot( np_row_wise=histogram, title=title, @@ -314,6 +318,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan series=series, xlabels=xlabels, comment=comment, + mode=mode, ) return self.report_plot( diff --git a/trains/logger.py b/trains/logger.py index 67c61009..59afacfb 100644 --- a/trains/logger.py +++ b/trains/logger.py @@ -12,6 +12,7 @@ except ImportError: from PIL import Image from pathlib2 import Path +from .backend_api.services import tasks from .backend_interface.logger import StdStreamPatch, LogFlusher from .backend_interface.task import Task as _Task from .backend_interface.task.development.worker import DevWorker @@ -152,10 +153,11 @@ class Logger(object): labels=None, # type: Optional[List[str]] xlabels=None, # type: Optional[List[str]] xaxis=None, # type: Optional[str] - yaxis=None # type: Optional[str] + yaxis=None, # type: Optional[str] + mode=None # type: Optional[str] ): """ - For explicit reporting, plot a vector as (stacked) histogram. + For explicit reporting, plot a vector as (default stacked) histogram. For example: @@ -178,10 +180,11 @@ class Logger(object): for each histogram bar on the x-axis. (Optional) :param str xaxis: The x-axis title. (Optional) :param str yaxis: The y-axis title. (Optional) + :param str mode: Multiple histograms mode, stack / group / relative. Default is 'group'. """ self._touch_title_series(title, series) return self.report_histogram(title, series, values, iteration, labels=labels, xlabels=xlabels, - xaxis=xaxis, yaxis=yaxis) + xaxis=xaxis, yaxis=yaxis, mode=mode) def report_histogram( self, @@ -192,10 +195,11 @@ class Logger(object): labels=None, # type: Optional[List[str]] xlabels=None, # type: Optional[List[str]] xaxis=None, # type: Optional[str] - yaxis=None # type: Optional[str] + yaxis=None, # type: Optional[str] + mode=None # type: Optional[str] ): """ - For explicit reporting, plot a (stacked) histogram. + For explicit reporting, plot a (default stacked) histogram. For example: @@ -218,6 +222,7 @@ class Logger(object): for each histogram bar on the x-axis. (Optional) :param str xaxis: The x-axis title. (Optional) :param str yaxis: The y-axis title. (Optional) + :param str mode: Multiple histograms mode, stack / group / relative. Default is 'group'. """ if not isinstance(values, np.ndarray): @@ -235,6 +240,7 @@ class Logger(object): xlabels=xlabels, xtitle=xaxis, ytitle=yaxis, + mode=mode or 'group' ) def report_table( diff --git a/trains/utilities/plotly_reporter.py b/trains/utilities/plotly_reporter.py index 3ef53341..d8b98280 100644 --- a/trains/utilities/plotly_reporter.py +++ b/trains/utilities/plotly_reporter.py @@ -10,7 +10,7 @@ from attr import attrs, attrib def create_2d_histogram_plot(np_row_wise, labels, title=None, xtitle=None, ytitle=None, series=None, xlabels=None, - comment=None): + comment=None, mode='group'): """ Create a 2D Plotly histogram chart from a 2D numpy array :param np_row_wise: 2D numpy data array @@ -19,8 +19,11 @@ def create_2d_histogram_plot(np_row_wise, labels, title=None, xtitle=None, ytitl :param xtitle: X-Series title :param ytitle: Y-Series title :param comment: comment underneath the title + :param mode: multiple histograms mode. valid options are: stack / group / relative. Default is 'group'. :return: Plotly chart dict """ + assert mode in ('stack', 'group', 'relative') + np_row_wise = np.atleast_2d(np_row_wise) assert len(np_row_wise.shape) == 2, "Expected a 2D numpy array" # using labels without xlabels leads to original behavior @@ -36,7 +39,7 @@ def create_2d_histogram_plot(np_row_wise, labels, title=None, xtitle=None, ytitl data = [_np_row_to_plotly_data_item(np_row=np_row_wise[i, :], label=labels[i] if labels else None, xlabels=xlabels) for i in range(np_row_wise.shape[0])] - return _plotly_hist_dict(title=title, xtitle=xtitle, ytitle=ytitle, data=data, comment=comment) + return _plotly_hist_dict(title=title, xtitle=xtitle, ytitle=ytitle, mode=mode, data=data, comment=comment) def _to_np_array(value): @@ -334,16 +337,19 @@ def _get_z_colorbar_data(z_data=None, values=None, colors=None): return colorscale, colorbar -def _plotly_hist_dict(title, xtitle, ytitle, data=None, comment=None): +def _plotly_hist_dict(title, xtitle, ytitle, mode='group', data=None, comment=None): """ Create a basic Plotly chart dictionary :param title: Chart title :param xtitle: X-Series title :param ytitle: Y-Series title + :param mode: multiple histograms mode. optionals stack / group / relative. Default is 'group'. :param data: Data items :type data: list :return: Plotly chart dict """ + assert mode in ('stack', 'group', 'relative') + return { "data": data or [], "layout": { @@ -354,7 +360,7 @@ def _plotly_hist_dict(title, xtitle, ytitle, data=None, comment=None): "yaxis": { "title": ytitle }, - "barmode": "stack", + "barmode": mode, "bargap": 0.08, "bargroupgap": 0 }