This commit is contained in:
Blake 2025-03-26 05:47:25 +00:00 committed by GitHub
commit 49706ebeda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 143 additions and 56 deletions

View File

@ -654,7 +654,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param series: Series (AKA variant) :param series: Series (AKA variant)
:type series: str :type series: str
:param table: The table data :param table: The table data
:type table: pandas.DataFrame :type table: pandas.DataFrame or polars.DataFrame
:param iteration: Iteration number :param iteration: Iteration number
:type iteration: int :type iteration: int
:param layout_config: optional dictionary for layout configuration, passed directly to plotly :param layout_config: optional dictionary for layout configuration, passed directly to plotly

View File

@ -30,6 +30,10 @@ from ..utilities.process.mp import SafeEvent, ForkSafeRLock
from ..utilities.proxy_object import LazyEvalWrapper from ..utilities.proxy_object import LazyEvalWrapper
from ..config import deferred_config, config from ..config import deferred_config, config
try:
import polars as pl
except ImportError:
pl = None
try: try:
import pandas as pd import pandas as pd
DataFrame = pd.DataFrame DataFrame = pd.DataFrame
@ -152,6 +156,7 @@ class Artifact(object):
Supported content types are: Supported content types are:
- dict - ``.json``, ``.yaml`` - dict - ``.json``, ``.yaml``
- pandas.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle`` - pandas.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle``
- polars.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle``
- numpy.ndarray - ``.npz``, ``.csv.gz`` - numpy.ndarray - ``.npz``, ``.csv.gz``
- PIL.Image - whatever content types PIL supports - PIL.Image - whatever content types PIL supports
All other types will return a pathlib2.Path object pointing to a local copy of the artifacts file (or directory). All other types will return a pathlib2.Path object pointing to a local copy of the artifacts file (or directory).
@ -192,6 +197,18 @@ class Artifact(object):
self._object = pd.read_csv(local_file) self._object = pd.read_csv(local_file)
else: else:
self._object = pd.read_csv(local_file, index_col=[0]) self._object = pd.read_csv(local_file, index_col=[0])
elif self.type == Artifacts._pd_artifact_type or self.type == "polars" and pl:
if self._content_type == "application/parquet":
self._object = pl.read_parquet(local_file)
elif self._content_type == "application/feather":
self._object = pl.read_ipc(local_file)
elif self._content_type == "application/pickle":
with open(local_file, "rb") as f:
self._object = pickle.load(f)
elif self.type == Artifacts._pd_artifact_type:
self._object = pl.read_csv(local_file)
else:
self._object = pl.read_csv(local_file)
elif self.type == "image": elif self.type == "image":
self._object = Image.open(local_file) self._object = Image.open(local_file)
elif self.type == "JSON" or self.type == "dict": elif self.type == "JSON" or self.type == "dict":
@ -280,14 +297,14 @@ class Artifacts(object):
self.artifact_hash_columns = {} self.artifact_hash_columns = {}
def __setitem__(self, key, value): def __setitem__(self, key, value):
# check that value is of type pandas # check that value is of type pandas or polars
if pd and isinstance(value, pd.DataFrame): if (pd and isinstance(value, pd.DataFrame)) or (pl and isinstance(value, pl.DataFrame)):
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value) super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)
if self._artifacts_manager: if self._artifacts_manager:
self._artifacts_manager.flush() self._artifacts_manager.flush()
else: else:
raise ValueError('Artifacts currently support pandas.DataFrame objects only') raise ValueError('Artifacts currently support pandas.DataFrame and polars.DataFrame objects only')
def unregister_artifact(self, name): def unregister_artifact(self, name):
self.artifact_metadata.pop(name, None) self.artifact_metadata.pop(name, None)
@ -472,8 +489,8 @@ class Artifacts(object):
artifact_type_data.content_type = "text/csv" artifact_type_data.content_type = "text/csv"
np.savetxt(local_filename, artifact_object, delimiter=",") np.savetxt(local_filename, artifact_object, delimiter=",")
delete_after_upload = True delete_after_upload = True
elif pd and isinstance(artifact_object, pd.DataFrame): elif (pd and isinstance(artifact_object, pd.DataFrame)) or (pl and isinstance(artifact_object, pl.DataFrame)):
artifact_type = "pandas" artifact_type = "pandas" if (pd and isinstance(artifact_object, pd.DataFrame)) else "polars"
artifact_type_data.preview = preview or str(artifact_object.__repr__()) artifact_type_data.preview = preview or str(artifact_object.__repr__())
# we are making sure self._default_pandas_dataframe_extension_name is not deferred # we are making sure self._default_pandas_dataframe_extension_name is not deferred
extension_name = extension_name or str(self._default_pandas_dataframe_extension_name or "") extension_name = extension_name or str(self._default_pandas_dataframe_extension_name or "")
@ -484,9 +501,9 @@ class Artifacts(object):
local_filename = self._push_temp_file( local_filename = self._push_temp_file(
prefix=quote(name, safe="") + ".", suffix=override_filename_ext_in_uri prefix=quote(name, safe="") + ".", suffix=override_filename_ext_in_uri
) )
if ( if (pd and isinstance(artifact_object, pd.DataFrame) and (
isinstance(artifact_object.index, pd.MultiIndex) or isinstance(artifact_object.columns, pd.MultiIndex) isinstance(artifact_object.index, pd.MultiIndex) or isinstance(artifact_object.columns, pd.MultiIndex)
) and not extension_name: ) or (pl and isinstance(artifact_object, pl.DataFrame))) and not extension_name:
store_as_pickle = True store_as_pickle = True
elif override_filename_ext_in_uri == ".csv.gz": elif override_filename_ext_in_uri == ".csv.gz":
artifact_type_data.content_type = "text/csv" artifact_type_data.content_type = "text/csv"
@ -494,7 +511,10 @@ class Artifacts(object):
elif override_filename_ext_in_uri == ".parquet": elif override_filename_ext_in_uri == ".parquet":
try: try:
artifact_type_data.content_type = "application/parquet" artifact_type_data.content_type = "application/parquet"
artifact_object.to_parquet(local_filename) if (pd and isinstance(artifact_object, pd.DataFrame)):
artifact_object.to_parquet(local_filename)
else:
artifact_object.write_parquet(local_filename)
except Exception as e: except Exception as e:
LoggerRoot.get_base_logger().warning( LoggerRoot.get_base_logger().warning(
"Exception '{}' encountered when uploading artifact as .parquet. Defaulting to .csv.gz".format( "Exception '{}' encountered when uploading artifact as .parquet. Defaulting to .csv.gz".format(
@ -506,7 +526,10 @@ class Artifacts(object):
elif override_filename_ext_in_uri == ".feather": elif override_filename_ext_in_uri == ".feather":
try: try:
artifact_type_data.content_type = "application/feather" artifact_type_data.content_type = "application/feather"
artifact_object.to_feather(local_filename) if (pd and isinstance(artifact_object, pd.DataFrame)):
artifact_object.to_feather(local_filename)
else:
artifact_object.write_ipc(local_filename)
except Exception as e: except Exception as e:
LoggerRoot.get_base_logger().warning( LoggerRoot.get_base_logger().warning(
"Exception '{}' encountered when uploading artifact as .feather. Defaulting to .csv.gz".format( "Exception '{}' encountered when uploading artifact as .feather. Defaulting to .csv.gz".format(
@ -517,7 +540,11 @@ class Artifacts(object):
self._store_compressed_pd_csv(artifact_object, local_filename) self._store_compressed_pd_csv(artifact_object, local_filename)
elif override_filename_ext_in_uri == ".pickle": elif override_filename_ext_in_uri == ".pickle":
artifact_type_data.content_type = "application/pickle" artifact_type_data.content_type = "application/pickle"
artifact_object.to_pickle(local_filename) if (pl and isinstance(artifact_object, pd.DataFrame)):
artifact_object.to_pickle(local_filename)
else:
with open(local_filename, "wb") as f:
pickle.dump(artifact_object, f)
delete_after_upload = True delete_after_upload = True
elif isinstance(artifact_object, Image.Image): elif isinstance(artifact_object, Image.Image):
artifact_type = "image" artifact_type = "image"
@ -1006,7 +1033,7 @@ class Artifacts(object):
artifacts_summary = [] artifacts_summary = []
for a_name, a_df in artifacts_dict.items(): for a_name, a_df in artifacts_dict.items():
hash_cols = self._artifacts_container.get_hash_columns(a_name) hash_cols = self._artifacts_container.get_hash_columns(a_name)
if not pd or not isinstance(a_df, pd.DataFrame): if not pd or not isinstance(a_df, pd.DataFrame) or not pl or not isinstance(a_df, pl.DataFrame):
continue continue
if hash_cols is True: if hash_cols is True:
@ -1038,8 +1065,12 @@ class Artifacts(object):
a_shape = a_df.shape a_shape = a_df.shape
# parallelize # parallelize
a_hash_cols = a_df.drop(columns=hash_col_drop) if pd and isinstance(a_df, pd.DataFrame):
thread_pool.map(hash_row, a_hash_cols.values) a_hash_cols = a_df.drop(columns=hash_col_drop)
thread_pool.map(hash_row, a_hash_cols.values)
else:
a_hash_cols = a_df.drop(hash_col_drop)
a_unique_hash.add(a_hash_cols.hash_rows())
# add result # add result
artifacts_summary.append((a_name, a_shape, a_unique_hash,)) artifacts_summary.append((a_name, a_shape, a_unique_hash,))
@ -1083,16 +1114,19 @@ class Artifacts(object):
# (otherwise it is encoded and creates new hash every time) # (otherwise it is encoded and creates new hash every time)
if self._compression == "gzip": if self._compression == "gzip":
with gzip.GzipFile(local_filename, 'wb', mtime=0) as gzip_file: with gzip.GzipFile(local_filename, 'wb', mtime=0) as gzip_file:
try: if pl and isinstance(artifact_object, pl.DataFrame):
pd_version = int(pd.__version__.split(".")[0]) artifact_object.write_csv(gzip_file)
except ValueError:
pd_version = 0
if pd_version >= 2:
artifact_object.to_csv(gzip_file, **kwargs)
else: else:
# old (pandas<2) versions of pandas cannot handle direct gzip stream, so we manually encode it try:
artifact_object.to_csv(io.TextIOWrapper(gzip_file), **kwargs) pd_version = int(pd.__version__.split(".")[0])
except ValueError:
pd_version = 0
if pd_version >= 2:
artifact_object.to_csv(gzip_file, **kwargs)
else:
# old (pandas<2) versions of pandas cannot handle direct gzip stream, so we manually encode it
artifact_object.to_csv(io.TextIOWrapper(gzip_file), **kwargs)
else: else:
artifact_object.to_csv(local_filename, compression=self._compression) artifact_object.to_csv(local_filename, compression=self._compression)

View File

@ -48,6 +48,10 @@ except ImportError:
except Exception as e: except Exception as e:
logging.warning("ClearML Dataset failed importing pandas: {}".format(e)) logging.warning("ClearML Dataset failed importing pandas: {}".format(e))
pd = None pd = None
try:
import polars as pl
except ImportError:
pl = None
try: try:
import pyarrow # noqa import pyarrow # noqa
@ -852,7 +856,7 @@ class Dataset(object):
return True return True
def set_metadata(self, metadata, metadata_name='metadata', ui_visible=True): def set_metadata(self, metadata, metadata_name='metadata', ui_visible=True):
# type: (Union[numpy.array, pd.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821 # type: (Union[numpy.array, pd.DataFrame, pl.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821
""" """
Attach a user-defined metadata to the dataset. Check `Task.upload_artifact` for supported types. Attach a user-defined metadata to the dataset. Check `Task.upload_artifact` for supported types.
If type is Pandas Dataframes, optionally make it visible as a table in the UI. If type is Pandas Dataframes, optionally make it visible as a table in the UI.
@ -861,7 +865,7 @@ class Dataset(object):
raise ValueError("metadata_name can not start with '{}'".format(self.__data_entry_name_prefix)) raise ValueError("metadata_name can not start with '{}'".format(self.__data_entry_name_prefix))
self._task.upload_artifact(name=metadata_name, artifact_object=metadata) self._task.upload_artifact(name=metadata_name, artifact_object=metadata)
if ui_visible: if ui_visible:
if pd and isinstance(metadata, pd.DataFrame): if (pd and isinstance(metadata, pd.DataFrame)) or (pl and isinstance(metadata, pl.DataFrame)):
self.get_logger().report_table( self.get_logger().report_table(
title='Dataset Metadata', title='Dataset Metadata',
series='Dataset Metadata', series='Dataset Metadata',
@ -874,7 +878,7 @@ class Dataset(object):
) )
def get_metadata(self, metadata_name='metadata'): def get_metadata(self, metadata_name='metadata'):
# type: (str) -> Optional[numpy.array, pd.DataFrame, dict, str, bool] # noqa: F821 # type: (str) -> Optional[numpy.array, pd.DataFrame, pl.DataFrame, dict, str, bool] # noqa: F821
""" """
Get attached metadata back in its original format. Will return None if none was found. Get attached metadata back in its original format. Will return None if none was found.
""" """
@ -3119,19 +3123,34 @@ class Dataset(object):
def convert_to_tabular_artifact(file_path_, file_extension_, compression_=None): def convert_to_tabular_artifact(file_path_, file_extension_, compression_=None):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if file_extension_ == ".csv" and pd: if file_extension_ == ".csv" and (pl or pd):
return pd.read_csv( if pd:
file_path_, return pd.read_csv(
nrows=self.__preview_tabular_row_count, file_path_,
compression=compression_.lstrip(".") if compression_ else None, nrows=self.__preview_tabular_row_count,
) compression=compression_.lstrip(".") if compression_ else None,
elif file_extension_ == ".tsv" and pd: )
return pd.read_csv( else:
file_path_, # TODO Re-implement compression after testing all extensions
sep='\t', return pl.read_csv(
nrows=self.__preview_tabular_row_count, file_path_,
compression=compression_.lstrip(".") if compression_ else None, n_rows=self.__preview_tabular_row_count,
) )
elif file_extension_ == ".tsv" and (pl or pd):
if pd:
return pd.read_csv(
file_path_,
sep='\t',
nrows=self.__preview_tabular_row_count,
compression=compression_.lstrip(".") if compression_ else None,
)
else:
# TODO Re-implement compression after testing all extensions
return pl.read_csv(
file_path_,
n_rows=self.__preview_tabular_row_count,
separator='\t',
)
elif file_extension_ == ".parquet" or file_extension_ == ".parq": elif file_extension_ == ".parquet" or file_extension_ == ".parq":
if pyarrow: if pyarrow:
pf = pyarrow.parquet.ParquetFile(file_path_) pf = pyarrow.parquet.ParquetFile(file_path_)
@ -3140,7 +3159,10 @@ class Dataset(object):
elif fastparquet: elif fastparquet:
return fastparquet.ParquetFile(file_path_).head(self.__preview_tabular_row_count).to_pandas() return fastparquet.ParquetFile(file_path_).head(self.__preview_tabular_row_count).to_pandas()
elif (file_extension_ == ".npz" or file_extension_ == ".npy") and np: elif (file_extension_ == ".npz" or file_extension_ == ".npy") and np:
return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count)) if pd:
return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
else
return pl.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
except Exception: except Exception:
pass pass
return None return None
@ -3172,7 +3194,10 @@ class Dataset(object):
# because it will not upload the sample to that destination. # because it will not upload the sample to that destination.
# use report_media instead to not leak data # use report_media instead to not leak data
if ( if (
isinstance(artifact, pd.DataFrame) (
(pd and isinstance(artifact, pd.DataFrame))
or (pl and isinstance(artifact, pl.DataFrame))
)
and self._task.get_logger().get_default_upload_destination() == Session.get_files_server_host() and self._task.get_logger().get_default_upload_destination() == Session.get_files_server_host()
): ):
self._task.get_logger().report_table("Tables", "summary", table_plot=artifact) self._task.get_logger().report_table("Tables", "summary", table_plot=artifact)

View File

@ -10,6 +10,11 @@ from pathlib2 import Path
from .debugging.log import LoggerRoot from .debugging.log import LoggerRoot
try:
import polars as pl
except ImportError:
pl = None
try: try:
import pandas as pd import pandas as pd
except ImportError: except ImportError:
@ -327,7 +332,7 @@ class Logger(object):
title, # type: str title, # type: str
series, # type: str series, # type: str
iteration=None, # type: Optional[int] iteration=None, # type: Optional[int]
table_plot=None, # type: Optional[pd.DataFrame, Sequence[Sequence]] table_plot=None, # type: Optional[pd.DataFrame, pl.DataFrame, Sequence[Sequence]]
csv=None, # type: Optional[str] csv=None, # type: Optional[str]
url=None, # type: Optional[str] url=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict] extra_layout=None, # type: Optional[dict]
@ -393,15 +398,15 @@ class Logger(object):
mutually_exclusive(UsageError, _check_none=True, table_plot=table_plot, csv=csv, url=url) mutually_exclusive(UsageError, _check_none=True, table_plot=table_plot, csv=csv, url=url)
table = table_plot table = table_plot
if url or csv: if url or csv:
if not pd: if not pd and not pl:
raise UsageError( raise UsageError(
"pandas is required in order to support reporting tables using CSV or a URL, " "pandas or polars is required in order to support reporting tables using CSV "
"please install the pandas python package" "or a URL, please install the pandas or polars python package"
) )
if url: if url:
table = pd.read_csv(url, index_col=[0]) table = pd.read_csv(url, index_col=[0]) if pd else pl.read_csv(url)
elif csv: elif csv:
table = pd.read_csv(csv, index_col=[0]) table = pd.read_csv(csv, index_col=[0]) if pd else pl.read_csv(csv)
def replace(dst, *srcs): def replace(dst, *srcs):
for src in srcs: for src in srcs:
@ -410,7 +415,8 @@ class Logger(object):
if isinstance(table, (list, tuple)): if isinstance(table, (list, tuple)):
reporter_table = table reporter_table = table
else: else:
reporter_table = table.fillna(str(np.nan)) nan = str(np.nan)
reporter_table = table.fillna(nan) if pd else table.fill_nan(nan)
replace("NaN", np.nan, math.nan if six.PY3 else float("nan")) replace("NaN", np.nan, math.nan if six.PY3 else float("nan"))
replace("Inf", np.inf, math.inf if six.PY3 else float("inf")) replace("Inf", np.inf, math.inf if six.PY3 else float("inf"))
minus_inf = [-np.inf, -math.inf if six.PY3 else -float("inf")] minus_inf = [-np.inf, -math.inf if six.PY3 else -float("inf")]

View File

@ -14,6 +14,10 @@ try:
import pandas as pd import pandas as pd
except ImportError: except ImportError:
pd = None pd = None
try:
import polars as pl
except ImportError:
pl = None
from .backend_api import Session from .backend_api import Session
from .backend_api.services import models, projects from .backend_api.services import models, projects
@ -638,15 +642,21 @@ class BaseModel(object):
) )
table = table_plot table = table_plot
if url or csv: if url or csv:
if not pd: if not pd and not pl:
raise UsageError( raise UsageError(
"pandas is required in order to support reporting tables using CSV or a URL, " "pandas or polars is required in order to support reporting tables using CSV or a URL, "
"please install the pandas python package" "please install the pandas or polars python package"
) )
if url: if url:
table = pd.read_csv(url, index_col=[0]) if pd:
table = pd.read_csv(url, index_col=[0])
else:
table = pd.read_csv(url)
elif csv: elif csv:
table = pd.read_csv(csv, index_col=[0]) if pd:
table = pd.read_csv(csv, index_col=[0])
else:
table = pd.read_csv(url)
def replace(dst, *srcs): def replace(dst, *srcs):
for src in srcs: for src in srcs:

View File

@ -4,6 +4,10 @@ import numpy as np
from ..errors import UsageError from ..errors import UsageError
from ..utilities.dicts import merge_dicts from ..utilities.dicts import merge_dicts
try:
import polars as pl
except ImportError:
pl = None
try: try:
import pandas as pd import pandas as pd
except ImportError: except ImportError:
@ -486,7 +490,7 @@ def create_plotly_table(table_plot, title, series, layout_config=None, data_conf
""" """
Create a basic Plotly table json style to be sent Create a basic Plotly table json style to be sent
:param table_plot: the output table in pandas.DataFrame structure or list of rows (list) in a table :param table_plot: the output table in pandas.DataFrame structure or polars.Dataframe structure or list of rows (list) in a table
:param title: Title (AKA metric) :param title: Title (AKA metric)
:type title: str :type title: str
:param series: Series (AKA variant) :param series: Series (AKA variant)
@ -503,11 +507,19 @@ def create_plotly_table(table_plot, title, series, layout_config=None, data_conf
elif is_list and table_plot[0] and isinstance(table_plot[0], (list, tuple)): elif is_list and table_plot[0] and isinstance(table_plot[0], (list, tuple)):
headers_values = table_plot[0] headers_values = table_plot[0]
cells_values = [list(i) for i in zip(*table_plot[1:])] cells_values = [list(i) for i in zip(*table_plot[1:])]
elif pl and isinstance(table_plot, pl.DataFrame):
headers_values = list([col] for col in table_plot.columns)
# Convert datetimes to ISO strings
datetime_columns = table_plot.select(pl.selectors.datetime()).columns
exprs = [pl.col(col).dt.to_string("iso:strict") for col in datetime_columns]
# Get cell values and preserve value types
cells_values_transpose = table_plot.with_columns(*exprs).rows()
cells_values = list(map(list, zip(*cells_values_transpose)))
else: else:
if not pd: if not pd:
raise UsageError( raise UsageError(
"pandas is required in order to support reporting tables using CSV or a URL, " "pandas or polars is required in order to support reporting tables using CSV or a URL, "
"please install the pandas python package" "please install the pandas or polars python package"
) )
index_added = not isinstance(table_plot.index, pd.RangeIndex) index_added = not isinstance(table_plot.index, pd.RangeIndex)
headers_values = list([col] for col in table_plot.columns) headers_values = list([col] for col in table_plot.columns)