Add support for pl.DataFrame in Logger.report_table

This commit is contained in:
BlakeJC94 2024-12-02 20:43:46 +11:00
parent 01af13a294
commit 315486bf9d
3 changed files with 29 additions and 11 deletions

View File

@ -654,7 +654,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param series: Series (AKA variant) :param series: Series (AKA variant)
:type series: str :type series: str
:param table: The table data :param table: The table data
:type table: pandas.DataFrame :type table: pandas.DataFrame or polars.DataFrame
:param iteration: Iteration number :param iteration: Iteration number
:type iteration: int :type iteration: int
:param layout_config: optional dictionary for layout configuration, passed directly to plotly :param layout_config: optional dictionary for layout configuration, passed directly to plotly

View File

@ -10,6 +10,11 @@ from pathlib2 import Path
from .debugging.log import LoggerRoot from .debugging.log import LoggerRoot
try:
import polars as pl
except ImportError:
pl = None
try: try:
import pandas as pd import pandas as pd
except ImportError: except ImportError:
@ -326,7 +331,7 @@ class Logger(object):
title, # type: str title, # type: str
series, # type: str series, # type: str
iteration=None, # type: Optional[int] iteration=None, # type: Optional[int]
table_plot=None, # type: Optional[pd.DataFrame, Sequence[Sequence]] table_plot=None, # type: Optional[pd.DataFrame, pl.DataFrame, Sequence[Sequence]]
csv=None, # type: Optional[str] csv=None, # type: Optional[str]
url=None, # type: Optional[str] url=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict] extra_layout=None, # type: Optional[dict]
@ -392,15 +397,15 @@ class Logger(object):
mutually_exclusive(UsageError, _check_none=True, table_plot=table_plot, csv=csv, url=url) mutually_exclusive(UsageError, _check_none=True, table_plot=table_plot, csv=csv, url=url)
table = table_plot table = table_plot
if url or csv: if url or csv:
if not pd: if not pd and not pl:
raise UsageError( raise UsageError(
"pandas is required in order to support reporting tables using CSV or a URL, " "pandas or polars is required in order to support reporting tables using CSV "
"please install the pandas python package" "or a URL, please install the pandas or polars python package"
) )
if url: if url:
table = pd.read_csv(url, index_col=[0]) table = pd.read_csv(url, index_col=[0]) if pd else pl.read_csv(url)
elif csv: elif csv:
table = pd.read_csv(csv, index_col=[0]) table = pd.read_csv(csv, index_col=[0]) if pd else pl.read_csv(csv)
def replace(dst, *srcs): def replace(dst, *srcs):
for src in srcs: for src in srcs:
@ -409,7 +414,8 @@ class Logger(object):
if isinstance(table, (list, tuple)): if isinstance(table, (list, tuple)):
reporter_table = table reporter_table = table
else: else:
reporter_table = table.fillna(str(np.nan)) nan = str(np.nan)
reporter_table = table.fillna(nan) if pd else table.fill_nan(nan)
replace("NaN", np.nan, math.nan if six.PY3 else float("nan")) replace("NaN", np.nan, math.nan if six.PY3 else float("nan"))
replace("Inf", np.inf, math.inf if six.PY3 else float("inf")) replace("Inf", np.inf, math.inf if six.PY3 else float("inf"))
minus_inf = [-np.inf, -math.inf if six.PY3 else -float("inf")] minus_inf = [-np.inf, -math.inf if six.PY3 else -float("inf")]

View File

@ -4,6 +4,10 @@ import numpy as np
from ..errors import UsageError from ..errors import UsageError
from ..utilities.dicts import merge_dicts from ..utilities.dicts import merge_dicts
try:
import polars as pl
except ImportError:
pl = None
try: try:
import pandas as pd import pandas as pd
except ImportError: except ImportError:
@ -486,7 +490,7 @@ def create_plotly_table(table_plot, title, series, layout_config=None, data_conf
""" """
Create a basic Plotly table json style to be sent Create a basic Plotly table json style to be sent
:param table_plot: the output table in pandas.DataFrame structure or list of rows (list) in a table :param table_plot: the output table in pandas.DataFrame structure or polars.Dataframe structure or list of rows (list) in a table
:param title: Title (AKA metric) :param title: Title (AKA metric)
:type title: str :type title: str
:param series: Series (AKA variant) :param series: Series (AKA variant)
@ -503,11 +507,19 @@ def create_plotly_table(table_plot, title, series, layout_config=None, data_conf
elif is_list and table_plot[0] and isinstance(table_plot[0], (list, tuple)): elif is_list and table_plot[0] and isinstance(table_plot[0], (list, tuple)):
headers_values = table_plot[0] headers_values = table_plot[0]
cells_values = [list(i) for i in zip(*table_plot[1:])] cells_values = [list(i) for i in zip(*table_plot[1:])]
elif pl and isinstance(table_plot, pl.DataFrame):
headers_values = list([col] for col in table_plot.columns)
# Convert datetimes to ISO strings
datetime_columns = table_plot.select(pl.selectors.datetime()).columns
exprs = [pl.col(col).dt.to_string("iso:strict") for col in datetime_columns]
# Get cell values and preserve value types
cells_values_transpose = table_plot.with_columns(*exprs).rows()
cells_values = list(map(list, zip(*cells_values_transpose)))
else: else:
if not pd: if not pd:
raise UsageError( raise UsageError(
"pandas is required in order to support reporting tables using CSV or a URL, " "pandas or polars is required in order to support reporting tables using CSV or a URL, "
"please install the pandas python package" "please install the pandas or polars python package"
) )
index_added = not isinstance(table_plot.index, pd.RangeIndex) index_added = not isinstance(table_plot.index, pd.RangeIndex)
headers_values = list([col] for col in table_plot.columns) headers_values = list([col] for col in table_plot.columns)