Add support for a default extension name when uploading a pandas dataframe artifact (see sdk.development.artifacts.default_pandas_dataframe_extension_name)

This commit is contained in:
allegroai 2024-08-24 22:10:23 +03:00
parent 0086785372
commit 9a9ee54681
2 changed files with 21 additions and 6 deletions

View File

@ -28,6 +28,7 @@ from ..storage.helper import remote_driver_schemes
from ..storage.util import sha256sum, format_size, get_common_path
from ..utilities.process.mp import SafeEvent, ForkSafeRLock
from ..utilities.proxy_object import LazyEvalWrapper
from ..config import deferred_config
try:
import pandas as pd
@ -262,6 +263,9 @@ class Artifacts(object):
# hashing constants
_hash_block_size = 65536
_pd_artifact_type = 'data-audit-table'
_default_pandas_dataframe_extension_name = deferred_config(
"development.artifacts.default_pandas_dataframe_extension_name", None
)
class _ProxyDictWrite(dict):
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
@ -464,19 +468,23 @@ class Artifacts(object):
artifact_type_data.content_type = "text/csv"
np.savetxt(local_filename, artifact_object, delimiter=",")
delete_after_upload = True
elif pd and isinstance(artifact_object, pd.DataFrame) \
and (isinstance(artifact_object.index, pd.MultiIndex) or
isinstance(artifact_object.columns, pd.MultiIndex)):
store_as_pickle = True
elif pd and isinstance(artifact_object, pd.DataFrame):
artifact_type = "pandas"
artifact_type_data.preview = preview or str(artifact_object.__repr__())
# we are making sure self._default_pandas_dataframe_extension_name is not deferred
extension_name = extension_name or str(self._default_pandas_dataframe_extension_name or "")
override_filename_ext_in_uri = get_extension(
extension_name, [".csv.gz", ".parquet", ".feather", ".pickle"], ".csv.gz", artifact_type
)
override_filename_in_uri = name
local_filename = self._push_temp_file(prefix=quote(name, safe="") + '.', suffix=override_filename_ext_in_uri)
if override_filename_ext_in_uri == ".csv.gz":
local_filename = self._push_temp_file(
prefix=quote(name, safe="") + ".", suffix=override_filename_ext_in_uri
)
if (
isinstance(artifact_object.index, pd.MultiIndex) or isinstance(artifact_object.columns, pd.MultiIndex)
) and not extension_name:
store_as_pickle = True
elif override_filename_ext_in_uri == ".csv.gz":
artifact_type_data.content_type = "text/csv"
self._store_compressed_pd_csv(artifact_object, local_filename)
elif override_filename_ext_in_uri == ".parquet":

View File

@ -240,5 +240,12 @@ sdk {
# iteration reporting x-axis after starting to report "seconds from start"
# max_wait_for_first_iteration_to_start_sec: 1800
}
artifacts {
# set default extension_name for pandas DataFrame objects
# valid values are: ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle``
# extension_name supplied to Task.upload_artifact is prioritized over this value
default_pandas_dataframe_extension_name: ""
}
}
}