mirror of
https://github.com/clearml/clearml
synced 2025-04-08 06:34:37 +00:00
Add polars support to Dataset
This commit is contained in:
parent
a300d7a8bd
commit
3c38f008a0
@ -48,6 +48,10 @@ except ImportError:
|
||||
except Exception as e:
|
||||
logging.warning("ClearML Dataset failed importing pandas: {}".format(e))
|
||||
pd = None
|
||||
try:
|
||||
import polars as pl
|
||||
except ImportError:
|
||||
pl = None
|
||||
|
||||
try:
|
||||
import pyarrow # noqa
|
||||
@ -850,7 +854,7 @@ class Dataset(object):
|
||||
return True
|
||||
|
||||
def set_metadata(self, metadata, metadata_name='metadata', ui_visible=True):
|
||||
# type: (Union[numpy.array, pd.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821
|
||||
# type: (Union[numpy.array, pd.DataFrame, pl.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821
|
||||
"""
|
||||
Attach a user-defined metadata to the dataset. Check `Task.upload_artifact` for supported types.
|
||||
If type is Pandas Dataframes, optionally make it visible as a table in the UI.
|
||||
@ -859,7 +863,7 @@ class Dataset(object):
|
||||
raise ValueError("metadata_name can not start with '{}'".format(self.__data_entry_name_prefix))
|
||||
self._task.upload_artifact(name=metadata_name, artifact_object=metadata)
|
||||
if ui_visible:
|
||||
if pd and isinstance(metadata, pd.DataFrame):
|
||||
if (pd and isinstance(metadata, pd.DataFrame)) or (pl and isinstance(metadata, pl.DataFrame)):
|
||||
self.get_logger().report_table(
|
||||
title='Dataset Metadata',
|
||||
series='Dataset Metadata',
|
||||
@ -872,7 +876,7 @@ class Dataset(object):
|
||||
)
|
||||
|
||||
def get_metadata(self, metadata_name='metadata'):
|
||||
# type: (str) -> Optional[numpy.array, pd.DataFrame, dict, str, bool] # noqa: F821
|
||||
# type: (str) -> Optional[numpy.array, pd.DataFrame, pl.DataFrame, dict, str, bool] # noqa: F821
|
||||
"""
|
||||
Get attached metadata back in its original format. Will return None if none was found.
|
||||
"""
|
||||
@ -3091,19 +3095,34 @@ class Dataset(object):
|
||||
def convert_to_tabular_artifact(file_path_, file_extension_, compression_=None):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if file_extension_ == ".csv" and pd:
|
||||
return pd.read_csv(
|
||||
file_path_,
|
||||
nrows=self.__preview_tabular_row_count,
|
||||
compression=compression_.lstrip(".") if compression_ else None,
|
||||
)
|
||||
elif file_extension_ == ".tsv" and pd:
|
||||
return pd.read_csv(
|
||||
file_path_,
|
||||
sep='\t',
|
||||
nrows=self.__preview_tabular_row_count,
|
||||
compression=compression_.lstrip(".") if compression_ else None,
|
||||
)
|
||||
if file_extension_ == ".csv" and (pl or pd):
|
||||
if pd:
|
||||
return pd.read_csv(
|
||||
file_path_,
|
||||
nrows=self.__preview_tabular_row_count,
|
||||
compression=compression_.lstrip(".") if compression_ else None,
|
||||
)
|
||||
else:
|
||||
# TODO Re-implement compression after testing all extensions
|
||||
return pl.read_csv(
|
||||
file_path_,
|
||||
n_rows=self.__preview_tabular_row_count,
|
||||
)
|
||||
elif file_extension_ == ".tsv" and (pl or pd):
|
||||
if pd:
|
||||
return pd.read_csv(
|
||||
file_path_,
|
||||
sep='\t',
|
||||
nrows=self.__preview_tabular_row_count,
|
||||
compression=compression_.lstrip(".") if compression_ else None,
|
||||
)
|
||||
else:
|
||||
# TODO Re-implement compression after testing all extensions
|
||||
return pl.read_csv(
|
||||
file_path_,
|
||||
n_rows=self.__preview_tabular_row_count,
|
||||
separator='\t',
|
||||
)
|
||||
elif file_extension_ == ".parquet" or file_extension_ == ".parq":
|
||||
if pyarrow:
|
||||
pf = pyarrow.parquet.ParquetFile(file_path_)
|
||||
@ -3112,7 +3131,10 @@ class Dataset(object):
|
||||
elif fastparquet:
|
||||
return fastparquet.ParquetFile(file_path_).head(self.__preview_tabular_row_count).to_pandas()
|
||||
elif (file_extension_ == ".npz" or file_extension_ == ".npy") and np:
|
||||
return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
|
||||
if pd:
|
||||
return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
|
||||
else
|
||||
return pl.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
@ -3144,7 +3166,10 @@ class Dataset(object):
|
||||
# because it will not upload the sample to that destination.
|
||||
# use report_media instead to not leak data
|
||||
if (
|
||||
isinstance(artifact, pd.DataFrame)
|
||||
(
|
||||
(pd and isinstance(artifact, pd.DataFrame))
|
||||
or (pl and isinstance(artifact, pl.DataFrame))
|
||||
)
|
||||
and self._task.get_logger().get_default_upload_destination() == Session.get_files_server_host()
|
||||
):
|
||||
self._task.get_logger().report_table("Tables", "summary", table_plot=artifact)
|
||||
|
Loading…
Reference in New Issue
Block a user