mirror of
https://github.com/clearml/clearml
synced 2025-04-20 14:25:19 +00:00
Add polars support to Dataset
This commit is contained in:
parent
a300d7a8bd
commit
3c38f008a0
@ -48,6 +48,10 @@ except ImportError:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning("ClearML Dataset failed importing pandas: {}".format(e))
|
logging.warning("ClearML Dataset failed importing pandas: {}".format(e))
|
||||||
pd = None
|
pd = None
|
||||||
|
try:
|
||||||
|
import polars as pl
|
||||||
|
except ImportError:
|
||||||
|
pl = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pyarrow # noqa
|
import pyarrow # noqa
|
||||||
@ -850,7 +854,7 @@ class Dataset(object):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def set_metadata(self, metadata, metadata_name='metadata', ui_visible=True):
|
def set_metadata(self, metadata, metadata_name='metadata', ui_visible=True):
|
||||||
# type: (Union[numpy.array, pd.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821
|
# type: (Union[numpy.array, pd.DataFrame, pl.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821
|
||||||
"""
|
"""
|
||||||
Attach a user-defined metadata to the dataset. Check `Task.upload_artifact` for supported types.
|
Attach a user-defined metadata to the dataset. Check `Task.upload_artifact` for supported types.
|
||||||
If type is Pandas Dataframes, optionally make it visible as a table in the UI.
|
If type is Pandas Dataframes, optionally make it visible as a table in the UI.
|
||||||
@ -859,7 +863,7 @@ class Dataset(object):
|
|||||||
raise ValueError("metadata_name can not start with '{}'".format(self.__data_entry_name_prefix))
|
raise ValueError("metadata_name can not start with '{}'".format(self.__data_entry_name_prefix))
|
||||||
self._task.upload_artifact(name=metadata_name, artifact_object=metadata)
|
self._task.upload_artifact(name=metadata_name, artifact_object=metadata)
|
||||||
if ui_visible:
|
if ui_visible:
|
||||||
if pd and isinstance(metadata, pd.DataFrame):
|
if (pd and isinstance(metadata, pd.DataFrame)) or (pl and isinstance(metadata, pl.DataFrame)):
|
||||||
self.get_logger().report_table(
|
self.get_logger().report_table(
|
||||||
title='Dataset Metadata',
|
title='Dataset Metadata',
|
||||||
series='Dataset Metadata',
|
series='Dataset Metadata',
|
||||||
@ -872,7 +876,7 @@ class Dataset(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_metadata(self, metadata_name='metadata'):
|
def get_metadata(self, metadata_name='metadata'):
|
||||||
# type: (str) -> Optional[numpy.array, pd.DataFrame, dict, str, bool] # noqa: F821
|
# type: (str) -> Optional[numpy.array, pd.DataFrame, pl.DataFrame, dict, str, bool] # noqa: F821
|
||||||
"""
|
"""
|
||||||
Get attached metadata back in its original format. Will return None if none was found.
|
Get attached metadata back in its original format. Will return None if none was found.
|
||||||
"""
|
"""
|
||||||
@ -3091,19 +3095,34 @@ class Dataset(object):
|
|||||||
def convert_to_tabular_artifact(file_path_, file_extension_, compression_=None):
|
def convert_to_tabular_artifact(file_path_, file_extension_, compression_=None):
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
if file_extension_ == ".csv" and pd:
|
if file_extension_ == ".csv" and (pl or pd):
|
||||||
return pd.read_csv(
|
if pd:
|
||||||
file_path_,
|
return pd.read_csv(
|
||||||
nrows=self.__preview_tabular_row_count,
|
file_path_,
|
||||||
compression=compression_.lstrip(".") if compression_ else None,
|
nrows=self.__preview_tabular_row_count,
|
||||||
)
|
compression=compression_.lstrip(".") if compression_ else None,
|
||||||
elif file_extension_ == ".tsv" and pd:
|
)
|
||||||
return pd.read_csv(
|
else:
|
||||||
file_path_,
|
# TODO Re-implement compression after testing all extensions
|
||||||
sep='\t',
|
return pl.read_csv(
|
||||||
nrows=self.__preview_tabular_row_count,
|
file_path_,
|
||||||
compression=compression_.lstrip(".") if compression_ else None,
|
n_rows=self.__preview_tabular_row_count,
|
||||||
)
|
)
|
||||||
|
elif file_extension_ == ".tsv" and (pl or pd):
|
||||||
|
if pd:
|
||||||
|
return pd.read_csv(
|
||||||
|
file_path_,
|
||||||
|
sep='\t',
|
||||||
|
nrows=self.__preview_tabular_row_count,
|
||||||
|
compression=compression_.lstrip(".") if compression_ else None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# TODO Re-implement compression after testing all extensions
|
||||||
|
return pl.read_csv(
|
||||||
|
file_path_,
|
||||||
|
n_rows=self.__preview_tabular_row_count,
|
||||||
|
separator='\t',
|
||||||
|
)
|
||||||
elif file_extension_ == ".parquet" or file_extension_ == ".parq":
|
elif file_extension_ == ".parquet" or file_extension_ == ".parq":
|
||||||
if pyarrow:
|
if pyarrow:
|
||||||
pf = pyarrow.parquet.ParquetFile(file_path_)
|
pf = pyarrow.parquet.ParquetFile(file_path_)
|
||||||
@ -3112,7 +3131,10 @@ class Dataset(object):
|
|||||||
elif fastparquet:
|
elif fastparquet:
|
||||||
return fastparquet.ParquetFile(file_path_).head(self.__preview_tabular_row_count).to_pandas()
|
return fastparquet.ParquetFile(file_path_).head(self.__preview_tabular_row_count).to_pandas()
|
||||||
elif (file_extension_ == ".npz" or file_extension_ == ".npy") and np:
|
elif (file_extension_ == ".npz" or file_extension_ == ".npy") and np:
|
||||||
return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
|
if pd:
|
||||||
|
return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
|
||||||
|
else
|
||||||
|
return pl.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return None
|
return None
|
||||||
@ -3144,7 +3166,10 @@ class Dataset(object):
|
|||||||
# because it will not upload the sample to that destination.
|
# because it will not upload the sample to that destination.
|
||||||
# use report_media instead to not leak data
|
# use report_media instead to not leak data
|
||||||
if (
|
if (
|
||||||
isinstance(artifact, pd.DataFrame)
|
(
|
||||||
|
(pd and isinstance(artifact, pd.DataFrame))
|
||||||
|
or (pl and isinstance(artifact, pl.DataFrame))
|
||||||
|
)
|
||||||
and self._task.get_logger().get_default_upload_destination() == Session.get_files_server_host()
|
and self._task.get_logger().get_default_upload_destination() == Session.get_files_server_host()
|
||||||
):
|
):
|
||||||
self._task.get_logger().report_table("Tables", "summary", table_plot=artifact)
|
self._task.get_logger().report_table("Tables", "summary", table_plot=artifact)
|
||||||
|
Loading…
Reference in New Issue
Block a user