From 3c38f008a07f8c498f59d37d38cbb33ba864b595 Mon Sep 17 00:00:00 2001 From: BlakeJC94 Date: Mon, 2 Dec 2024 22:03:16 +1100 Subject: [PATCH] Add polars support to Dataset --- clearml/datasets/dataset.py | 61 ++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/clearml/datasets/dataset.py b/clearml/datasets/dataset.py index 4ac0e001..5efda250 100644 --- a/clearml/datasets/dataset.py +++ b/clearml/datasets/dataset.py @@ -48,6 +48,10 @@ except ImportError: except Exception as e: logging.warning("ClearML Dataset failed importing pandas: {}".format(e)) pd = None +try: + import polars as pl +except ImportError: + pl = None try: import pyarrow # noqa @@ -850,7 +854,7 @@ class Dataset(object): return True def set_metadata(self, metadata, metadata_name='metadata', ui_visible=True): - # type: (Union[numpy.array, pd.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821 + # type: (Union[numpy.array, pd.DataFrame, pl.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821 """ Attach a user-defined metadata to the dataset. Check `Task.upload_artifact` for supported types. If type is Pandas Dataframes, optionally make it visible as a table in the UI. @@ -859,7 +863,7 @@ class Dataset(object): raise ValueError("metadata_name can not start with '{}'".format(self.__data_entry_name_prefix)) self._task.upload_artifact(name=metadata_name, artifact_object=metadata) if ui_visible: - if pd and isinstance(metadata, pd.DataFrame): + if (pd and isinstance(metadata, pd.DataFrame)) or (pl and isinstance(metadata, pl.DataFrame)): self.get_logger().report_table( title='Dataset Metadata', series='Dataset Metadata', @@ -872,7 +876,7 @@ class Dataset(object): ) def get_metadata(self, metadata_name='metadata'): - # type: (str) -> Optional[numpy.array, pd.DataFrame, dict, str, bool] # noqa: F821 + # type: (str) -> Optional[numpy.array, pd.DataFrame, pl.DataFrame, dict, str, bool] # noqa: F821 """ Get attached metadata back in its original format. Will return None if none was found. """ @@ -3091,19 +3095,34 @@ class Dataset(object): def convert_to_tabular_artifact(file_path_, file_extension_, compression_=None): # noinspection PyBroadException try: - if file_extension_ == ".csv" and pd: - return pd.read_csv( - file_path_, - nrows=self.__preview_tabular_row_count, - compression=compression_.lstrip(".") if compression_ else None, - ) - elif file_extension_ == ".tsv" and pd: - return pd.read_csv( - file_path_, - sep='\t', - nrows=self.__preview_tabular_row_count, - compression=compression_.lstrip(".") if compression_ else None, - ) + if file_extension_ == ".csv" and (pl or pd): + if pd: + return pd.read_csv( + file_path_, + nrows=self.__preview_tabular_row_count, + compression=compression_.lstrip(".") if compression_ else None, + ) + else: + # TODO Re-implement compression after testing all extensions + return pl.read_csv( + file_path_, + n_rows=self.__preview_tabular_row_count, + ) + elif file_extension_ == ".tsv" and (pl or pd): + if pd: + return pd.read_csv( + file_path_, + sep='\t', + nrows=self.__preview_tabular_row_count, + compression=compression_.lstrip(".") if compression_ else None, + ) + else: + # TODO Re-implement compression after testing all extensions + return pl.read_csv( + file_path_, + n_rows=self.__preview_tabular_row_count, + separator='\t', + ) elif file_extension_ == ".parquet" or file_extension_ == ".parq": if pyarrow: pf = pyarrow.parquet.ParquetFile(file_path_) @@ -3112,7 +3131,10 @@ class Dataset(object): elif fastparquet: return fastparquet.ParquetFile(file_path_).head(self.__preview_tabular_row_count).to_pandas() elif (file_extension_ == ".npz" or file_extension_ == ".npy") and np: - return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count)) + if pd: + return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count)) + else + return pl.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count)) except Exception: pass return None @@ -3144,7 +3166,10 @@ class Dataset(object): # because it will not upload the sample to that destination. # use report_media instead to not leak data if ( - isinstance(artifact, pd.DataFrame) + ( + (pd and isinstance(artifact, pd.DataFrame)) + or (pl and isinstance(artifact, pl.DataFrame)) + ) and self._task.get_logger().get_default_upload_destination() == Session.get_files_server_host() ): self._task.get_logger().report_table("Tables", "summary", table_plot=artifact)