Fix Logger.report_confusion_matrix (#894)

* set annotated cells (cells with text inside them) to be the default.
* removed the unused series argument from the signature.
This commit is contained in:
Yiftach Beer 2023-01-26 19:10:10 +02:00 committed by GitHub
parent 0b9428bf21
commit da6f75363d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -617,7 +617,6 @@ class Logger(object):
def report_confusion_matrix( def report_confusion_matrix(
self, self,
title, # type: str title, # type: str
series, # type: str
matrix, # type: np.ndarray matrix, # type: np.ndarray
iteration=None, # type: Optional[int] iteration=None, # type: Optional[int]
xaxis=None, # type: Optional[str] xaxis=None, # type: Optional[str]
@ -626,7 +625,7 @@ class Logger(object):
ylabels=None, # type: Optional[List[str]] ylabels=None, # type: Optional[List[str]]
yaxis_reversed=False, # type: bool yaxis_reversed=False, # type: bool
comment=None, # type: Optional[str] comment=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict] extra_layout={'texttemplate': '%{z}'}, # type: Optional[dict]
): ):
""" """
For explicit reporting, plot a heat-map matrix. For explicit reporting, plot a heat-map matrix.
@ -636,11 +635,10 @@ class Logger(object):
.. code-block:: py .. code-block:: py
confusion = np.random.randint(10, size=(10, 10)) confusion = np.random.randint(10, size=(10, 10))
logger.report_confusion_matrix("example confusion matrix", "ignored", iteration=1, matrix=confusion, logger.report_confusion_matrix("example confusion matrix", iteration=1, matrix=confusion,
xaxis="title X", yaxis="title Y") xaxis="title X", yaxis="title Y")
:param str title: The title (metric) of the plot. :param str title: The title (metric) of the plot.
:param str series: The series name (variant) of the reported confusion matrix.
:param numpy.ndarray matrix: A heat-map matrix (example: confusion matrix) :param numpy.ndarray matrix: A heat-map matrix (example: confusion matrix)
:param int iteration: The reported iteration / step. :param int iteration: The reported iteration / step.
:param str xaxis: The x-axis title. (Optional) :param str xaxis: The x-axis title. (Optional)
@ -663,7 +661,7 @@ class Logger(object):
# noinspection PyProtectedMember # noinspection PyProtectedMember
return self._task._reporter.report_value_matrix( return self._task._reporter.report_value_matrix(
title=title, title=title,
series=series, series='ignored',
data=matrix.astype(np.float32), data=matrix.astype(np.float32),
iter=iteration or 0, iter=iteration or 0,
xtitle=xaxis, xtitle=xaxis,