mirror of
https://github.com/clearml/clearml
synced 2025-04-05 05:10:06 +00:00
Add support for pl.DataFrame in Logger.report_table
This commit is contained in:
parent
01af13a294
commit
315486bf9d
@ -654,7 +654,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param table: The table data
|
||||
:type table: pandas.DataFrame
|
||||
:type table: pandas.DataFrame or polars.DataFrame
|
||||
:param iteration: Iteration number
|
||||
:type iteration: int
|
||||
:param layout_config: optional dictionary for layout configuration, passed directly to plotly
|
||||
|
@ -10,6 +10,11 @@ from pathlib2 import Path
|
||||
|
||||
from .debugging.log import LoggerRoot
|
||||
|
||||
try:
|
||||
import polars as pl
|
||||
except ImportError:
|
||||
pl = None
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
@ -326,7 +331,7 @@ class Logger(object):
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
iteration=None, # type: Optional[int]
|
||||
table_plot=None, # type: Optional[pd.DataFrame, Sequence[Sequence]]
|
||||
table_plot=None, # type: Optional[pd.DataFrame, pl.DataFrame, Sequence[Sequence]]
|
||||
csv=None, # type: Optional[str]
|
||||
url=None, # type: Optional[str]
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
@ -392,15 +397,15 @@ class Logger(object):
|
||||
mutually_exclusive(UsageError, _check_none=True, table_plot=table_plot, csv=csv, url=url)
|
||||
table = table_plot
|
||||
if url or csv:
|
||||
if not pd:
|
||||
if not pd and not pl:
|
||||
raise UsageError(
|
||||
"pandas is required in order to support reporting tables using CSV or a URL, "
|
||||
"please install the pandas python package"
|
||||
"pandas or polars is required in order to support reporting tables using CSV "
|
||||
"or a URL, please install the pandas or polars python package"
|
||||
)
|
||||
if url:
|
||||
table = pd.read_csv(url, index_col=[0])
|
||||
table = pd.read_csv(url, index_col=[0]) if pd else pl.read_csv(url)
|
||||
elif csv:
|
||||
table = pd.read_csv(csv, index_col=[0])
|
||||
table = pd.read_csv(csv, index_col=[0]) if pd else pl.read_csv(csv)
|
||||
|
||||
def replace(dst, *srcs):
|
||||
for src in srcs:
|
||||
@ -409,7 +414,8 @@ class Logger(object):
|
||||
if isinstance(table, (list, tuple)):
|
||||
reporter_table = table
|
||||
else:
|
||||
reporter_table = table.fillna(str(np.nan))
|
||||
nan = str(np.nan)
|
||||
reporter_table = table.fillna(nan) if pd else table.fill_nan(nan)
|
||||
replace("NaN", np.nan, math.nan if six.PY3 else float("nan"))
|
||||
replace("Inf", np.inf, math.inf if six.PY3 else float("inf"))
|
||||
minus_inf = [-np.inf, -math.inf if six.PY3 else -float("inf")]
|
||||
|
@ -4,6 +4,10 @@ import numpy as np
|
||||
from ..errors import UsageError
|
||||
from ..utilities.dicts import merge_dicts
|
||||
|
||||
try:
|
||||
import polars as pl
|
||||
except ImportError:
|
||||
pl = None
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
@ -486,7 +490,7 @@ def create_plotly_table(table_plot, title, series, layout_config=None, data_conf
|
||||
"""
|
||||
Create a basic Plotly table json style to be sent
|
||||
|
||||
:param table_plot: the output table in pandas.DataFrame structure or list of rows (list) in a table
|
||||
:param table_plot: the output table in pandas.DataFrame structure or polars.Dataframe structure or list of rows (list) in a table
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
@ -503,11 +507,19 @@ def create_plotly_table(table_plot, title, series, layout_config=None, data_conf
|
||||
elif is_list and table_plot[0] and isinstance(table_plot[0], (list, tuple)):
|
||||
headers_values = table_plot[0]
|
||||
cells_values = [list(i) for i in zip(*table_plot[1:])]
|
||||
elif pl and isinstance(table_plot, pl.DataFrame):
|
||||
headers_values = list([col] for col in table_plot.columns)
|
||||
# Convert datetimes to ISO strings
|
||||
datetime_columns = table_plot.select(pl.selectors.datetime()).columns
|
||||
exprs = [pl.col(col).dt.to_string("iso:strict") for col in datetime_columns]
|
||||
# Get cell values and preserve value types
|
||||
cells_values_transpose = table_plot.with_columns(*exprs).rows()
|
||||
cells_values = list(map(list, zip(*cells_values_transpose)))
|
||||
else:
|
||||
if not pd:
|
||||
raise UsageError(
|
||||
"pandas is required in order to support reporting tables using CSV or a URL, "
|
||||
"please install the pandas python package"
|
||||
"pandas or polars is required in order to support reporting tables using CSV or a URL, "
|
||||
"please install the pandas or polars python package"
|
||||
)
|
||||
index_added = not isinstance(table_plot.index, pd.RangeIndex)
|
||||
headers_values = list([col] for col in table_plot.columns)
|
||||
|
Loading…
Reference in New Issue
Block a user