diff --git a/clearml/backend_interface/metrics/reporter.py b/clearml/backend_interface/metrics/reporter.py index b5fb4d49..dcf56d96 100644 --- a/clearml/backend_interface/metrics/reporter.py +++ b/clearml/backend_interface/metrics/reporter.py @@ -654,7 +654,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan :param series: Series (AKA variant) :type series: str :param table: The table data - :type table: pandas.DataFrame + :type table: pandas.DataFrame or polars.DataFrame :param iteration: Iteration number :type iteration: int :param layout_config: optional dictionary for layout configuration, passed directly to plotly diff --git a/clearml/binding/artifacts.py b/clearml/binding/artifacts.py index 5257e3b5..37d0482a 100644 --- a/clearml/binding/artifacts.py +++ b/clearml/binding/artifacts.py @@ -30,6 +30,10 @@ from ..utilities.process.mp import SafeEvent, ForkSafeRLock from ..utilities.proxy_object import LazyEvalWrapper from ..config import deferred_config, config +try: + import polars as pl +except ImportError: + pl = None try: import pandas as pd DataFrame = pd.DataFrame @@ -152,6 +156,7 @@ class Artifact(object): Supported content types are: - dict - ``.json``, ``.yaml`` - pandas.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle`` + - polars.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle`` - numpy.ndarray - ``.npz``, ``.csv.gz`` - PIL.Image - whatever content types PIL supports All other types will return a pathlib2.Path object pointing to a local copy of the artifacts file (or directory). @@ -192,6 +197,18 @@ class Artifact(object): self._object = pd.read_csv(local_file) else: self._object = pd.read_csv(local_file, index_col=[0]) + elif self.type == Artifacts._pd_artifact_type or self.type == "polars" and pl: + if self._content_type == "application/parquet": + self._object = pl.read_parquet(local_file) + elif self._content_type == "application/feather": + self._object = pl.read_ipc(local_file) + elif self._content_type == "application/pickle": + with open(local_file, "rb") as f: + self._object = pickle.load(f) + elif self.type == Artifacts._pd_artifact_type: + self._object = pl.read_csv(local_file) + else: + self._object = pl.read_csv(local_file) elif self.type == "image": self._object = Image.open(local_file) elif self.type == "JSON" or self.type == "dict": @@ -280,14 +297,14 @@ class Artifacts(object): self.artifact_hash_columns = {} def __setitem__(self, key, value): - # check that value is of type pandas - if pd and isinstance(value, pd.DataFrame): + # check that value is of type pandas or polars + if (pd and isinstance(value, pd.DataFrame)) or (pl and isinstance(value, pl.DataFrame)): super(Artifacts._ProxyDictWrite, self).__setitem__(key, value) if self._artifacts_manager: self._artifacts_manager.flush() else: - raise ValueError('Artifacts currently support pandas.DataFrame objects only') + raise ValueError('Artifacts currently support pandas.DataFrame and polars.DataFrame objects only') def unregister_artifact(self, name): self.artifact_metadata.pop(name, None) @@ -472,8 +489,8 @@ class Artifacts(object): artifact_type_data.content_type = "text/csv" np.savetxt(local_filename, artifact_object, delimiter=",") delete_after_upload = True - elif pd and isinstance(artifact_object, pd.DataFrame): - artifact_type = "pandas" + elif (pd and isinstance(artifact_object, pd.DataFrame)) or (pl and isinstance(artifact_object, pl.DataFrame)): + artifact_type = "pandas" if (pd and isinstance(artifact_object, pd.DataFrame)) else "polars" artifact_type_data.preview = preview or str(artifact_object.__repr__()) # we are making sure self._default_pandas_dataframe_extension_name is not deferred extension_name = extension_name or str(self._default_pandas_dataframe_extension_name or "") @@ -484,9 +501,9 @@ class Artifacts(object): local_filename = self._push_temp_file( prefix=quote(name, safe="") + ".", suffix=override_filename_ext_in_uri ) - if ( + if (pd and isinstance(artifact_object, pd.DataFrame) and ( isinstance(artifact_object.index, pd.MultiIndex) or isinstance(artifact_object.columns, pd.MultiIndex) - ) and not extension_name: + ) or (pl and isinstance(artifact_object, pl.DataFrame))) and not extension_name: store_as_pickle = True elif override_filename_ext_in_uri == ".csv.gz": artifact_type_data.content_type = "text/csv" @@ -494,7 +511,10 @@ class Artifacts(object): elif override_filename_ext_in_uri == ".parquet": try: artifact_type_data.content_type = "application/parquet" - artifact_object.to_parquet(local_filename) + if (pd and isinstance(artifact_object, pd.DataFrame)): + artifact_object.to_parquet(local_filename) + else: + artifact_object.write_parquet(local_filename) except Exception as e: LoggerRoot.get_base_logger().warning( "Exception '{}' encountered when uploading artifact as .parquet. Defaulting to .csv.gz".format( @@ -506,7 +526,10 @@ class Artifacts(object): elif override_filename_ext_in_uri == ".feather": try: artifact_type_data.content_type = "application/feather" - artifact_object.to_feather(local_filename) + if (pd and isinstance(artifact_object, pd.DataFrame)): + artifact_object.to_feather(local_filename) + else: + artifact_object.write_ipc(local_filename) except Exception as e: LoggerRoot.get_base_logger().warning( "Exception '{}' encountered when uploading artifact as .feather. Defaulting to .csv.gz".format( @@ -517,7 +540,11 @@ class Artifacts(object): self._store_compressed_pd_csv(artifact_object, local_filename) elif override_filename_ext_in_uri == ".pickle": artifact_type_data.content_type = "application/pickle" - artifact_object.to_pickle(local_filename) + if (pl and isinstance(artifact_object, pd.DataFrame)): + artifact_object.to_pickle(local_filename) + else: + with open(local_filename, "wb") as f: + pickle.dump(artifact_object, f) delete_after_upload = True elif isinstance(artifact_object, Image.Image): artifact_type = "image" @@ -1006,7 +1033,7 @@ class Artifacts(object): artifacts_summary = [] for a_name, a_df in artifacts_dict.items(): hash_cols = self._artifacts_container.get_hash_columns(a_name) - if not pd or not isinstance(a_df, pd.DataFrame): + if not pd or not isinstance(a_df, pd.DataFrame) or not pl or not isinstance(a_df, pl.DataFrame): continue if hash_cols is True: @@ -1038,8 +1065,12 @@ class Artifacts(object): a_shape = a_df.shape # parallelize - a_hash_cols = a_df.drop(columns=hash_col_drop) - thread_pool.map(hash_row, a_hash_cols.values) + if pd and isinstance(a_df, pd.DataFrame): + a_hash_cols = a_df.drop(columns=hash_col_drop) + thread_pool.map(hash_row, a_hash_cols.values) + else: + a_hash_cols = a_df.drop(hash_col_drop) + a_unique_hash.add(a_hash_cols.hash_rows()) # add result artifacts_summary.append((a_name, a_shape, a_unique_hash,)) @@ -1083,16 +1114,19 @@ class Artifacts(object): # (otherwise it is encoded and creates new hash every time) if self._compression == "gzip": with gzip.GzipFile(local_filename, 'wb', mtime=0) as gzip_file: - try: - pd_version = int(pd.__version__.split(".")[0]) - except ValueError: - pd_version = 0 - - if pd_version >= 2: - artifact_object.to_csv(gzip_file, **kwargs) + if pl and isinstance(artifact_object, pl.DataFrame): + artifact_object.write_csv(gzip_file) else: - # old (pandas<2) versions of pandas cannot handle direct gzip stream, so we manually encode it - artifact_object.to_csv(io.TextIOWrapper(gzip_file), **kwargs) + try: + pd_version = int(pd.__version__.split(".")[0]) + except ValueError: + pd_version = 0 + + if pd_version >= 2: + artifact_object.to_csv(gzip_file, **kwargs) + else: + # old (pandas<2) versions of pandas cannot handle direct gzip stream, so we manually encode it + artifact_object.to_csv(io.TextIOWrapper(gzip_file), **kwargs) else: artifact_object.to_csv(local_filename, compression=self._compression) diff --git a/clearml/datasets/dataset.py b/clearml/datasets/dataset.py index eb347629..ea56c7c5 100644 --- a/clearml/datasets/dataset.py +++ b/clearml/datasets/dataset.py @@ -48,6 +48,10 @@ except ImportError: except Exception as e: logging.warning("ClearML Dataset failed importing pandas: {}".format(e)) pd = None +try: + import polars as pl +except ImportError: + pl = None try: import pyarrow # noqa @@ -852,7 +856,7 @@ class Dataset(object): return True def set_metadata(self, metadata, metadata_name='metadata', ui_visible=True): - # type: (Union[numpy.array, pd.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821 + # type: (Union[numpy.array, pd.DataFrame, pl.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821 """ Attach a user-defined metadata to the dataset. Check `Task.upload_artifact` for supported types. If type is Pandas Dataframes, optionally make it visible as a table in the UI. @@ -861,7 +865,7 @@ class Dataset(object): raise ValueError("metadata_name can not start with '{}'".format(self.__data_entry_name_prefix)) self._task.upload_artifact(name=metadata_name, artifact_object=metadata) if ui_visible: - if pd and isinstance(metadata, pd.DataFrame): + if (pd and isinstance(metadata, pd.DataFrame)) or (pl and isinstance(metadata, pl.DataFrame)): self.get_logger().report_table( title='Dataset Metadata', series='Dataset Metadata', @@ -874,7 +878,7 @@ class Dataset(object): ) def get_metadata(self, metadata_name='metadata'): - # type: (str) -> Optional[numpy.array, pd.DataFrame, dict, str, bool] # noqa: F821 + # type: (str) -> Optional[numpy.array, pd.DataFrame, pl.DataFrame, dict, str, bool] # noqa: F821 """ Get attached metadata back in its original format. Will return None if none was found. """ @@ -3119,19 +3123,34 @@ class Dataset(object): def convert_to_tabular_artifact(file_path_, file_extension_, compression_=None): # noinspection PyBroadException try: - if file_extension_ == ".csv" and pd: - return pd.read_csv( - file_path_, - nrows=self.__preview_tabular_row_count, - compression=compression_.lstrip(".") if compression_ else None, - ) - elif file_extension_ == ".tsv" and pd: - return pd.read_csv( - file_path_, - sep='\t', - nrows=self.__preview_tabular_row_count, - compression=compression_.lstrip(".") if compression_ else None, - ) + if file_extension_ == ".csv" and (pl or pd): + if pd: + return pd.read_csv( + file_path_, + nrows=self.__preview_tabular_row_count, + compression=compression_.lstrip(".") if compression_ else None, + ) + else: + # TODO Re-implement compression after testing all extensions + return pl.read_csv( + file_path_, + n_rows=self.__preview_tabular_row_count, + ) + elif file_extension_ == ".tsv" and (pl or pd): + if pd: + return pd.read_csv( + file_path_, + sep='\t', + nrows=self.__preview_tabular_row_count, + compression=compression_.lstrip(".") if compression_ else None, + ) + else: + # TODO Re-implement compression after testing all extensions + return pl.read_csv( + file_path_, + n_rows=self.__preview_tabular_row_count, + separator='\t', + ) elif file_extension_ == ".parquet" or file_extension_ == ".parq": if pyarrow: pf = pyarrow.parquet.ParquetFile(file_path_) @@ -3140,7 +3159,10 @@ class Dataset(object): elif fastparquet: return fastparquet.ParquetFile(file_path_).head(self.__preview_tabular_row_count).to_pandas() elif (file_extension_ == ".npz" or file_extension_ == ".npy") and np: - return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count)) + if pd: + return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count)) + else + return pl.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count)) except Exception: pass return None @@ -3172,7 +3194,10 @@ class Dataset(object): # because it will not upload the sample to that destination. # use report_media instead to not leak data if ( - isinstance(artifact, pd.DataFrame) + ( + (pd and isinstance(artifact, pd.DataFrame)) + or (pl and isinstance(artifact, pl.DataFrame)) + ) and self._task.get_logger().get_default_upload_destination() == Session.get_files_server_host() ): self._task.get_logger().report_table("Tables", "summary", table_plot=artifact) diff --git a/clearml/logger.py b/clearml/logger.py index ee3e35c3..e0fec6db 100644 --- a/clearml/logger.py +++ b/clearml/logger.py @@ -10,6 +10,11 @@ from pathlib2 import Path from .debugging.log import LoggerRoot +try: + import polars as pl +except ImportError: + pl = None + try: import pandas as pd except ImportError: @@ -327,7 +332,7 @@ class Logger(object): title, # type: str series, # type: str iteration=None, # type: Optional[int] - table_plot=None, # type: Optional[pd.DataFrame, Sequence[Sequence]] + table_plot=None, # type: Optional[pd.DataFrame, pl.DataFrame, Sequence[Sequence]] csv=None, # type: Optional[str] url=None, # type: Optional[str] extra_layout=None, # type: Optional[dict] @@ -393,15 +398,15 @@ class Logger(object): mutually_exclusive(UsageError, _check_none=True, table_plot=table_plot, csv=csv, url=url) table = table_plot if url or csv: - if not pd: + if not pd and not pl: raise UsageError( - "pandas is required in order to support reporting tables using CSV or a URL, " - "please install the pandas python package" + "pandas or polars is required in order to support reporting tables using CSV " + "or a URL, please install the pandas or polars python package" ) if url: - table = pd.read_csv(url, index_col=[0]) + table = pd.read_csv(url, index_col=[0]) if pd else pl.read_csv(url) elif csv: - table = pd.read_csv(csv, index_col=[0]) + table = pd.read_csv(csv, index_col=[0]) if pd else pl.read_csv(csv) def replace(dst, *srcs): for src in srcs: @@ -410,7 +415,8 @@ class Logger(object): if isinstance(table, (list, tuple)): reporter_table = table else: - reporter_table = table.fillna(str(np.nan)) + nan = str(np.nan) + reporter_table = table.fillna(nan) if pd else table.fill_nan(nan) replace("NaN", np.nan, math.nan if six.PY3 else float("nan")) replace("Inf", np.inf, math.inf if six.PY3 else float("inf")) minus_inf = [-np.inf, -math.inf if six.PY3 else -float("inf")] diff --git a/clearml/model.py b/clearml/model.py index 58813cfe..bc593ca6 100644 --- a/clearml/model.py +++ b/clearml/model.py @@ -14,6 +14,10 @@ try: import pandas as pd except ImportError: pd = None +try: + import polars as pl +except ImportError: + pl = None from .backend_api import Session from .backend_api.services import models, projects @@ -638,15 +642,21 @@ class BaseModel(object): ) table = table_plot if url or csv: - if not pd: + if not pd and not pl: raise UsageError( - "pandas is required in order to support reporting tables using CSV or a URL, " - "please install the pandas python package" + "pandas or polars is required in order to support reporting tables using CSV or a URL, " + "please install the pandas or polars python package" ) if url: - table = pd.read_csv(url, index_col=[0]) + if pd: + table = pd.read_csv(url, index_col=[0]) + else: + table = pd.read_csv(url) elif csv: - table = pd.read_csv(csv, index_col=[0]) + if pd: + table = pd.read_csv(csv, index_col=[0]) + else: + table = pd.read_csv(url) def replace(dst, *srcs): for src in srcs: diff --git a/clearml/utilities/plotly_reporter.py b/clearml/utilities/plotly_reporter.py index cf8cac40..723dfa14 100644 --- a/clearml/utilities/plotly_reporter.py +++ b/clearml/utilities/plotly_reporter.py @@ -4,6 +4,10 @@ import numpy as np from ..errors import UsageError from ..utilities.dicts import merge_dicts +try: + import polars as pl +except ImportError: + pl = None try: import pandas as pd except ImportError: @@ -486,7 +490,7 @@ def create_plotly_table(table_plot, title, series, layout_config=None, data_conf """ Create a basic Plotly table json style to be sent - :param table_plot: the output table in pandas.DataFrame structure or list of rows (list) in a table + :param table_plot: the output table in pandas.DataFrame structure or polars.Dataframe structure or list of rows (list) in a table :param title: Title (AKA metric) :type title: str :param series: Series (AKA variant) @@ -503,11 +507,19 @@ def create_plotly_table(table_plot, title, series, layout_config=None, data_conf elif is_list and table_plot[0] and isinstance(table_plot[0], (list, tuple)): headers_values = table_plot[0] cells_values = [list(i) for i in zip(*table_plot[1:])] + elif pl and isinstance(table_plot, pl.DataFrame): + headers_values = list([col] for col in table_plot.columns) + # Convert datetimes to ISO strings + datetime_columns = table_plot.select(pl.selectors.datetime()).columns + exprs = [pl.col(col).dt.to_string("iso:strict") for col in datetime_columns] + # Get cell values and preserve value types + cells_values_transpose = table_plot.with_columns(*exprs).rows() + cells_values = list(map(list, zip(*cells_values_transpose))) else: if not pd: raise UsageError( - "pandas is required in order to support reporting tables using CSV or a URL, " - "please install the pandas python package" + "pandas or polars is required in order to support reporting tables using CSV or a URL, " + "please install the pandas or polars python package" ) index_added = not isinstance(table_plot.index, pd.RangeIndex) headers_values = list([col] for col in table_plot.columns)