Add Logger histogram mode (stack/group/relative)

This commit is contained in:
allegroai 2020-05-08 22:05:33 +03:00
parent 5a85d40fc7
commit a5ff2ba9c8
3 changed files with 27 additions and 10 deletions

View File

@ -284,7 +284,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
self._report(ev) self._report(ev)
def report_histogram(self, title, series, histogram, iter, labels=None, xlabels=None, def report_histogram(self, title, series, histogram, iter, labels=None, xlabels=None,
xtitle=None, ytitle=None, comment=None): xtitle=None, ytitle=None, comment=None, mode='group'):
""" """
Report an histogram bar plot Report an histogram bar plot
:param title: Title (AKA metric) :param title: Title (AKA metric)
@ -304,7 +304,11 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param str ytitle: optional y-axis title :param str ytitle: optional y-axis title
:param comment: comment underneath the title :param comment: comment underneath the title
:type comment: str :type comment: str
:param mode: multiple histograms mode. valid options are: stack / group / relative. Default is 'group'.
:type mode: str
""" """
assert mode in ('stack', 'group', 'relative')
plotly_dict = create_2d_histogram_plot( plotly_dict = create_2d_histogram_plot(
np_row_wise=histogram, np_row_wise=histogram,
title=title, title=title,
@ -314,6 +318,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
series=series, series=series,
xlabels=xlabels, xlabels=xlabels,
comment=comment, comment=comment,
mode=mode,
) )
return self.report_plot( return self.report_plot(

View File

@ -12,6 +12,7 @@ except ImportError:
from PIL import Image from PIL import Image
from pathlib2 import Path from pathlib2 import Path
from .backend_api.services import tasks
from .backend_interface.logger import StdStreamPatch, LogFlusher from .backend_interface.logger import StdStreamPatch, LogFlusher
from .backend_interface.task import Task as _Task from .backend_interface.task import Task as _Task
from .backend_interface.task.development.worker import DevWorker from .backend_interface.task.development.worker import DevWorker
@ -152,10 +153,11 @@ class Logger(object):
labels=None, # type: Optional[List[str]] labels=None, # type: Optional[List[str]]
xlabels=None, # type: Optional[List[str]] xlabels=None, # type: Optional[List[str]]
xaxis=None, # type: Optional[str] xaxis=None, # type: Optional[str]
yaxis=None # type: Optional[str] yaxis=None, # type: Optional[str]
mode=None # type: Optional[str]
): ):
""" """
For explicit reporting, plot a vector as (stacked) histogram. For explicit reporting, plot a vector as (default stacked) histogram.
For example: For example:
@ -178,10 +180,11 @@ class Logger(object):
for each histogram bar on the x-axis. (Optional) for each histogram bar on the x-axis. (Optional)
:param str xaxis: The x-axis title. (Optional) :param str xaxis: The x-axis title. (Optional)
:param str yaxis: The y-axis title. (Optional) :param str yaxis: The y-axis title. (Optional)
:param str mode: Multiple histograms mode, stack / group / relative. Default is 'group'.
""" """
self._touch_title_series(title, series) self._touch_title_series(title, series)
return self.report_histogram(title, series, values, iteration, labels=labels, xlabels=xlabels, return self.report_histogram(title, series, values, iteration, labels=labels, xlabels=xlabels,
xaxis=xaxis, yaxis=yaxis) xaxis=xaxis, yaxis=yaxis, mode=mode)
def report_histogram( def report_histogram(
self, self,
@ -192,10 +195,11 @@ class Logger(object):
labels=None, # type: Optional[List[str]] labels=None, # type: Optional[List[str]]
xlabels=None, # type: Optional[List[str]] xlabels=None, # type: Optional[List[str]]
xaxis=None, # type: Optional[str] xaxis=None, # type: Optional[str]
yaxis=None # type: Optional[str] yaxis=None, # type: Optional[str]
mode=None # type: Optional[str]
): ):
""" """
For explicit reporting, plot a (stacked) histogram. For explicit reporting, plot a (default stacked) histogram.
For example: For example:
@ -218,6 +222,7 @@ class Logger(object):
for each histogram bar on the x-axis. (Optional) for each histogram bar on the x-axis. (Optional)
:param str xaxis: The x-axis title. (Optional) :param str xaxis: The x-axis title. (Optional)
:param str yaxis: The y-axis title. (Optional) :param str yaxis: The y-axis title. (Optional)
:param str mode: Multiple histograms mode, stack / group / relative. Default is 'group'.
""" """
if not isinstance(values, np.ndarray): if not isinstance(values, np.ndarray):
@ -235,6 +240,7 @@ class Logger(object):
xlabels=xlabels, xlabels=xlabels,
xtitle=xaxis, xtitle=xaxis,
ytitle=yaxis, ytitle=yaxis,
mode=mode or 'group'
) )
def report_table( def report_table(

View File

@ -10,7 +10,7 @@ from attr import attrs, attrib
def create_2d_histogram_plot(np_row_wise, labels, title=None, xtitle=None, ytitle=None, series=None, xlabels=None, def create_2d_histogram_plot(np_row_wise, labels, title=None, xtitle=None, ytitle=None, series=None, xlabels=None,
comment=None): comment=None, mode='group'):
""" """
Create a 2D Plotly histogram chart from a 2D numpy array Create a 2D Plotly histogram chart from a 2D numpy array
:param np_row_wise: 2D numpy data array :param np_row_wise: 2D numpy data array
@ -19,8 +19,11 @@ def create_2d_histogram_plot(np_row_wise, labels, title=None, xtitle=None, ytitl
:param xtitle: X-Series title :param xtitle: X-Series title
:param ytitle: Y-Series title :param ytitle: Y-Series title
:param comment: comment underneath the title :param comment: comment underneath the title
:param mode: multiple histograms mode. valid options are: stack / group / relative. Default is 'group'.
:return: Plotly chart dict :return: Plotly chart dict
""" """
assert mode in ('stack', 'group', 'relative')
np_row_wise = np.atleast_2d(np_row_wise) np_row_wise = np.atleast_2d(np_row_wise)
assert len(np_row_wise.shape) == 2, "Expected a 2D numpy array" assert len(np_row_wise.shape) == 2, "Expected a 2D numpy array"
# using labels without xlabels leads to original behavior # using labels without xlabels leads to original behavior
@ -36,7 +39,7 @@ def create_2d_histogram_plot(np_row_wise, labels, title=None, xtitle=None, ytitl
data = [_np_row_to_plotly_data_item(np_row=np_row_wise[i, :], label=labels[i] if labels else None, xlabels=xlabels) data = [_np_row_to_plotly_data_item(np_row=np_row_wise[i, :], label=labels[i] if labels else None, xlabels=xlabels)
for i in range(np_row_wise.shape[0])] for i in range(np_row_wise.shape[0])]
return _plotly_hist_dict(title=title, xtitle=xtitle, ytitle=ytitle, data=data, comment=comment) return _plotly_hist_dict(title=title, xtitle=xtitle, ytitle=ytitle, mode=mode, data=data, comment=comment)
def _to_np_array(value): def _to_np_array(value):
@ -334,16 +337,19 @@ def _get_z_colorbar_data(z_data=None, values=None, colors=None):
return colorscale, colorbar return colorscale, colorbar
def _plotly_hist_dict(title, xtitle, ytitle, data=None, comment=None): def _plotly_hist_dict(title, xtitle, ytitle, mode='group', data=None, comment=None):
""" """
Create a basic Plotly chart dictionary Create a basic Plotly chart dictionary
:param title: Chart title :param title: Chart title
:param xtitle: X-Series title :param xtitle: X-Series title
:param ytitle: Y-Series title :param ytitle: Y-Series title
:param mode: multiple histograms mode. optionals stack / group / relative. Default is 'group'.
:param data: Data items :param data: Data items
:type data: list :type data: list
:return: Plotly chart dict :return: Plotly chart dict
""" """
assert mode in ('stack', 'group', 'relative')
return { return {
"data": data or [], "data": data or [],
"layout": { "layout": {
@ -354,7 +360,7 @@ def _plotly_hist_dict(title, xtitle, ytitle, data=None, comment=None):
"yaxis": { "yaxis": {
"title": ytitle "title": ytitle
}, },
"barmode": "stack", "barmode": mode,
"bargap": 0.08, "bargap": 0.08,
"bargroupgap": 0 "bargroupgap": 0
} }