Fix breaking change in call to Logger.report_confusion_matrix()

This commit is contained in:
allegroai 2023-02-28 17:04:33 +02:00
parent 89e39f0b02
commit dd5036a0a0

View File

@ -611,6 +611,7 @@ class Logger(object):
def report_confusion_matrix(
self,
title, # type: str
series, # type: str
matrix, # type: np.ndarray
iteration=None, # type: Optional[int]
xaxis=None, # type: Optional[str]
@ -619,7 +620,7 @@ class Logger(object):
ylabels=None, # type: Optional[List[str]]
yaxis_reversed=False, # type: bool
comment=None, # type: Optional[str]
extra_layout={'texttemplate': '%{z}'}, # type: Optional[dict]
extra_layout=None, # type: Optional[dict]
):
"""
For explicit reporting, plot a heat-map matrix.
@ -629,10 +630,11 @@ class Logger(object):
.. code-block:: py
confusion = np.random.randint(10, size=(10, 10))
logger.report_confusion_matrix("example confusion matrix", iteration=1, matrix=confusion,
logger.report_confusion_matrix("example confusion matrix", "ignored", iteration=1, matrix=confusion,
xaxis="title X", yaxis="title Y")
:param str title: The title (metric) of the plot.
:param str series: The series name (variant) of the reported confusion matrix.
:param numpy.ndarray matrix: A heat-map matrix (example: confusion matrix)
:param int iteration: The reported iteration / step.
:param str xaxis: The x-axis title. (Optional)
@ -649,12 +651,15 @@ class Logger(object):
if not isinstance(matrix, np.ndarray):
matrix = np.array(matrix)
if extra_layout is None:
extra_layout = {'texttemplate': '%{z}'}
# if task was not started, we have to start it
self._start_task_if_needed()
# noinspection PyProtectedMember
return self._task._reporter.report_value_matrix(
title=title,
series='ignored',
series=series,
data=matrix.astype(np.float32),
iter=iteration or 0,
xtitle=xaxis,