From 315486bf9d17ec33cbb015f35aa29129c9157496 Mon Sep 17 00:00:00 2001 From: BlakeJC94 Date: Mon, 2 Dec 2024 20:43:46 +1100 Subject: [PATCH] Add support for pl.DataFrame in Logger.report_table --- clearml/backend_interface/metrics/reporter.py | 2 +- clearml/logger.py | 20 ++++++++++++------- clearml/utilities/plotly_reporter.py | 18 ++++++++++++++--- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/clearml/backend_interface/metrics/reporter.py b/clearml/backend_interface/metrics/reporter.py index b5fb4d49..dcf56d96 100644 --- a/clearml/backend_interface/metrics/reporter.py +++ b/clearml/backend_interface/metrics/reporter.py @@ -654,7 +654,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan :param series: Series (AKA variant) :type series: str :param table: The table data - :type table: pandas.DataFrame + :type table: pandas.DataFrame or polars.DataFrame :param iteration: Iteration number :type iteration: int :param layout_config: optional dictionary for layout configuration, passed directly to plotly diff --git a/clearml/logger.py b/clearml/logger.py index fe0a19df..d745078e 100644 --- a/clearml/logger.py +++ b/clearml/logger.py @@ -10,6 +10,11 @@ from pathlib2 import Path from .debugging.log import LoggerRoot +try: + import polars as pl +except ImportError: + pl = None + try: import pandas as pd except ImportError: @@ -326,7 +331,7 @@ class Logger(object): title, # type: str series, # type: str iteration=None, # type: Optional[int] - table_plot=None, # type: Optional[pd.DataFrame, Sequence[Sequence]] + table_plot=None, # type: Optional[pd.DataFrame, pl.DataFrame, Sequence[Sequence]] csv=None, # type: Optional[str] url=None, # type: Optional[str] extra_layout=None, # type: Optional[dict] @@ -392,15 +397,15 @@ class Logger(object): mutually_exclusive(UsageError, _check_none=True, table_plot=table_plot, csv=csv, url=url) table = table_plot if url or csv: - if not pd: + if not pd and not pl: raise UsageError( - "pandas is required in order to support reporting tables using CSV or a URL, " - "please install the pandas python package" + "pandas or polars is required in order to support reporting tables using CSV " + "or a URL, please install the pandas or polars python package" ) if url: - table = pd.read_csv(url, index_col=[0]) + table = pd.read_csv(url, index_col=[0]) if pd else pl.read_csv(url) elif csv: - table = pd.read_csv(csv, index_col=[0]) + table = pd.read_csv(csv, index_col=[0]) if pd else pl.read_csv(csv) def replace(dst, *srcs): for src in srcs: @@ -409,7 +414,8 @@ class Logger(object): if isinstance(table, (list, tuple)): reporter_table = table else: - reporter_table = table.fillna(str(np.nan)) + nan = str(np.nan) + reporter_table = table.fillna(nan) if pd else table.fill_nan(nan) replace("NaN", np.nan, math.nan if six.PY3 else float("nan")) replace("Inf", np.inf, math.inf if six.PY3 else float("inf")) minus_inf = [-np.inf, -math.inf if six.PY3 else -float("inf")] diff --git a/clearml/utilities/plotly_reporter.py b/clearml/utilities/plotly_reporter.py index e3eb51e0..1da65029 100644 --- a/clearml/utilities/plotly_reporter.py +++ b/clearml/utilities/plotly_reporter.py @@ -4,6 +4,10 @@ import numpy as np from ..errors import UsageError from ..utilities.dicts import merge_dicts +try: + import polars as pl +except ImportError: + pl = None try: import pandas as pd except ImportError: @@ -486,7 +490,7 @@ def create_plotly_table(table_plot, title, series, layout_config=None, data_conf """ Create a basic Plotly table json style to be sent - :param table_plot: the output table in pandas.DataFrame structure or list of rows (list) in a table + :param table_plot: the output table in pandas.DataFrame structure or polars.Dataframe structure or list of rows (list) in a table :param title: Title (AKA metric) :type title: str :param series: Series (AKA variant) @@ -503,11 +507,19 @@ def create_plotly_table(table_plot, title, series, layout_config=None, data_conf elif is_list and table_plot[0] and isinstance(table_plot[0], (list, tuple)): headers_values = table_plot[0] cells_values = [list(i) for i in zip(*table_plot[1:])] + elif pl and isinstance(table_plot, pl.DataFrame): + headers_values = list([col] for col in table_plot.columns) + # Convert datetimes to ISO strings + datetime_columns = table_plot.select(pl.selectors.datetime()).columns + exprs = [pl.col(col).dt.to_string("iso:strict") for col in datetime_columns] + # Get cell values and preserve value types + cells_values_transpose = table_plot.with_columns(*exprs).rows() + cells_values = list(map(list, zip(*cells_values_transpose))) else: if not pd: raise UsageError( - "pandas is required in order to support reporting tables using CSV or a URL, " - "please install the pandas python package" + "pandas or polars is required in order to support reporting tables using CSV or a URL, " + "please install the pandas or polars python package" ) index_added = not isinstance(table_plot.index, pd.RangeIndex) headers_values = list([col] for col in table_plot.columns)