This commit is contained in:
Blake 2025-03-26 05:47:25 +00:00 committed by GitHub
commit 49706ebeda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 143 additions and 56 deletions

View File

@ -654,7 +654,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param series: Series (AKA variant)
:type series: str
:param table: The table data
:type table: pandas.DataFrame
:type table: pandas.DataFrame or polars.DataFrame
:param iteration: Iteration number
:type iteration: int
:param layout_config: optional dictionary for layout configuration, passed directly to plotly

View File

@ -30,6 +30,10 @@ from ..utilities.process.mp import SafeEvent, ForkSafeRLock
from ..utilities.proxy_object import LazyEvalWrapper
from ..config import deferred_config, config
try:
import polars as pl
except ImportError:
pl = None
try:
import pandas as pd
DataFrame = pd.DataFrame
@ -152,6 +156,7 @@ class Artifact(object):
Supported content types are:
- dict - ``.json``, ``.yaml``
- pandas.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle``
- polars.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle``
- numpy.ndarray - ``.npz``, ``.csv.gz``
- PIL.Image - whatever content types PIL supports
All other types will return a pathlib2.Path object pointing to a local copy of the artifacts file (or directory).
@ -192,6 +197,18 @@ class Artifact(object):
self._object = pd.read_csv(local_file)
else:
self._object = pd.read_csv(local_file, index_col=[0])
elif self.type == Artifacts._pd_artifact_type or self.type == "polars" and pl:
if self._content_type == "application/parquet":
self._object = pl.read_parquet(local_file)
elif self._content_type == "application/feather":
self._object = pl.read_ipc(local_file)
elif self._content_type == "application/pickle":
with open(local_file, "rb") as f:
self._object = pickle.load(f)
elif self.type == Artifacts._pd_artifact_type:
self._object = pl.read_csv(local_file)
else:
self._object = pl.read_csv(local_file)
elif self.type == "image":
self._object = Image.open(local_file)
elif self.type == "JSON" or self.type == "dict":
@ -280,14 +297,14 @@ class Artifacts(object):
self.artifact_hash_columns = {}
def __setitem__(self, key, value):
# check that value is of type pandas
if pd and isinstance(value, pd.DataFrame):
# check that value is of type pandas or polars
if (pd and isinstance(value, pd.DataFrame)) or (pl and isinstance(value, pl.DataFrame)):
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)
if self._artifacts_manager:
self._artifacts_manager.flush()
else:
raise ValueError('Artifacts currently support pandas.DataFrame objects only')
raise ValueError('Artifacts currently support pandas.DataFrame and polars.DataFrame objects only')
def unregister_artifact(self, name):
self.artifact_metadata.pop(name, None)
@ -472,8 +489,8 @@ class Artifacts(object):
artifact_type_data.content_type = "text/csv"
np.savetxt(local_filename, artifact_object, delimiter=",")
delete_after_upload = True
elif pd and isinstance(artifact_object, pd.DataFrame):
artifact_type = "pandas"
elif (pd and isinstance(artifact_object, pd.DataFrame)) or (pl and isinstance(artifact_object, pl.DataFrame)):
artifact_type = "pandas" if (pd and isinstance(artifact_object, pd.DataFrame)) else "polars"
artifact_type_data.preview = preview or str(artifact_object.__repr__())
# we are making sure self._default_pandas_dataframe_extension_name is not deferred
extension_name = extension_name or str(self._default_pandas_dataframe_extension_name or "")
@ -484,9 +501,9 @@ class Artifacts(object):
local_filename = self._push_temp_file(
prefix=quote(name, safe="") + ".", suffix=override_filename_ext_in_uri
)
if (
if (pd and isinstance(artifact_object, pd.DataFrame) and (
isinstance(artifact_object.index, pd.MultiIndex) or isinstance(artifact_object.columns, pd.MultiIndex)
) and not extension_name:
) or (pl and isinstance(artifact_object, pl.DataFrame))) and not extension_name:
store_as_pickle = True
elif override_filename_ext_in_uri == ".csv.gz":
artifact_type_data.content_type = "text/csv"
@ -494,7 +511,10 @@ class Artifacts(object):
elif override_filename_ext_in_uri == ".parquet":
try:
artifact_type_data.content_type = "application/parquet"
artifact_object.to_parquet(local_filename)
if (pd and isinstance(artifact_object, pd.DataFrame)):
artifact_object.to_parquet(local_filename)
else:
artifact_object.write_parquet(local_filename)
except Exception as e:
LoggerRoot.get_base_logger().warning(
"Exception '{}' encountered when uploading artifact as .parquet. Defaulting to .csv.gz".format(
@ -506,7 +526,10 @@ class Artifacts(object):
elif override_filename_ext_in_uri == ".feather":
try:
artifact_type_data.content_type = "application/feather"
artifact_object.to_feather(local_filename)
if (pd and isinstance(artifact_object, pd.DataFrame)):
artifact_object.to_feather(local_filename)
else:
artifact_object.write_ipc(local_filename)
except Exception as e:
LoggerRoot.get_base_logger().warning(
"Exception '{}' encountered when uploading artifact as .feather. Defaulting to .csv.gz".format(
@ -517,7 +540,11 @@ class Artifacts(object):
self._store_compressed_pd_csv(artifact_object, local_filename)
elif override_filename_ext_in_uri == ".pickle":
artifact_type_data.content_type = "application/pickle"
artifact_object.to_pickle(local_filename)
if (pl and isinstance(artifact_object, pd.DataFrame)):
artifact_object.to_pickle(local_filename)
else:
with open(local_filename, "wb") as f:
pickle.dump(artifact_object, f)
delete_after_upload = True
elif isinstance(artifact_object, Image.Image):
artifact_type = "image"
@ -1006,7 +1033,7 @@ class Artifacts(object):
artifacts_summary = []
for a_name, a_df in artifacts_dict.items():
hash_cols = self._artifacts_container.get_hash_columns(a_name)
if not pd or not isinstance(a_df, pd.DataFrame):
if not pd or not isinstance(a_df, pd.DataFrame) or not pl or not isinstance(a_df, pl.DataFrame):
continue
if hash_cols is True:
@ -1038,8 +1065,12 @@ class Artifacts(object):
a_shape = a_df.shape
# parallelize
a_hash_cols = a_df.drop(columns=hash_col_drop)
thread_pool.map(hash_row, a_hash_cols.values)
if pd and isinstance(a_df, pd.DataFrame):
a_hash_cols = a_df.drop(columns=hash_col_drop)
thread_pool.map(hash_row, a_hash_cols.values)
else:
a_hash_cols = a_df.drop(hash_col_drop)
a_unique_hash.add(a_hash_cols.hash_rows())
# add result
artifacts_summary.append((a_name, a_shape, a_unique_hash,))
@ -1083,16 +1114,19 @@ class Artifacts(object):
# (otherwise it is encoded and creates new hash every time)
if self._compression == "gzip":
with gzip.GzipFile(local_filename, 'wb', mtime=0) as gzip_file:
try:
pd_version = int(pd.__version__.split(".")[0])
except ValueError:
pd_version = 0
if pd_version >= 2:
artifact_object.to_csv(gzip_file, **kwargs)
if pl and isinstance(artifact_object, pl.DataFrame):
artifact_object.write_csv(gzip_file)
else:
# old (pandas<2) versions of pandas cannot handle direct gzip stream, so we manually encode it
artifact_object.to_csv(io.TextIOWrapper(gzip_file), **kwargs)
try:
pd_version = int(pd.__version__.split(".")[0])
except ValueError:
pd_version = 0
if pd_version >= 2:
artifact_object.to_csv(gzip_file, **kwargs)
else:
# old (pandas<2) versions of pandas cannot handle direct gzip stream, so we manually encode it
artifact_object.to_csv(io.TextIOWrapper(gzip_file), **kwargs)
else:
artifact_object.to_csv(local_filename, compression=self._compression)

View File

@ -48,6 +48,10 @@ except ImportError:
except Exception as e:
logging.warning("ClearML Dataset failed importing pandas: {}".format(e))
pd = None
try:
import polars as pl
except ImportError:
pl = None
try:
import pyarrow # noqa
@ -852,7 +856,7 @@ class Dataset(object):
return True
def set_metadata(self, metadata, metadata_name='metadata', ui_visible=True):
# type: (Union[numpy.array, pd.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821
# type: (Union[numpy.array, pd.DataFrame, pl.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821
"""
Attach a user-defined metadata to the dataset. Check `Task.upload_artifact` for supported types.
If type is Pandas Dataframes, optionally make it visible as a table in the UI.
@ -861,7 +865,7 @@ class Dataset(object):
raise ValueError("metadata_name can not start with '{}'".format(self.__data_entry_name_prefix))
self._task.upload_artifact(name=metadata_name, artifact_object=metadata)
if ui_visible:
if pd and isinstance(metadata, pd.DataFrame):
if (pd and isinstance(metadata, pd.DataFrame)) or (pl and isinstance(metadata, pl.DataFrame)):
self.get_logger().report_table(
title='Dataset Metadata',
series='Dataset Metadata',
@ -874,7 +878,7 @@ class Dataset(object):
)
def get_metadata(self, metadata_name='metadata'):
# type: (str) -> Optional[numpy.array, pd.DataFrame, dict, str, bool] # noqa: F821
# type: (str) -> Optional[numpy.array, pd.DataFrame, pl.DataFrame, dict, str, bool] # noqa: F821
"""
Get attached metadata back in its original format. Will return None if none was found.
"""
@ -3119,19 +3123,34 @@ class Dataset(object):
def convert_to_tabular_artifact(file_path_, file_extension_, compression_=None):
# noinspection PyBroadException
try:
if file_extension_ == ".csv" and pd:
return pd.read_csv(
file_path_,
nrows=self.__preview_tabular_row_count,
compression=compression_.lstrip(".") if compression_ else None,
)
elif file_extension_ == ".tsv" and pd:
return pd.read_csv(
file_path_,
sep='\t',
nrows=self.__preview_tabular_row_count,
compression=compression_.lstrip(".") if compression_ else None,
)
if file_extension_ == ".csv" and (pl or pd):
if pd:
return pd.read_csv(
file_path_,
nrows=self.__preview_tabular_row_count,
compression=compression_.lstrip(".") if compression_ else None,
)
else:
# TODO Re-implement compression after testing all extensions
return pl.read_csv(
file_path_,
n_rows=self.__preview_tabular_row_count,
)
elif file_extension_ == ".tsv" and (pl or pd):
if pd:
return pd.read_csv(
file_path_,
sep='\t',
nrows=self.__preview_tabular_row_count,
compression=compression_.lstrip(".") if compression_ else None,
)
else:
# TODO Re-implement compression after testing all extensions
return pl.read_csv(
file_path_,
n_rows=self.__preview_tabular_row_count,
separator='\t',
)
elif file_extension_ == ".parquet" or file_extension_ == ".parq":
if pyarrow:
pf = pyarrow.parquet.ParquetFile(file_path_)
@ -3140,7 +3159,10 @@ class Dataset(object):
elif fastparquet:
return fastparquet.ParquetFile(file_path_).head(self.__preview_tabular_row_count).to_pandas()
elif (file_extension_ == ".npz" or file_extension_ == ".npy") and np:
return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
if pd:
return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
else
return pl.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
except Exception:
pass
return None
@ -3172,7 +3194,10 @@ class Dataset(object):
# because it will not upload the sample to that destination.
# use report_media instead to not leak data
if (
isinstance(artifact, pd.DataFrame)
(
(pd and isinstance(artifact, pd.DataFrame))
or (pl and isinstance(artifact, pl.DataFrame))
)
and self._task.get_logger().get_default_upload_destination() == Session.get_files_server_host()
):
self._task.get_logger().report_table("Tables", "summary", table_plot=artifact)

View File

@ -10,6 +10,11 @@ from pathlib2 import Path
from .debugging.log import LoggerRoot
try:
import polars as pl
except ImportError:
pl = None
try:
import pandas as pd
except ImportError:
@ -327,7 +332,7 @@ class Logger(object):
title, # type: str
series, # type: str
iteration=None, # type: Optional[int]
table_plot=None, # type: Optional[pd.DataFrame, Sequence[Sequence]]
table_plot=None, # type: Optional[pd.DataFrame, pl.DataFrame, Sequence[Sequence]]
csv=None, # type: Optional[str]
url=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict]
@ -393,15 +398,15 @@ class Logger(object):
mutually_exclusive(UsageError, _check_none=True, table_plot=table_plot, csv=csv, url=url)
table = table_plot
if url or csv:
if not pd:
if not pd and not pl:
raise UsageError(
"pandas is required in order to support reporting tables using CSV or a URL, "
"please install the pandas python package"
"pandas or polars is required in order to support reporting tables using CSV "
"or a URL, please install the pandas or polars python package"
)
if url:
table = pd.read_csv(url, index_col=[0])
table = pd.read_csv(url, index_col=[0]) if pd else pl.read_csv(url)
elif csv:
table = pd.read_csv(csv, index_col=[0])
table = pd.read_csv(csv, index_col=[0]) if pd else pl.read_csv(csv)
def replace(dst, *srcs):
for src in srcs:
@ -410,7 +415,8 @@ class Logger(object):
if isinstance(table, (list, tuple)):
reporter_table = table
else:
reporter_table = table.fillna(str(np.nan))
nan = str(np.nan)
reporter_table = table.fillna(nan) if pd else table.fill_nan(nan)
replace("NaN", np.nan, math.nan if six.PY3 else float("nan"))
replace("Inf", np.inf, math.inf if six.PY3 else float("inf"))
minus_inf = [-np.inf, -math.inf if six.PY3 else -float("inf")]

View File

@ -14,6 +14,10 @@ try:
import pandas as pd
except ImportError:
pd = None
try:
import polars as pl
except ImportError:
pl = None
from .backend_api import Session
from .backend_api.services import models, projects
@ -638,15 +642,21 @@ class BaseModel(object):
)
table = table_plot
if url or csv:
if not pd:
if not pd and not pl:
raise UsageError(
"pandas is required in order to support reporting tables using CSV or a URL, "
"please install the pandas python package"
"pandas or polars is required in order to support reporting tables using CSV or a URL, "
"please install the pandas or polars python package"
)
if url:
table = pd.read_csv(url, index_col=[0])
if pd:
table = pd.read_csv(url, index_col=[0])
else:
table = pd.read_csv(url)
elif csv:
table = pd.read_csv(csv, index_col=[0])
if pd:
table = pd.read_csv(csv, index_col=[0])
else:
table = pd.read_csv(url)
def replace(dst, *srcs):
for src in srcs:

View File

@ -4,6 +4,10 @@ import numpy as np
from ..errors import UsageError
from ..utilities.dicts import merge_dicts
try:
import polars as pl
except ImportError:
pl = None
try:
import pandas as pd
except ImportError:
@ -486,7 +490,7 @@ def create_plotly_table(table_plot, title, series, layout_config=None, data_conf
"""
Create a basic Plotly table json style to be sent
:param table_plot: the output table in pandas.DataFrame structure or list of rows (list) in a table
:param table_plot: the output table in pandas.DataFrame structure or polars.Dataframe structure or list of rows (list) in a table
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
@ -503,11 +507,19 @@ def create_plotly_table(table_plot, title, series, layout_config=None, data_conf
elif is_list and table_plot[0] and isinstance(table_plot[0], (list, tuple)):
headers_values = table_plot[0]
cells_values = [list(i) for i in zip(*table_plot[1:])]
elif pl and isinstance(table_plot, pl.DataFrame):
headers_values = list([col] for col in table_plot.columns)
# Convert datetimes to ISO strings
datetime_columns = table_plot.select(pl.selectors.datetime()).columns
exprs = [pl.col(col).dt.to_string("iso:strict") for col in datetime_columns]
# Get cell values and preserve value types
cells_values_transpose = table_plot.with_columns(*exprs).rows()
cells_values = list(map(list, zip(*cells_values_transpose)))
else:
if not pd:
raise UsageError(
"pandas is required in order to support reporting tables using CSV or a URL, "
"please install the pandas python package"
"pandas or polars is required in order to support reporting tables using CSV or a URL, "
"please install the pandas or polars python package"
)
index_added = not isinstance(table_plot.index, pd.RangeIndex)
headers_values = list([col] for col in table_plot.columns)