Add support for pl.DataFrame in Logger.report_table

This commit is contained in:
BlakeJC94 2024-12-02 20:43:46 +11:00
parent 01af13a294
commit 315486bf9d
3 changed files with 29 additions and 11 deletions

View File

@ -654,7 +654,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param series: Series (AKA variant)
:type series: str
:param table: The table data
:type table: pandas.DataFrame
:type table: pandas.DataFrame or polars.DataFrame
:param iteration: Iteration number
:type iteration: int
:param layout_config: optional dictionary for layout configuration, passed directly to plotly

View File

@ -10,6 +10,11 @@ from pathlib2 import Path
from .debugging.log import LoggerRoot
try:
import polars as pl
except ImportError:
pl = None
try:
import pandas as pd
except ImportError:
@ -326,7 +331,7 @@ class Logger(object):
title, # type: str
series, # type: str
iteration=None, # type: Optional[int]
table_plot=None, # type: Optional[pd.DataFrame, Sequence[Sequence]]
table_plot=None, # type: Optional[pd.DataFrame, pl.DataFrame, Sequence[Sequence]]
csv=None, # type: Optional[str]
url=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict]
@ -392,15 +397,15 @@ class Logger(object):
mutually_exclusive(UsageError, _check_none=True, table_plot=table_plot, csv=csv, url=url)
table = table_plot
if url or csv:
if not pd:
if not pd and not pl:
raise UsageError(
"pandas is required in order to support reporting tables using CSV or a URL, "
"please install the pandas python package"
"pandas or polars is required in order to support reporting tables using CSV "
"or a URL, please install the pandas or polars python package"
)
if url:
table = pd.read_csv(url, index_col=[0])
table = pd.read_csv(url, index_col=[0]) if pd else pl.read_csv(url)
elif csv:
table = pd.read_csv(csv, index_col=[0])
table = pd.read_csv(csv, index_col=[0]) if pd else pl.read_csv(csv)
def replace(dst, *srcs):
for src in srcs:
@ -409,7 +414,8 @@ class Logger(object):
if isinstance(table, (list, tuple)):
reporter_table = table
else:
reporter_table = table.fillna(str(np.nan))
nan = str(np.nan)
reporter_table = table.fillna(nan) if pd else table.fill_nan(nan)
replace("NaN", np.nan, math.nan if six.PY3 else float("nan"))
replace("Inf", np.inf, math.inf if six.PY3 else float("inf"))
minus_inf = [-np.inf, -math.inf if six.PY3 else -float("inf")]

View File

@ -4,6 +4,10 @@ import numpy as np
from ..errors import UsageError
from ..utilities.dicts import merge_dicts
try:
import polars as pl
except ImportError:
pl = None
try:
import pandas as pd
except ImportError:
@ -486,7 +490,7 @@ def create_plotly_table(table_plot, title, series, layout_config=None, data_conf
"""
Create a basic Plotly table json style to be sent
:param table_plot: the output table in pandas.DataFrame structure or list of rows (list) in a table
:param table_plot: the output table in pandas.DataFrame structure or polars.Dataframe structure or list of rows (list) in a table
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
@ -503,11 +507,19 @@ def create_plotly_table(table_plot, title, series, layout_config=None, data_conf
elif is_list and table_plot[0] and isinstance(table_plot[0], (list, tuple)):
headers_values = table_plot[0]
cells_values = [list(i) for i in zip(*table_plot[1:])]
elif pl and isinstance(table_plot, pl.DataFrame):
headers_values = list([col] for col in table_plot.columns)
# Convert datetimes to ISO strings
datetime_columns = table_plot.select(pl.selectors.datetime()).columns
exprs = [pl.col(col).dt.to_string("iso:strict") for col in datetime_columns]
# Get cell values and preserve value types
cells_values_transpose = table_plot.with_columns(*exprs).rows()
cells_values = list(map(list, zip(*cells_values_transpose)))
else:
if not pd:
raise UsageError(
"pandas is required in order to support reporting tables using CSV or a URL, "
"please install the pandas python package"
"pandas or polars is required in order to support reporting tables using CSV or a URL, "
"please install the pandas or polars python package"
)
index_added = not isinstance(table_plot.index, pd.RangeIndex)
headers_values = list([col] for col in table_plot.columns)