Add polars support to Dataset

This commit is contained in:
BlakeJC94 2024-12-02 22:03:16 +11:00
parent a300d7a8bd
commit 3c38f008a0

View File

@ -48,6 +48,10 @@ except ImportError:
except Exception as e: except Exception as e:
logging.warning("ClearML Dataset failed importing pandas: {}".format(e)) logging.warning("ClearML Dataset failed importing pandas: {}".format(e))
pd = None pd = None
try:
import polars as pl
except ImportError:
pl = None
try: try:
import pyarrow # noqa import pyarrow # noqa
@ -850,7 +854,7 @@ class Dataset(object):
return True return True
def set_metadata(self, metadata, metadata_name='metadata', ui_visible=True): def set_metadata(self, metadata, metadata_name='metadata', ui_visible=True):
# type: (Union[numpy.array, pd.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821 # type: (Union[numpy.array, pd.DataFrame, pl.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821
""" """
Attach a user-defined metadata to the dataset. Check `Task.upload_artifact` for supported types. Attach a user-defined metadata to the dataset. Check `Task.upload_artifact` for supported types.
If type is Pandas Dataframes, optionally make it visible as a table in the UI. If type is Pandas Dataframes, optionally make it visible as a table in the UI.
@ -859,7 +863,7 @@ class Dataset(object):
raise ValueError("metadata_name can not start with '{}'".format(self.__data_entry_name_prefix)) raise ValueError("metadata_name can not start with '{}'".format(self.__data_entry_name_prefix))
self._task.upload_artifact(name=metadata_name, artifact_object=metadata) self._task.upload_artifact(name=metadata_name, artifact_object=metadata)
if ui_visible: if ui_visible:
if pd and isinstance(metadata, pd.DataFrame): if (pd and isinstance(metadata, pd.DataFrame)) or (pl and isinstance(metadata, pl.DataFrame)):
self.get_logger().report_table( self.get_logger().report_table(
title='Dataset Metadata', title='Dataset Metadata',
series='Dataset Metadata', series='Dataset Metadata',
@ -872,7 +876,7 @@ class Dataset(object):
) )
def get_metadata(self, metadata_name='metadata'): def get_metadata(self, metadata_name='metadata'):
# type: (str) -> Optional[numpy.array, pd.DataFrame, dict, str, bool] # noqa: F821 # type: (str) -> Optional[numpy.array, pd.DataFrame, pl.DataFrame, dict, str, bool] # noqa: F821
""" """
Get attached metadata back in its original format. Will return None if none was found. Get attached metadata back in its original format. Will return None if none was found.
""" """
@ -3091,19 +3095,34 @@ class Dataset(object):
def convert_to_tabular_artifact(file_path_, file_extension_, compression_=None): def convert_to_tabular_artifact(file_path_, file_extension_, compression_=None):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if file_extension_ == ".csv" and pd: if file_extension_ == ".csv" and (pl or pd):
return pd.read_csv( if pd:
file_path_, return pd.read_csv(
nrows=self.__preview_tabular_row_count, file_path_,
compression=compression_.lstrip(".") if compression_ else None, nrows=self.__preview_tabular_row_count,
) compression=compression_.lstrip(".") if compression_ else None,
elif file_extension_ == ".tsv" and pd: )
return pd.read_csv( else:
file_path_, # TODO Re-implement compression after testing all extensions
sep='\t', return pl.read_csv(
nrows=self.__preview_tabular_row_count, file_path_,
compression=compression_.lstrip(".") if compression_ else None, n_rows=self.__preview_tabular_row_count,
) )
elif file_extension_ == ".tsv" and (pl or pd):
if pd:
return pd.read_csv(
file_path_,
sep='\t',
nrows=self.__preview_tabular_row_count,
compression=compression_.lstrip(".") if compression_ else None,
)
else:
# TODO Re-implement compression after testing all extensions
return pl.read_csv(
file_path_,
n_rows=self.__preview_tabular_row_count,
separator='\t',
)
elif file_extension_ == ".parquet" or file_extension_ == ".parq": elif file_extension_ == ".parquet" or file_extension_ == ".parq":
if pyarrow: if pyarrow:
pf = pyarrow.parquet.ParquetFile(file_path_) pf = pyarrow.parquet.ParquetFile(file_path_)
@ -3112,7 +3131,10 @@ class Dataset(object):
elif fastparquet: elif fastparquet:
return fastparquet.ParquetFile(file_path_).head(self.__preview_tabular_row_count).to_pandas() return fastparquet.ParquetFile(file_path_).head(self.__preview_tabular_row_count).to_pandas()
elif (file_extension_ == ".npz" or file_extension_ == ".npy") and np: elif (file_extension_ == ".npz" or file_extension_ == ".npy") and np:
return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count)) if pd:
return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
else
return pl.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
except Exception: except Exception:
pass pass
return None return None
@ -3144,7 +3166,10 @@ class Dataset(object):
# because it will not upload the sample to that destination. # because it will not upload the sample to that destination.
# use report_media instead to not leak data # use report_media instead to not leak data
if ( if (
isinstance(artifact, pd.DataFrame) (
(pd and isinstance(artifact, pd.DataFrame))
or (pl and isinstance(artifact, pl.DataFrame))
)
and self._task.get_logger().get_default_upload_destination() == Session.get_files_server_host() and self._task.get_logger().get_default_upload_destination() == Session.get_files_server_host()
): ):
self._task.get_logger().report_table("Tables", "summary", table_plot=artifact) self._task.get_logger().report_table("Tables", "summary", table_plot=artifact)