Add polars support to Artifact

This commit is contained in:
BlakeJC94 2024-12-02 21:47:54 +11:00
parent 315486bf9d
commit a300d7a8bd

View File

@ -30,6 +30,10 @@ from ..utilities.process.mp import SafeEvent, ForkSafeRLock
from ..utilities.proxy_object import LazyEvalWrapper
from ..config import deferred_config, config
try:
import polars as pl
except ImportError:
pl = None
try:
import pandas as pd
DataFrame = pd.DataFrame
@ -152,6 +156,7 @@ class Artifact(object):
Supported content types are:
- dict - ``.json``, ``.yaml``
- pandas.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle``
- polars.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle``
- numpy.ndarray - ``.npz``, ``.csv.gz``
- PIL.Image - whatever content types PIL supports
All other types will return a pathlib2.Path object pointing to a local copy of the artifacts file (or directory).
@ -192,6 +197,18 @@ class Artifact(object):
self._object = pd.read_csv(local_file)
else:
self._object = pd.read_csv(local_file, index_col=[0])
elif self.type == Artifacts._pd_artifact_type or self.type == "polars" and pl:
if self._content_type == "application/parquet":
self._object = pl.read_parquet(local_file)
elif self._content_type == "application/feather":
self._object = pl.read_ipc(local_file)
elif self._content_type == "application/pickle":
with open(local_file, "rb") as f:
self._object = pickle.load(f)
elif self.type == Artifacts._pd_artifact_type:
self._object = pl.read_csv(local_file)
else:
self._object = pl.read_csv(local_file)
elif self.type == "image":
self._object = Image.open(local_file)
elif self.type == "JSON" or self.type == "dict":
@ -279,14 +296,14 @@ class Artifacts(object):
self.artifact_hash_columns = {}
def __setitem__(self, key, value):
# check that value is of type pandas
if pd and isinstance(value, pd.DataFrame):
# check that value is of type pandas or polars
if (pd and isinstance(value, pd.DataFrame)) or (pl and isinstance(value, pl.DataFrame)):
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)
if self._artifacts_manager:
self._artifacts_manager.flush()
else:
raise ValueError('Artifacts currently support pandas.DataFrame objects only')
raise ValueError('Artifacts currently support pandas.DataFrame and polars.DataFrame objects only')
def unregister_artifact(self, name):
self.artifact_metadata.pop(name, None)
@ -471,8 +488,8 @@ class Artifacts(object):
artifact_type_data.content_type = "text/csv"
np.savetxt(local_filename, artifact_object, delimiter=",")
delete_after_upload = True
elif pd and isinstance(artifact_object, pd.DataFrame):
artifact_type = "pandas"
elif (pd and isinstance(artifact_object, pd.DataFrame)) or (pl and isinstance(artifact_object, pl.DataFrame)):
artifact_type = "pandas" if (pd and isinstance(artifact_object, pd.DataFrame)) else "polars"
artifact_type_data.preview = preview or str(artifact_object.__repr__())
# we are making sure self._default_pandas_dataframe_extension_name is not deferred
extension_name = extension_name or str(self._default_pandas_dataframe_extension_name or "")
@ -483,9 +500,9 @@ class Artifacts(object):
local_filename = self._push_temp_file(
prefix=quote(name, safe="") + ".", suffix=override_filename_ext_in_uri
)
if (
if (pd and isinstance(artifact_object, pd.DataFrame) and (
isinstance(artifact_object.index, pd.MultiIndex) or isinstance(artifact_object.columns, pd.MultiIndex)
) and not extension_name:
) or (pl and isinstance(artifact_object, pl.DataFrame))) and not extension_name:
store_as_pickle = True
elif override_filename_ext_in_uri == ".csv.gz":
artifact_type_data.content_type = "text/csv"
@ -493,7 +510,10 @@ class Artifacts(object):
elif override_filename_ext_in_uri == ".parquet":
try:
artifact_type_data.content_type = "application/parquet"
artifact_object.to_parquet(local_filename)
if (pd and isinstance(artifact_object, pd.DataFrame)):
artifact_object.to_parquet(local_filename)
else:
artifact_object.write_parquet(local_filename)
except Exception as e:
LoggerRoot.get_base_logger().warning(
"Exception '{}' encountered when uploading artifact as .parquet. Defaulting to .csv.gz".format(
@ -505,7 +525,10 @@ class Artifacts(object):
elif override_filename_ext_in_uri == ".feather":
try:
artifact_type_data.content_type = "application/feather"
artifact_object.to_feather(local_filename)
if (pd and isinstance(artifact_object, pd.DataFrame)):
artifact_object.to_feather(local_filename)
else:
artifact_object.write_ipc(local_filename)
except Exception as e:
LoggerRoot.get_base_logger().warning(
"Exception '{}' encountered when uploading artifact as .feather. Defaulting to .csv.gz".format(
@ -516,7 +539,11 @@ class Artifacts(object):
self._store_compressed_pd_csv(artifact_object, local_filename)
elif override_filename_ext_in_uri == ".pickle":
artifact_type_data.content_type = "application/pickle"
artifact_object.to_pickle(local_filename)
if (pl and isinstance(artifact_object, pd.DataFrame)):
artifact_object.to_pickle(local_filename)
else:
with open(local_filename, "wb") as f:
pickle.dump(artifact_object, f)
delete_after_upload = True
elif isinstance(artifact_object, Image.Image):
artifact_type = "image"
@ -1005,7 +1032,7 @@ class Artifacts(object):
artifacts_summary = []
for a_name, a_df in artifacts_dict.items():
hash_cols = self._artifacts_container.get_hash_columns(a_name)
if not pd or not isinstance(a_df, pd.DataFrame):
if not pd or not isinstance(a_df, pd.DataFrame) or not pl or not isinstance(a_df, pl.DataFrame):
continue
if hash_cols is True:
@ -1037,8 +1064,12 @@ class Artifacts(object):
a_shape = a_df.shape
# parallelize
a_hash_cols = a_df.drop(columns=hash_col_drop)
thread_pool.map(hash_row, a_hash_cols.values)
if pd and isinstance(a_df, pd.DataFrame):
a_hash_cols = a_df.drop(columns=hash_col_drop)
thread_pool.map(hash_row, a_hash_cols.values)
else:
a_hash_cols = a_df.drop(hash_col_drop)
a_unique_hash.add(a_hash_cols.hash_rows())
# add result
artifacts_summary.append((a_name, a_shape, a_unique_hash,))
@ -1082,16 +1113,19 @@ class Artifacts(object):
# (otherwise it is encoded and creates new hash every time)
if self._compression == "gzip":
with gzip.GzipFile(local_filename, 'wb', mtime=0) as gzip_file:
try:
pd_version = int(pd.__version__.split(".")[0])
except ValueError:
pd_version = 0
if pd_version >= 2:
artifact_object.to_csv(gzip_file, **kwargs)
if pl and isinstance(artifact_object, pl.DataFrame):
artifact_object.write_csv(gzip_file)
else:
# old (pandas<2) versions of pandas cannot handle direct gzip stream, so we manually encode it
artifact_object.to_csv(io.TextIOWrapper(gzip_file), **kwargs)
try:
pd_version = int(pd.__version__.split(".")[0])
except ValueError:
pd_version = 0
if pd_version >= 2:
artifact_object.to_csv(gzip_file, **kwargs)
else:
# old (pandas<2) versions of pandas cannot handle direct gzip stream, so we manually encode it
artifact_object.to_csv(io.TextIOWrapper(gzip_file), **kwargs)
else:
artifact_object.to_csv(local_filename, compression=self._compression)