mirror of
https://github.com/clearml/clearml
synced 2025-04-16 21:42:10 +00:00
Add polars support to Artifact
This commit is contained in:
parent
315486bf9d
commit
a300d7a8bd
@ -30,6 +30,10 @@ from ..utilities.process.mp import SafeEvent, ForkSafeRLock
|
|||||||
from ..utilities.proxy_object import LazyEvalWrapper
|
from ..utilities.proxy_object import LazyEvalWrapper
|
||||||
from ..config import deferred_config, config
|
from ..config import deferred_config, config
|
||||||
|
|
||||||
|
try:
|
||||||
|
import polars as pl
|
||||||
|
except ImportError:
|
||||||
|
pl = None
|
||||||
try:
|
try:
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
DataFrame = pd.DataFrame
|
DataFrame = pd.DataFrame
|
||||||
@ -152,6 +156,7 @@ class Artifact(object):
|
|||||||
Supported content types are:
|
Supported content types are:
|
||||||
- dict - ``.json``, ``.yaml``
|
- dict - ``.json``, ``.yaml``
|
||||||
- pandas.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle``
|
- pandas.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle``
|
||||||
|
- polars.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle``
|
||||||
- numpy.ndarray - ``.npz``, ``.csv.gz``
|
- numpy.ndarray - ``.npz``, ``.csv.gz``
|
||||||
- PIL.Image - whatever content types PIL supports
|
- PIL.Image - whatever content types PIL supports
|
||||||
All other types will return a pathlib2.Path object pointing to a local copy of the artifacts file (or directory).
|
All other types will return a pathlib2.Path object pointing to a local copy of the artifacts file (or directory).
|
||||||
@ -192,6 +197,18 @@ class Artifact(object):
|
|||||||
self._object = pd.read_csv(local_file)
|
self._object = pd.read_csv(local_file)
|
||||||
else:
|
else:
|
||||||
self._object = pd.read_csv(local_file, index_col=[0])
|
self._object = pd.read_csv(local_file, index_col=[0])
|
||||||
|
elif self.type == Artifacts._pd_artifact_type or self.type == "polars" and pl:
|
||||||
|
if self._content_type == "application/parquet":
|
||||||
|
self._object = pl.read_parquet(local_file)
|
||||||
|
elif self._content_type == "application/feather":
|
||||||
|
self._object = pl.read_ipc(local_file)
|
||||||
|
elif self._content_type == "application/pickle":
|
||||||
|
with open(local_file, "rb") as f:
|
||||||
|
self._object = pickle.load(f)
|
||||||
|
elif self.type == Artifacts._pd_artifact_type:
|
||||||
|
self._object = pl.read_csv(local_file)
|
||||||
|
else:
|
||||||
|
self._object = pl.read_csv(local_file)
|
||||||
elif self.type == "image":
|
elif self.type == "image":
|
||||||
self._object = Image.open(local_file)
|
self._object = Image.open(local_file)
|
||||||
elif self.type == "JSON" or self.type == "dict":
|
elif self.type == "JSON" or self.type == "dict":
|
||||||
@ -279,14 +296,14 @@ class Artifacts(object):
|
|||||||
self.artifact_hash_columns = {}
|
self.artifact_hash_columns = {}
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
# check that value is of type pandas
|
# check that value is of type pandas or polars
|
||||||
if pd and isinstance(value, pd.DataFrame):
|
if (pd and isinstance(value, pd.DataFrame)) or (pl and isinstance(value, pl.DataFrame)):
|
||||||
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)
|
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)
|
||||||
|
|
||||||
if self._artifacts_manager:
|
if self._artifacts_manager:
|
||||||
self._artifacts_manager.flush()
|
self._artifacts_manager.flush()
|
||||||
else:
|
else:
|
||||||
raise ValueError('Artifacts currently support pandas.DataFrame objects only')
|
raise ValueError('Artifacts currently support pandas.DataFrame and polars.DataFrame objects only')
|
||||||
|
|
||||||
def unregister_artifact(self, name):
|
def unregister_artifact(self, name):
|
||||||
self.artifact_metadata.pop(name, None)
|
self.artifact_metadata.pop(name, None)
|
||||||
@ -471,8 +488,8 @@ class Artifacts(object):
|
|||||||
artifact_type_data.content_type = "text/csv"
|
artifact_type_data.content_type = "text/csv"
|
||||||
np.savetxt(local_filename, artifact_object, delimiter=",")
|
np.savetxt(local_filename, artifact_object, delimiter=",")
|
||||||
delete_after_upload = True
|
delete_after_upload = True
|
||||||
elif pd and isinstance(artifact_object, pd.DataFrame):
|
elif (pd and isinstance(artifact_object, pd.DataFrame)) or (pl and isinstance(artifact_object, pl.DataFrame)):
|
||||||
artifact_type = "pandas"
|
artifact_type = "pandas" if (pd and isinstance(artifact_object, pd.DataFrame)) else "polars"
|
||||||
artifact_type_data.preview = preview or str(artifact_object.__repr__())
|
artifact_type_data.preview = preview or str(artifact_object.__repr__())
|
||||||
# we are making sure self._default_pandas_dataframe_extension_name is not deferred
|
# we are making sure self._default_pandas_dataframe_extension_name is not deferred
|
||||||
extension_name = extension_name or str(self._default_pandas_dataframe_extension_name or "")
|
extension_name = extension_name or str(self._default_pandas_dataframe_extension_name or "")
|
||||||
@ -483,9 +500,9 @@ class Artifacts(object):
|
|||||||
local_filename = self._push_temp_file(
|
local_filename = self._push_temp_file(
|
||||||
prefix=quote(name, safe="") + ".", suffix=override_filename_ext_in_uri
|
prefix=quote(name, safe="") + ".", suffix=override_filename_ext_in_uri
|
||||||
)
|
)
|
||||||
if (
|
if (pd and isinstance(artifact_object, pd.DataFrame) and (
|
||||||
isinstance(artifact_object.index, pd.MultiIndex) or isinstance(artifact_object.columns, pd.MultiIndex)
|
isinstance(artifact_object.index, pd.MultiIndex) or isinstance(artifact_object.columns, pd.MultiIndex)
|
||||||
) and not extension_name:
|
) or (pl and isinstance(artifact_object, pl.DataFrame))) and not extension_name:
|
||||||
store_as_pickle = True
|
store_as_pickle = True
|
||||||
elif override_filename_ext_in_uri == ".csv.gz":
|
elif override_filename_ext_in_uri == ".csv.gz":
|
||||||
artifact_type_data.content_type = "text/csv"
|
artifact_type_data.content_type = "text/csv"
|
||||||
@ -493,7 +510,10 @@ class Artifacts(object):
|
|||||||
elif override_filename_ext_in_uri == ".parquet":
|
elif override_filename_ext_in_uri == ".parquet":
|
||||||
try:
|
try:
|
||||||
artifact_type_data.content_type = "application/parquet"
|
artifact_type_data.content_type = "application/parquet"
|
||||||
artifact_object.to_parquet(local_filename)
|
if (pd and isinstance(artifact_object, pd.DataFrame)):
|
||||||
|
artifact_object.to_parquet(local_filename)
|
||||||
|
else:
|
||||||
|
artifact_object.write_parquet(local_filename)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LoggerRoot.get_base_logger().warning(
|
LoggerRoot.get_base_logger().warning(
|
||||||
"Exception '{}' encountered when uploading artifact as .parquet. Defaulting to .csv.gz".format(
|
"Exception '{}' encountered when uploading artifact as .parquet. Defaulting to .csv.gz".format(
|
||||||
@ -505,7 +525,10 @@ class Artifacts(object):
|
|||||||
elif override_filename_ext_in_uri == ".feather":
|
elif override_filename_ext_in_uri == ".feather":
|
||||||
try:
|
try:
|
||||||
artifact_type_data.content_type = "application/feather"
|
artifact_type_data.content_type = "application/feather"
|
||||||
artifact_object.to_feather(local_filename)
|
if (pd and isinstance(artifact_object, pd.DataFrame)):
|
||||||
|
artifact_object.to_feather(local_filename)
|
||||||
|
else:
|
||||||
|
artifact_object.write_ipc(local_filename)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LoggerRoot.get_base_logger().warning(
|
LoggerRoot.get_base_logger().warning(
|
||||||
"Exception '{}' encountered when uploading artifact as .feather. Defaulting to .csv.gz".format(
|
"Exception '{}' encountered when uploading artifact as .feather. Defaulting to .csv.gz".format(
|
||||||
@ -516,7 +539,11 @@ class Artifacts(object):
|
|||||||
self._store_compressed_pd_csv(artifact_object, local_filename)
|
self._store_compressed_pd_csv(artifact_object, local_filename)
|
||||||
elif override_filename_ext_in_uri == ".pickle":
|
elif override_filename_ext_in_uri == ".pickle":
|
||||||
artifact_type_data.content_type = "application/pickle"
|
artifact_type_data.content_type = "application/pickle"
|
||||||
artifact_object.to_pickle(local_filename)
|
if (pl and isinstance(artifact_object, pd.DataFrame)):
|
||||||
|
artifact_object.to_pickle(local_filename)
|
||||||
|
else:
|
||||||
|
with open(local_filename, "wb") as f:
|
||||||
|
pickle.dump(artifact_object, f)
|
||||||
delete_after_upload = True
|
delete_after_upload = True
|
||||||
elif isinstance(artifact_object, Image.Image):
|
elif isinstance(artifact_object, Image.Image):
|
||||||
artifact_type = "image"
|
artifact_type = "image"
|
||||||
@ -1005,7 +1032,7 @@ class Artifacts(object):
|
|||||||
artifacts_summary = []
|
artifacts_summary = []
|
||||||
for a_name, a_df in artifacts_dict.items():
|
for a_name, a_df in artifacts_dict.items():
|
||||||
hash_cols = self._artifacts_container.get_hash_columns(a_name)
|
hash_cols = self._artifacts_container.get_hash_columns(a_name)
|
||||||
if not pd or not isinstance(a_df, pd.DataFrame):
|
if not pd or not isinstance(a_df, pd.DataFrame) or not pl or not isinstance(a_df, pl.DataFrame):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if hash_cols is True:
|
if hash_cols is True:
|
||||||
@ -1037,8 +1064,12 @@ class Artifacts(object):
|
|||||||
|
|
||||||
a_shape = a_df.shape
|
a_shape = a_df.shape
|
||||||
# parallelize
|
# parallelize
|
||||||
a_hash_cols = a_df.drop(columns=hash_col_drop)
|
if pd and isinstance(a_df, pd.DataFrame):
|
||||||
thread_pool.map(hash_row, a_hash_cols.values)
|
a_hash_cols = a_df.drop(columns=hash_col_drop)
|
||||||
|
thread_pool.map(hash_row, a_hash_cols.values)
|
||||||
|
else:
|
||||||
|
a_hash_cols = a_df.drop(hash_col_drop)
|
||||||
|
a_unique_hash.add(a_hash_cols.hash_rows())
|
||||||
# add result
|
# add result
|
||||||
artifacts_summary.append((a_name, a_shape, a_unique_hash,))
|
artifacts_summary.append((a_name, a_shape, a_unique_hash,))
|
||||||
|
|
||||||
@ -1082,16 +1113,19 @@ class Artifacts(object):
|
|||||||
# (otherwise it is encoded and creates new hash every time)
|
# (otherwise it is encoded and creates new hash every time)
|
||||||
if self._compression == "gzip":
|
if self._compression == "gzip":
|
||||||
with gzip.GzipFile(local_filename, 'wb', mtime=0) as gzip_file:
|
with gzip.GzipFile(local_filename, 'wb', mtime=0) as gzip_file:
|
||||||
try:
|
if pl and isinstance(artifact_object, pl.DataFrame):
|
||||||
pd_version = int(pd.__version__.split(".")[0])
|
artifact_object.write_csv(gzip_file)
|
||||||
except ValueError:
|
|
||||||
pd_version = 0
|
|
||||||
|
|
||||||
if pd_version >= 2:
|
|
||||||
artifact_object.to_csv(gzip_file, **kwargs)
|
|
||||||
else:
|
else:
|
||||||
# old (pandas<2) versions of pandas cannot handle direct gzip stream, so we manually encode it
|
try:
|
||||||
artifact_object.to_csv(io.TextIOWrapper(gzip_file), **kwargs)
|
pd_version = int(pd.__version__.split(".")[0])
|
||||||
|
except ValueError:
|
||||||
|
pd_version = 0
|
||||||
|
|
||||||
|
if pd_version >= 2:
|
||||||
|
artifact_object.to_csv(gzip_file, **kwargs)
|
||||||
|
else:
|
||||||
|
# old (pandas<2) versions of pandas cannot handle direct gzip stream, so we manually encode it
|
||||||
|
artifact_object.to_csv(io.TextIOWrapper(gzip_file), **kwargs)
|
||||||
else:
|
else:
|
||||||
artifact_object.to_csv(local_filename, compression=self._compression)
|
artifact_object.to_csv(local_filename, compression=self._compression)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user