Add Logger histogram mode (stack/group/relative)

This commit is contained in:
allegroai 2020-05-08 22:05:33 +03:00
parent 5a85d40fc7
commit a5ff2ba9c8
3 changed files with 27 additions and 10 deletions

View File

@ -284,7 +284,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
self._report(ev)
def report_histogram(self, title, series, histogram, iter, labels=None, xlabels=None,
xtitle=None, ytitle=None, comment=None):
xtitle=None, ytitle=None, comment=None, mode='group'):
"""
Report an histogram bar plot
:param title: Title (AKA metric)
@ -304,7 +304,11 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param str ytitle: optional y-axis title
:param comment: comment underneath the title
:type comment: str
:param mode: multiple histograms mode. valid options are: stack / group / relative. Default is 'group'.
:type mode: str
"""
assert mode in ('stack', 'group', 'relative')
plotly_dict = create_2d_histogram_plot(
np_row_wise=histogram,
title=title,
@ -314,6 +318,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
series=series,
xlabels=xlabels,
comment=comment,
mode=mode,
)
return self.report_plot(

View File

@ -12,6 +12,7 @@ except ImportError:
from PIL import Image
from pathlib2 import Path
from .backend_api.services import tasks
from .backend_interface.logger import StdStreamPatch, LogFlusher
from .backend_interface.task import Task as _Task
from .backend_interface.task.development.worker import DevWorker
@ -152,10 +153,11 @@ class Logger(object):
labels=None, # type: Optional[List[str]]
xlabels=None, # type: Optional[List[str]]
xaxis=None, # type: Optional[str]
yaxis=None # type: Optional[str]
yaxis=None, # type: Optional[str]
mode=None # type: Optional[str]
):
"""
For explicit reporting, plot a vector as (stacked) histogram.
For explicit reporting, plot a vector as (default stacked) histogram.
For example:
@ -178,10 +180,11 @@ class Logger(object):
for each histogram bar on the x-axis. (Optional)
:param str xaxis: The x-axis title. (Optional)
:param str yaxis: The y-axis title. (Optional)
:param str mode: Multiple histograms mode, stack / group / relative. Default is 'group'.
"""
self._touch_title_series(title, series)
return self.report_histogram(title, series, values, iteration, labels=labels, xlabels=xlabels,
xaxis=xaxis, yaxis=yaxis)
xaxis=xaxis, yaxis=yaxis, mode=mode)
def report_histogram(
self,
@ -192,10 +195,11 @@ class Logger(object):
labels=None, # type: Optional[List[str]]
xlabels=None, # type: Optional[List[str]]
xaxis=None, # type: Optional[str]
yaxis=None # type: Optional[str]
yaxis=None, # type: Optional[str]
mode=None # type: Optional[str]
):
"""
For explicit reporting, plot a (stacked) histogram.
For explicit reporting, plot a (default stacked) histogram.
For example:
@ -218,6 +222,7 @@ class Logger(object):
for each histogram bar on the x-axis. (Optional)
:param str xaxis: The x-axis title. (Optional)
:param str yaxis: The y-axis title. (Optional)
:param str mode: Multiple histograms mode, stack / group / relative. Default is 'group'.
"""
if not isinstance(values, np.ndarray):
@ -235,6 +240,7 @@ class Logger(object):
xlabels=xlabels,
xtitle=xaxis,
ytitle=yaxis,
mode=mode or 'group'
)
def report_table(

View File

@ -10,7 +10,7 @@ from attr import attrs, attrib
def create_2d_histogram_plot(np_row_wise, labels, title=None, xtitle=None, ytitle=None, series=None, xlabels=None,
comment=None):
comment=None, mode='group'):
"""
Create a 2D Plotly histogram chart from a 2D numpy array
:param np_row_wise: 2D numpy data array
@ -19,8 +19,11 @@ def create_2d_histogram_plot(np_row_wise, labels, title=None, xtitle=None, ytitl
:param xtitle: X-Series title
:param ytitle: Y-Series title
:param comment: comment underneath the title
:param mode: multiple histograms mode. valid options are: stack / group / relative. Default is 'group'.
:return: Plotly chart dict
"""
assert mode in ('stack', 'group', 'relative')
np_row_wise = np.atleast_2d(np_row_wise)
assert len(np_row_wise.shape) == 2, "Expected a 2D numpy array"
# using labels without xlabels leads to original behavior
@ -36,7 +39,7 @@ def create_2d_histogram_plot(np_row_wise, labels, title=None, xtitle=None, ytitl
data = [_np_row_to_plotly_data_item(np_row=np_row_wise[i, :], label=labels[i] if labels else None, xlabels=xlabels)
for i in range(np_row_wise.shape[0])]
return _plotly_hist_dict(title=title, xtitle=xtitle, ytitle=ytitle, data=data, comment=comment)
return _plotly_hist_dict(title=title, xtitle=xtitle, ytitle=ytitle, mode=mode, data=data, comment=comment)
def _to_np_array(value):
@ -334,16 +337,19 @@ def _get_z_colorbar_data(z_data=None, values=None, colors=None):
return colorscale, colorbar
def _plotly_hist_dict(title, xtitle, ytitle, data=None, comment=None):
def _plotly_hist_dict(title, xtitle, ytitle, mode='group', data=None, comment=None):
"""
Create a basic Plotly chart dictionary
:param title: Chart title
:param xtitle: X-Series title
:param ytitle: Y-Series title
:param mode: multiple histograms mode. optionals stack / group / relative. Default is 'group'.
:param data: Data items
:type data: list
:return: Plotly chart dict
"""
assert mode in ('stack', 'group', 'relative')
return {
"data": data or [],
"layout": {
@ -354,7 +360,7 @@ def _plotly_hist_dict(title, xtitle, ytitle, data=None, comment=None):
"yaxis": {
"title": ytitle
},
"barmode": "stack",
"barmode": mode,
"bargap": 0.08,
"bargroupgap": 0
}