mirror of
https://github.com/clearml/clearml
synced 2025-04-05 13:15:17 +00:00
Merge 969a48906b
into 5f680c3079
This commit is contained in:
commit
49706ebeda
@ -654,7 +654,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param table: The table data
|
||||
:type table: pandas.DataFrame
|
||||
:type table: pandas.DataFrame or polars.DataFrame
|
||||
:param iteration: Iteration number
|
||||
:type iteration: int
|
||||
:param layout_config: optional dictionary for layout configuration, passed directly to plotly
|
||||
|
@ -30,6 +30,10 @@ from ..utilities.process.mp import SafeEvent, ForkSafeRLock
|
||||
from ..utilities.proxy_object import LazyEvalWrapper
|
||||
from ..config import deferred_config, config
|
||||
|
||||
try:
|
||||
import polars as pl
|
||||
except ImportError:
|
||||
pl = None
|
||||
try:
|
||||
import pandas as pd
|
||||
DataFrame = pd.DataFrame
|
||||
@ -152,6 +156,7 @@ class Artifact(object):
|
||||
Supported content types are:
|
||||
- dict - ``.json``, ``.yaml``
|
||||
- pandas.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle``
|
||||
- polars.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle``
|
||||
- numpy.ndarray - ``.npz``, ``.csv.gz``
|
||||
- PIL.Image - whatever content types PIL supports
|
||||
All other types will return a pathlib2.Path object pointing to a local copy of the artifacts file (or directory).
|
||||
@ -192,6 +197,18 @@ class Artifact(object):
|
||||
self._object = pd.read_csv(local_file)
|
||||
else:
|
||||
self._object = pd.read_csv(local_file, index_col=[0])
|
||||
elif self.type == Artifacts._pd_artifact_type or self.type == "polars" and pl:
|
||||
if self._content_type == "application/parquet":
|
||||
self._object = pl.read_parquet(local_file)
|
||||
elif self._content_type == "application/feather":
|
||||
self._object = pl.read_ipc(local_file)
|
||||
elif self._content_type == "application/pickle":
|
||||
with open(local_file, "rb") as f:
|
||||
self._object = pickle.load(f)
|
||||
elif self.type == Artifacts._pd_artifact_type:
|
||||
self._object = pl.read_csv(local_file)
|
||||
else:
|
||||
self._object = pl.read_csv(local_file)
|
||||
elif self.type == "image":
|
||||
self._object = Image.open(local_file)
|
||||
elif self.type == "JSON" or self.type == "dict":
|
||||
@ -280,14 +297,14 @@ class Artifacts(object):
|
||||
self.artifact_hash_columns = {}
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# check that value is of type pandas
|
||||
if pd and isinstance(value, pd.DataFrame):
|
||||
# check that value is of type pandas or polars
|
||||
if (pd and isinstance(value, pd.DataFrame)) or (pl and isinstance(value, pl.DataFrame)):
|
||||
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)
|
||||
|
||||
if self._artifacts_manager:
|
||||
self._artifacts_manager.flush()
|
||||
else:
|
||||
raise ValueError('Artifacts currently support pandas.DataFrame objects only')
|
||||
raise ValueError('Artifacts currently support pandas.DataFrame and polars.DataFrame objects only')
|
||||
|
||||
def unregister_artifact(self, name):
|
||||
self.artifact_metadata.pop(name, None)
|
||||
@ -472,8 +489,8 @@ class Artifacts(object):
|
||||
artifact_type_data.content_type = "text/csv"
|
||||
np.savetxt(local_filename, artifact_object, delimiter=",")
|
||||
delete_after_upload = True
|
||||
elif pd and isinstance(artifact_object, pd.DataFrame):
|
||||
artifact_type = "pandas"
|
||||
elif (pd and isinstance(artifact_object, pd.DataFrame)) or (pl and isinstance(artifact_object, pl.DataFrame)):
|
||||
artifact_type = "pandas" if (pd and isinstance(artifact_object, pd.DataFrame)) else "polars"
|
||||
artifact_type_data.preview = preview or str(artifact_object.__repr__())
|
||||
# we are making sure self._default_pandas_dataframe_extension_name is not deferred
|
||||
extension_name = extension_name or str(self._default_pandas_dataframe_extension_name or "")
|
||||
@ -484,9 +501,9 @@ class Artifacts(object):
|
||||
local_filename = self._push_temp_file(
|
||||
prefix=quote(name, safe="") + ".", suffix=override_filename_ext_in_uri
|
||||
)
|
||||
if (
|
||||
if (pd and isinstance(artifact_object, pd.DataFrame) and (
|
||||
isinstance(artifact_object.index, pd.MultiIndex) or isinstance(artifact_object.columns, pd.MultiIndex)
|
||||
) and not extension_name:
|
||||
) or (pl and isinstance(artifact_object, pl.DataFrame))) and not extension_name:
|
||||
store_as_pickle = True
|
||||
elif override_filename_ext_in_uri == ".csv.gz":
|
||||
artifact_type_data.content_type = "text/csv"
|
||||
@ -494,7 +511,10 @@ class Artifacts(object):
|
||||
elif override_filename_ext_in_uri == ".parquet":
|
||||
try:
|
||||
artifact_type_data.content_type = "application/parquet"
|
||||
artifact_object.to_parquet(local_filename)
|
||||
if (pd and isinstance(artifact_object, pd.DataFrame)):
|
||||
artifact_object.to_parquet(local_filename)
|
||||
else:
|
||||
artifact_object.write_parquet(local_filename)
|
||||
except Exception as e:
|
||||
LoggerRoot.get_base_logger().warning(
|
||||
"Exception '{}' encountered when uploading artifact as .parquet. Defaulting to .csv.gz".format(
|
||||
@ -506,7 +526,10 @@ class Artifacts(object):
|
||||
elif override_filename_ext_in_uri == ".feather":
|
||||
try:
|
||||
artifact_type_data.content_type = "application/feather"
|
||||
artifact_object.to_feather(local_filename)
|
||||
if (pd and isinstance(artifact_object, pd.DataFrame)):
|
||||
artifact_object.to_feather(local_filename)
|
||||
else:
|
||||
artifact_object.write_ipc(local_filename)
|
||||
except Exception as e:
|
||||
LoggerRoot.get_base_logger().warning(
|
||||
"Exception '{}' encountered when uploading artifact as .feather. Defaulting to .csv.gz".format(
|
||||
@ -517,7 +540,11 @@ class Artifacts(object):
|
||||
self._store_compressed_pd_csv(artifact_object, local_filename)
|
||||
elif override_filename_ext_in_uri == ".pickle":
|
||||
artifact_type_data.content_type = "application/pickle"
|
||||
artifact_object.to_pickle(local_filename)
|
||||
if (pl and isinstance(artifact_object, pd.DataFrame)):
|
||||
artifact_object.to_pickle(local_filename)
|
||||
else:
|
||||
with open(local_filename, "wb") as f:
|
||||
pickle.dump(artifact_object, f)
|
||||
delete_after_upload = True
|
||||
elif isinstance(artifact_object, Image.Image):
|
||||
artifact_type = "image"
|
||||
@ -1006,7 +1033,7 @@ class Artifacts(object):
|
||||
artifacts_summary = []
|
||||
for a_name, a_df in artifacts_dict.items():
|
||||
hash_cols = self._artifacts_container.get_hash_columns(a_name)
|
||||
if not pd or not isinstance(a_df, pd.DataFrame):
|
||||
if not pd or not isinstance(a_df, pd.DataFrame) or not pl or not isinstance(a_df, pl.DataFrame):
|
||||
continue
|
||||
|
||||
if hash_cols is True:
|
||||
@ -1038,8 +1065,12 @@ class Artifacts(object):
|
||||
|
||||
a_shape = a_df.shape
|
||||
# parallelize
|
||||
a_hash_cols = a_df.drop(columns=hash_col_drop)
|
||||
thread_pool.map(hash_row, a_hash_cols.values)
|
||||
if pd and isinstance(a_df, pd.DataFrame):
|
||||
a_hash_cols = a_df.drop(columns=hash_col_drop)
|
||||
thread_pool.map(hash_row, a_hash_cols.values)
|
||||
else:
|
||||
a_hash_cols = a_df.drop(hash_col_drop)
|
||||
a_unique_hash.add(a_hash_cols.hash_rows())
|
||||
# add result
|
||||
artifacts_summary.append((a_name, a_shape, a_unique_hash,))
|
||||
|
||||
@ -1083,16 +1114,19 @@ class Artifacts(object):
|
||||
# (otherwise it is encoded and creates new hash every time)
|
||||
if self._compression == "gzip":
|
||||
with gzip.GzipFile(local_filename, 'wb', mtime=0) as gzip_file:
|
||||
try:
|
||||
pd_version = int(pd.__version__.split(".")[0])
|
||||
except ValueError:
|
||||
pd_version = 0
|
||||
|
||||
if pd_version >= 2:
|
||||
artifact_object.to_csv(gzip_file, **kwargs)
|
||||
if pl and isinstance(artifact_object, pl.DataFrame):
|
||||
artifact_object.write_csv(gzip_file)
|
||||
else:
|
||||
# old (pandas<2) versions of pandas cannot handle direct gzip stream, so we manually encode it
|
||||
artifact_object.to_csv(io.TextIOWrapper(gzip_file), **kwargs)
|
||||
try:
|
||||
pd_version = int(pd.__version__.split(".")[0])
|
||||
except ValueError:
|
||||
pd_version = 0
|
||||
|
||||
if pd_version >= 2:
|
||||
artifact_object.to_csv(gzip_file, **kwargs)
|
||||
else:
|
||||
# old (pandas<2) versions of pandas cannot handle direct gzip stream, so we manually encode it
|
||||
artifact_object.to_csv(io.TextIOWrapper(gzip_file), **kwargs)
|
||||
else:
|
||||
artifact_object.to_csv(local_filename, compression=self._compression)
|
||||
|
||||
|
@ -48,6 +48,10 @@ except ImportError:
|
||||
except Exception as e:
|
||||
logging.warning("ClearML Dataset failed importing pandas: {}".format(e))
|
||||
pd = None
|
||||
try:
|
||||
import polars as pl
|
||||
except ImportError:
|
||||
pl = None
|
||||
|
||||
try:
|
||||
import pyarrow # noqa
|
||||
@ -852,7 +856,7 @@ class Dataset(object):
|
||||
return True
|
||||
|
||||
def set_metadata(self, metadata, metadata_name='metadata', ui_visible=True):
|
||||
# type: (Union[numpy.array, pd.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821
|
||||
# type: (Union[numpy.array, pd.DataFrame, pl.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821
|
||||
"""
|
||||
Attach a user-defined metadata to the dataset. Check `Task.upload_artifact` for supported types.
|
||||
If type is Pandas Dataframes, optionally make it visible as a table in the UI.
|
||||
@ -861,7 +865,7 @@ class Dataset(object):
|
||||
raise ValueError("metadata_name can not start with '{}'".format(self.__data_entry_name_prefix))
|
||||
self._task.upload_artifact(name=metadata_name, artifact_object=metadata)
|
||||
if ui_visible:
|
||||
if pd and isinstance(metadata, pd.DataFrame):
|
||||
if (pd and isinstance(metadata, pd.DataFrame)) or (pl and isinstance(metadata, pl.DataFrame)):
|
||||
self.get_logger().report_table(
|
||||
title='Dataset Metadata',
|
||||
series='Dataset Metadata',
|
||||
@ -874,7 +878,7 @@ class Dataset(object):
|
||||
)
|
||||
|
||||
def get_metadata(self, metadata_name='metadata'):
|
||||
# type: (str) -> Optional[numpy.array, pd.DataFrame, dict, str, bool] # noqa: F821
|
||||
# type: (str) -> Optional[numpy.array, pd.DataFrame, pl.DataFrame, dict, str, bool] # noqa: F821
|
||||
"""
|
||||
Get attached metadata back in its original format. Will return None if none was found.
|
||||
"""
|
||||
@ -3119,19 +3123,34 @@ class Dataset(object):
|
||||
def convert_to_tabular_artifact(file_path_, file_extension_, compression_=None):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if file_extension_ == ".csv" and pd:
|
||||
return pd.read_csv(
|
||||
file_path_,
|
||||
nrows=self.__preview_tabular_row_count,
|
||||
compression=compression_.lstrip(".") if compression_ else None,
|
||||
)
|
||||
elif file_extension_ == ".tsv" and pd:
|
||||
return pd.read_csv(
|
||||
file_path_,
|
||||
sep='\t',
|
||||
nrows=self.__preview_tabular_row_count,
|
||||
compression=compression_.lstrip(".") if compression_ else None,
|
||||
)
|
||||
if file_extension_ == ".csv" and (pl or pd):
|
||||
if pd:
|
||||
return pd.read_csv(
|
||||
file_path_,
|
||||
nrows=self.__preview_tabular_row_count,
|
||||
compression=compression_.lstrip(".") if compression_ else None,
|
||||
)
|
||||
else:
|
||||
# TODO Re-implement compression after testing all extensions
|
||||
return pl.read_csv(
|
||||
file_path_,
|
||||
n_rows=self.__preview_tabular_row_count,
|
||||
)
|
||||
elif file_extension_ == ".tsv" and (pl or pd):
|
||||
if pd:
|
||||
return pd.read_csv(
|
||||
file_path_,
|
||||
sep='\t',
|
||||
nrows=self.__preview_tabular_row_count,
|
||||
compression=compression_.lstrip(".") if compression_ else None,
|
||||
)
|
||||
else:
|
||||
# TODO Re-implement compression after testing all extensions
|
||||
return pl.read_csv(
|
||||
file_path_,
|
||||
n_rows=self.__preview_tabular_row_count,
|
||||
separator='\t',
|
||||
)
|
||||
elif file_extension_ == ".parquet" or file_extension_ == ".parq":
|
||||
if pyarrow:
|
||||
pf = pyarrow.parquet.ParquetFile(file_path_)
|
||||
@ -3140,7 +3159,10 @@ class Dataset(object):
|
||||
elif fastparquet:
|
||||
return fastparquet.ParquetFile(file_path_).head(self.__preview_tabular_row_count).to_pandas()
|
||||
elif (file_extension_ == ".npz" or file_extension_ == ".npy") and np:
|
||||
return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
|
||||
if pd:
|
||||
return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
|
||||
else
|
||||
return pl.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
@ -3172,7 +3194,10 @@ class Dataset(object):
|
||||
# because it will not upload the sample to that destination.
|
||||
# use report_media instead to not leak data
|
||||
if (
|
||||
isinstance(artifact, pd.DataFrame)
|
||||
(
|
||||
(pd and isinstance(artifact, pd.DataFrame))
|
||||
or (pl and isinstance(artifact, pl.DataFrame))
|
||||
)
|
||||
and self._task.get_logger().get_default_upload_destination() == Session.get_files_server_host()
|
||||
):
|
||||
self._task.get_logger().report_table("Tables", "summary", table_plot=artifact)
|
||||
|
@ -10,6 +10,11 @@ from pathlib2 import Path
|
||||
|
||||
from .debugging.log import LoggerRoot
|
||||
|
||||
try:
|
||||
import polars as pl
|
||||
except ImportError:
|
||||
pl = None
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
@ -327,7 +332,7 @@ class Logger(object):
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
iteration=None, # type: Optional[int]
|
||||
table_plot=None, # type: Optional[pd.DataFrame, Sequence[Sequence]]
|
||||
table_plot=None, # type: Optional[pd.DataFrame, pl.DataFrame, Sequence[Sequence]]
|
||||
csv=None, # type: Optional[str]
|
||||
url=None, # type: Optional[str]
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
@ -393,15 +398,15 @@ class Logger(object):
|
||||
mutually_exclusive(UsageError, _check_none=True, table_plot=table_plot, csv=csv, url=url)
|
||||
table = table_plot
|
||||
if url or csv:
|
||||
if not pd:
|
||||
if not pd and not pl:
|
||||
raise UsageError(
|
||||
"pandas is required in order to support reporting tables using CSV or a URL, "
|
||||
"please install the pandas python package"
|
||||
"pandas or polars is required in order to support reporting tables using CSV "
|
||||
"or a URL, please install the pandas or polars python package"
|
||||
)
|
||||
if url:
|
||||
table = pd.read_csv(url, index_col=[0])
|
||||
table = pd.read_csv(url, index_col=[0]) if pd else pl.read_csv(url)
|
||||
elif csv:
|
||||
table = pd.read_csv(csv, index_col=[0])
|
||||
table = pd.read_csv(csv, index_col=[0]) if pd else pl.read_csv(csv)
|
||||
|
||||
def replace(dst, *srcs):
|
||||
for src in srcs:
|
||||
@ -410,7 +415,8 @@ class Logger(object):
|
||||
if isinstance(table, (list, tuple)):
|
||||
reporter_table = table
|
||||
else:
|
||||
reporter_table = table.fillna(str(np.nan))
|
||||
nan = str(np.nan)
|
||||
reporter_table = table.fillna(nan) if pd else table.fill_nan(nan)
|
||||
replace("NaN", np.nan, math.nan if six.PY3 else float("nan"))
|
||||
replace("Inf", np.inf, math.inf if six.PY3 else float("inf"))
|
||||
minus_inf = [-np.inf, -math.inf if six.PY3 else -float("inf")]
|
||||
|
@ -14,6 +14,10 @@ try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = None
|
||||
try:
|
||||
import polars as pl
|
||||
except ImportError:
|
||||
pl = None
|
||||
|
||||
from .backend_api import Session
|
||||
from .backend_api.services import models, projects
|
||||
@ -638,15 +642,21 @@ class BaseModel(object):
|
||||
)
|
||||
table = table_plot
|
||||
if url or csv:
|
||||
if not pd:
|
||||
if not pd and not pl:
|
||||
raise UsageError(
|
||||
"pandas is required in order to support reporting tables using CSV or a URL, "
|
||||
"please install the pandas python package"
|
||||
"pandas or polars is required in order to support reporting tables using CSV or a URL, "
|
||||
"please install the pandas or polars python package"
|
||||
)
|
||||
if url:
|
||||
table = pd.read_csv(url, index_col=[0])
|
||||
if pd:
|
||||
table = pd.read_csv(url, index_col=[0])
|
||||
else:
|
||||
table = pd.read_csv(url)
|
||||
elif csv:
|
||||
table = pd.read_csv(csv, index_col=[0])
|
||||
if pd:
|
||||
table = pd.read_csv(csv, index_col=[0])
|
||||
else:
|
||||
table = pd.read_csv(url)
|
||||
|
||||
def replace(dst, *srcs):
|
||||
for src in srcs:
|
||||
|
@ -4,6 +4,10 @@ import numpy as np
|
||||
from ..errors import UsageError
|
||||
from ..utilities.dicts import merge_dicts
|
||||
|
||||
try:
|
||||
import polars as pl
|
||||
except ImportError:
|
||||
pl = None
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
@ -486,7 +490,7 @@ def create_plotly_table(table_plot, title, series, layout_config=None, data_conf
|
||||
"""
|
||||
Create a basic Plotly table json style to be sent
|
||||
|
||||
:param table_plot: the output table in pandas.DataFrame structure or list of rows (list) in a table
|
||||
:param table_plot: the output table in pandas.DataFrame structure or polars.Dataframe structure or list of rows (list) in a table
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
@ -503,11 +507,19 @@ def create_plotly_table(table_plot, title, series, layout_config=None, data_conf
|
||||
elif is_list and table_plot[0] and isinstance(table_plot[0], (list, tuple)):
|
||||
headers_values = table_plot[0]
|
||||
cells_values = [list(i) for i in zip(*table_plot[1:])]
|
||||
elif pl and isinstance(table_plot, pl.DataFrame):
|
||||
headers_values = list([col] for col in table_plot.columns)
|
||||
# Convert datetimes to ISO strings
|
||||
datetime_columns = table_plot.select(pl.selectors.datetime()).columns
|
||||
exprs = [pl.col(col).dt.to_string("iso:strict") for col in datetime_columns]
|
||||
# Get cell values and preserve value types
|
||||
cells_values_transpose = table_plot.with_columns(*exprs).rows()
|
||||
cells_values = list(map(list, zip(*cells_values_transpose)))
|
||||
else:
|
||||
if not pd:
|
||||
raise UsageError(
|
||||
"pandas is required in order to support reporting tables using CSV or a URL, "
|
||||
"please install the pandas python package"
|
||||
"pandas or polars is required in order to support reporting tables using CSV or a URL, "
|
||||
"please install the pandas or polars python package"
|
||||
)
|
||||
index_added = not isinstance(table_plot.index, pd.RangeIndex)
|
||||
headers_values = list([col] for col in table_plot.columns)
|
||||
|
Loading…
Reference in New Issue
Block a user