diff --git a/trains/logger.py b/trains/logger.py index 29467d75..dc645f5c 100644 --- a/trains/logger.py +++ b/trains/logger.py @@ -265,7 +265,7 @@ class Logger(object): title, # type: str series, # type: str iteration, # type: int - table_plot=None, # type: Optional[pd.DataFrame] + table_plot=None, # type: Optional[pd.DataFrame, Sequence[Sequence]] csv=None, # type: Optional[str] url=None, # type: Optional[str] extra_layout=None, # type: Optional[dict] @@ -275,7 +275,7 @@ class Logger(object): One and only one of the following parameters must be provided. - - ``table_plot`` - Pandas DataFrame + - ``table_plot`` - Pandas DataFrame or Table as list of rows (list) - ``csv`` - CSV file - ``url`` - URL to CSV file @@ -296,7 +296,7 @@ class Logger(object): :param str series: The series name (variant) of the reported table. :param int iteration: The iteration number. :param table_plot: The output table plot object - :type table_plot: pandas.DataFrame + :type table_plot: pandas.DataFrame or Table as list of rows (list) :param csv: path to local csv file :type csv: str :param url: A URL to the location of csv file. @@ -325,10 +325,13 @@ class Logger(object): for src in srcs: reporter_table.replace(src, dst, inplace=True) - reporter_table = table.fillna(str(np.nan)) - replace("NaN", np.nan, math.nan) - replace("Inf", np.inf, math.inf) - replace("-Inf", -np.inf, np.NINF, -math.inf) + if isinstance(table, (list, tuple)): + reporter_table = table + else: + reporter_table = table.fillna(str(np.nan)) + replace("NaN", np.nan, math.nan) + replace("Inf", np.inf, math.inf) + replace("-Inf", -np.inf, np.NINF, -math.inf) # noinspection PyProtectedMember return self._task._reporter.report_table( title=title,