diff --git a/clearml/binding/artifacts.py b/clearml/binding/artifacts.py index 093c80d7..8b39d504 100644 --- a/clearml/binding/artifacts.py +++ b/clearml/binding/artifacts.py @@ -30,6 +30,10 @@ from ..utilities.process.mp import SafeEvent, ForkSafeRLock from ..utilities.proxy_object import LazyEvalWrapper from ..config import deferred_config, config +try: + import polars as pl +except ImportError: + pl = None try: import pandas as pd DataFrame = pd.DataFrame @@ -152,6 +156,7 @@ class Artifact(object): Supported content types are: - dict - ``.json``, ``.yaml`` - pandas.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle`` + - polars.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle`` - numpy.ndarray - ``.npz``, ``.csv.gz`` - PIL.Image - whatever content types PIL supports All other types will return a pathlib2.Path object pointing to a local copy of the artifacts file (or directory). @@ -192,6 +197,18 @@ class Artifact(object): self._object = pd.read_csv(local_file) else: self._object = pd.read_csv(local_file, index_col=[0]) + elif self.type == Artifacts._pd_artifact_type or self.type == "polars" and pl: + if self._content_type == "application/parquet": + self._object = pl.read_parquet(local_file) + elif self._content_type == "application/feather": + self._object = pl.read_ipc(local_file) + elif self._content_type == "application/pickle": + with open(local_file, "rb") as f: + self._object = pickle.load(f) + elif self.type == Artifacts._pd_artifact_type: + self._object = pl.read_csv(local_file) + else: + self._object = pl.read_csv(local_file) elif self.type == "image": self._object = Image.open(local_file) elif self.type == "JSON" or self.type == "dict": @@ -279,14 +296,14 @@ class Artifacts(object): self.artifact_hash_columns = {} def __setitem__(self, key, value): - # check that value is of type pandas - if pd and isinstance(value, pd.DataFrame): + # check that value is of type pandas or polars + if (pd and isinstance(value, pd.DataFrame)) or (pl and isinstance(value, pl.DataFrame)): super(Artifacts._ProxyDictWrite, self).__setitem__(key, value) if self._artifacts_manager: self._artifacts_manager.flush() else: - raise ValueError('Artifacts currently support pandas.DataFrame objects only') + raise ValueError('Artifacts currently support pandas.DataFrame and polars.DataFrame objects only') def unregister_artifact(self, name): self.artifact_metadata.pop(name, None) @@ -471,8 +488,8 @@ class Artifacts(object): artifact_type_data.content_type = "text/csv" np.savetxt(local_filename, artifact_object, delimiter=",") delete_after_upload = True - elif pd and isinstance(artifact_object, pd.DataFrame): - artifact_type = "pandas" + elif (pd and isinstance(artifact_object, pd.DataFrame)) or (pl and isinstance(artifact_object, pl.DataFrame)): + artifact_type = "pandas" if (pd and isinstance(artifact_object, pd.DataFrame)) else "polars" artifact_type_data.preview = preview or str(artifact_object.__repr__()) # we are making sure self._default_pandas_dataframe_extension_name is not deferred extension_name = extension_name or str(self._default_pandas_dataframe_extension_name or "") @@ -483,9 +500,9 @@ class Artifacts(object): local_filename = self._push_temp_file( prefix=quote(name, safe="") + ".", suffix=override_filename_ext_in_uri ) - if ( + if (pd and isinstance(artifact_object, pd.DataFrame) and ( isinstance(artifact_object.index, pd.MultiIndex) or isinstance(artifact_object.columns, pd.MultiIndex) - ) and not extension_name: + ) or (pl and isinstance(artifact_object, pl.DataFrame))) and not extension_name: store_as_pickle = True elif override_filename_ext_in_uri == ".csv.gz": artifact_type_data.content_type = "text/csv" @@ -493,7 +510,10 @@ class Artifacts(object): elif override_filename_ext_in_uri == ".parquet": try: artifact_type_data.content_type = "application/parquet" - artifact_object.to_parquet(local_filename) + if (pd and isinstance(artifact_object, pd.DataFrame)): + artifact_object.to_parquet(local_filename) + else: + artifact_object.write_parquet(local_filename) except Exception as e: LoggerRoot.get_base_logger().warning( "Exception '{}' encountered when uploading artifact as .parquet. Defaulting to .csv.gz".format( @@ -505,7 +525,10 @@ class Artifacts(object): elif override_filename_ext_in_uri == ".feather": try: artifact_type_data.content_type = "application/feather" - artifact_object.to_feather(local_filename) + if (pd and isinstance(artifact_object, pd.DataFrame)): + artifact_object.to_feather(local_filename) + else: + artifact_object.write_ipc(local_filename) except Exception as e: LoggerRoot.get_base_logger().warning( "Exception '{}' encountered when uploading artifact as .feather. Defaulting to .csv.gz".format( @@ -516,7 +539,11 @@ class Artifacts(object): self._store_compressed_pd_csv(artifact_object, local_filename) elif override_filename_ext_in_uri == ".pickle": artifact_type_data.content_type = "application/pickle" - artifact_object.to_pickle(local_filename) + if (pl and isinstance(artifact_object, pd.DataFrame)): + artifact_object.to_pickle(local_filename) + else: + with open(local_filename, "wb") as f: + pickle.dump(artifact_object, f) delete_after_upload = True elif isinstance(artifact_object, Image.Image): artifact_type = "image" @@ -1005,7 +1032,7 @@ class Artifacts(object): artifacts_summary = [] for a_name, a_df in artifacts_dict.items(): hash_cols = self._artifacts_container.get_hash_columns(a_name) - if not pd or not isinstance(a_df, pd.DataFrame): + if not pd or not isinstance(a_df, pd.DataFrame) or not pl or not isinstance(a_df, pl.DataFrame): continue if hash_cols is True: @@ -1037,8 +1064,12 @@ class Artifacts(object): a_shape = a_df.shape # parallelize - a_hash_cols = a_df.drop(columns=hash_col_drop) - thread_pool.map(hash_row, a_hash_cols.values) + if pd and isinstance(a_df, pd.DataFrame): + a_hash_cols = a_df.drop(columns=hash_col_drop) + thread_pool.map(hash_row, a_hash_cols.values) + else: + a_hash_cols = a_df.drop(hash_col_drop) + a_unique_hash.add(a_hash_cols.hash_rows()) # add result artifacts_summary.append((a_name, a_shape, a_unique_hash,)) @@ -1082,16 +1113,19 @@ class Artifacts(object): # (otherwise it is encoded and creates new hash every time) if self._compression == "gzip": with gzip.GzipFile(local_filename, 'wb', mtime=0) as gzip_file: - try: - pd_version = int(pd.__version__.split(".")[0]) - except ValueError: - pd_version = 0 - - if pd_version >= 2: - artifact_object.to_csv(gzip_file, **kwargs) + if pl and isinstance(artifact_object, pl.DataFrame): + artifact_object.write_csv(gzip_file) else: - # old (pandas<2) versions of pandas cannot handle direct gzip stream, so we manually encode it - artifact_object.to_csv(io.TextIOWrapper(gzip_file), **kwargs) + try: + pd_version = int(pd.__version__.split(".")[0]) + except ValueError: + pd_version = 0 + + if pd_version >= 2: + artifact_object.to_csv(gzip_file, **kwargs) + else: + # old (pandas<2) versions of pandas cannot handle direct gzip stream, so we manually encode it + artifact_object.to_csv(io.TextIOWrapper(gzip_file), **kwargs) else: artifact_object.to_csv(local_filename, compression=self._compression)