Add polars support to Dataset

This commit is contained in:
BlakeJC94 2024-12-02 22:03:16 +11:00
parent a300d7a8bd
commit 3c38f008a0

View File

@ -48,6 +48,10 @@ except ImportError:
except Exception as e:
logging.warning("ClearML Dataset failed importing pandas: {}".format(e))
pd = None
try:
import polars as pl
except ImportError:
pl = None
try:
import pyarrow # noqa
@ -850,7 +854,7 @@ class Dataset(object):
return True
def set_metadata(self, metadata, metadata_name='metadata', ui_visible=True):
# type: (Union[numpy.array, pd.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821
# type: (Union[numpy.array, pd.DataFrame, pl.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821
"""
Attach a user-defined metadata to the dataset. Check `Task.upload_artifact` for supported types.
If type is Pandas Dataframes, optionally make it visible as a table in the UI.
@ -859,7 +863,7 @@ class Dataset(object):
raise ValueError("metadata_name can not start with '{}'".format(self.__data_entry_name_prefix))
self._task.upload_artifact(name=metadata_name, artifact_object=metadata)
if ui_visible:
if pd and isinstance(metadata, pd.DataFrame):
if (pd and isinstance(metadata, pd.DataFrame)) or (pl and isinstance(metadata, pl.DataFrame)):
self.get_logger().report_table(
title='Dataset Metadata',
series='Dataset Metadata',
@ -872,7 +876,7 @@ class Dataset(object):
)
def get_metadata(self, metadata_name='metadata'):
# type: (str) -> Optional[numpy.array, pd.DataFrame, dict, str, bool] # noqa: F821
# type: (str) -> Optional[numpy.array, pd.DataFrame, pl.DataFrame, dict, str, bool] # noqa: F821
"""
Get attached metadata back in its original format. Will return None if none was found.
"""
@ -3091,19 +3095,34 @@ class Dataset(object):
def convert_to_tabular_artifact(file_path_, file_extension_, compression_=None):
# noinspection PyBroadException
try:
if file_extension_ == ".csv" and pd:
return pd.read_csv(
file_path_,
nrows=self.__preview_tabular_row_count,
compression=compression_.lstrip(".") if compression_ else None,
)
elif file_extension_ == ".tsv" and pd:
return pd.read_csv(
file_path_,
sep='\t',
nrows=self.__preview_tabular_row_count,
compression=compression_.lstrip(".") if compression_ else None,
)
if file_extension_ == ".csv" and (pl or pd):
if pd:
return pd.read_csv(
file_path_,
nrows=self.__preview_tabular_row_count,
compression=compression_.lstrip(".") if compression_ else None,
)
else:
# TODO Re-implement compression after testing all extensions
return pl.read_csv(
file_path_,
n_rows=self.__preview_tabular_row_count,
)
elif file_extension_ == ".tsv" and (pl or pd):
if pd:
return pd.read_csv(
file_path_,
sep='\t',
nrows=self.__preview_tabular_row_count,
compression=compression_.lstrip(".") if compression_ else None,
)
else:
# TODO Re-implement compression after testing all extensions
return pl.read_csv(
file_path_,
n_rows=self.__preview_tabular_row_count,
separator='\t',
)
elif file_extension_ == ".parquet" or file_extension_ == ".parq":
if pyarrow:
pf = pyarrow.parquet.ParquetFile(file_path_)
@ -3112,7 +3131,10 @@ class Dataset(object):
elif fastparquet:
return fastparquet.ParquetFile(file_path_).head(self.__preview_tabular_row_count).to_pandas()
elif (file_extension_ == ".npz" or file_extension_ == ".npy") and np:
return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
if pd:
return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
else
return pl.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
except Exception:
pass
return None
@ -3144,7 +3166,10 @@ class Dataset(object):
# because it will not upload the sample to that destination.
# use report_media instead to not leak data
if (
isinstance(artifact, pd.DataFrame)
(
(pd and isinstance(artifact, pd.DataFrame))
or (pl and isinstance(artifact, pl.DataFrame))
)
and self._task.get_logger().get_default_upload_destination() == Session.get_files_server_host()
):
self._task.get_logger().report_table("Tables", "summary", table_plot=artifact)