Add artifacts upload with a custom serializer (#689)

This commit is contained in:
allegroai 2022-09-06 10:47:56 +03:00
parent f04c39fb30
commit 5228b799c1
3 changed files with 83 additions and 15 deletions

View File

@ -16,7 +16,7 @@ import six
from PIL import Image
from pathlib2 import Path
from six.moves.urllib.parse import urlparse
from typing import Dict, Union, Optional, Any, Sequence
from typing import Dict, Union, Optional, Any, Sequence, Callable
from ..backend_api import Session
from ..backend_api.services import tasks
@ -140,8 +140,8 @@ class Artifact(object):
self._content_type = artifact_api_object.type_data.content_type if artifact_api_object.type_data else None
self._object = self._not_set
def get(self, force_download=False):
# type: (bool) -> Any
def get(self, force_download=False, deserialization_function=None):
# type: (bool, Optional[Callable[Union[bytes], Any]]) -> Any
"""
Return an object constructed from the artifact file
@ -156,7 +156,12 @@ class Artifact(object):
pointing to a local copy of the artifacts file (or directory) will be returned
:param bool force_download: download file from remote even if exists in local cache
:return: One of the following objects Numpy.array, pandas.DataFrame, PIL.Image, dict (json), or pathlib2.Path.
:param Callable[Union[bytes], Any] deserialization_function: A deserialization function that takes one parameter of type `bytes`,
which represents the serialized object. This function should return the deserialized object.
Useful when the artifact was uploaded using a custom serialization function when calling the
`Task.upload_artifact` method with the `serialization_function` argument.
:return: Usually, one of the following objects: Numpy.array, pandas.DataFrame, PIL.Image, dict (json), or pathlib2.Path.
An object with an arbitrary type may also be returned if it was serialized (using pickle or a custom serialization function).
"""
if self._object is not self._not_set:
return self._object
@ -165,7 +170,10 @@ class Artifact(object):
# noinspection PyBroadException
try:
if self.type == "numpy" and np:
if deserialization_function:
with open(local_file, "rb") as f:
self._object = deserialization_function(f.read())
elif self.type == "numpy" and np:
if self._content_type == "text/csv":
self._object = np.genfromtxt(local_file, delimiter=",")
else:
@ -339,12 +347,24 @@ class Artifacts(object):
self._unregister_request.add(name)
self.flush()
def upload_artifact(self, name, artifact_object=None, metadata=None, preview=None,
delete_after_upload=False, auto_pickle=True, wait_on_upload=False, extension_name=None):
# type: (str, Optional[object], Optional[dict], Optional[str], bool, bool, bool, Optional[str]) -> bool
if not Session.check_min_api_version('2.3'):
LoggerRoot.get_base_logger().warning('Artifacts not supported by your ClearML-server version, '
'please upgrade to the latest server version')
def upload_artifact(
self,
name, # type: str
artifact_object=None, # type: Optional[object]
metadata=None, # type: Optional[dict]
preview=None, # type: Optional[str]
delete_after_upload=False, # type: bool
auto_pickle=True, # type: bool
wait_on_upload=False, # type: bool
extension_name=None, # type: Optional[str]
serialization_function=None, # type: Optional[Callable[Any, Union[bytes, bytearray]]]
):
# type: (...) -> bool
if not Session.check_min_api_version("2.3"):
LoggerRoot.get_base_logger().warning(
"Artifacts not supported by your ClearML-server version,"
" please upgrade to the latest server version"
)
return False
if name in self._artifacts_container:
@ -399,7 +419,27 @@ class Artifacts(object):
)
return default_extension
if np and isinstance(artifact_object, np.ndarray):
if serialization_function:
artifact_type = "custom"
# noinspection PyBroadException
try:
artifact_type_data.preview = preview or str(artifact_object.__repr__())[:self.max_preview_size_bytes]
except Exception:
artifact_type_data.preview = ""
override_filename_ext_in_uri = ""
override_filename_in_uri = name
fd, local_filename = mkstemp(prefix=quote(name, safe="") + ".")
os.close(fd)
# noinspection PyBroadException
try:
with open(local_filename, "wb") as f:
f.write(serialization_function(artifact_object))
except Exception:
# cleanup and raise exception
os.unlink(local_filename)
raise
artifact_type_data.content_type = mimetypes.guess_type(local_filename)[0]
elif np and isinstance(artifact_object, np.ndarray):
artifact_type = 'numpy'
artifact_type_data.preview = preview or str(artifact_object.__repr__())
override_filename_ext_in_uri = get_extension(

View File

@ -18,7 +18,19 @@ try:
except ImportError:
from collections import Sequence as CollectionsSequence
from typing import Optional, Union, Mapping, Sequence, Any, Dict, Iterable, TYPE_CHECKING, Callable, Tuple, List
from typing import (
Optional,
Union,
Mapping,
Sequence,
Any,
Dict,
Iterable,
TYPE_CHECKING,
Callable,
Tuple,
List,
)
import psutil
import six
@ -1861,6 +1873,7 @@ class Task(_Task):
preview=None, # type: Any
wait_on_upload=False, # type: bool
extension_name=None, # type: Optional[str]
serialization_function=None # type: Optional[Callable[Any, Union[bytes, bytearray]]]
):
# type: (...) -> bool
"""
@ -1908,6 +1921,12 @@ class Task(_Task):
- numpy.ndarray - ``.npz``, ``.csv.gz`` (default ``.npz``)
- PIL.Image - whatever extensions PIL supports (default ``.png``)
:param Callable[Any, Union[bytes, bytearray]] serialization_function: A serialization function that takes one
parameter of any types which is the object to be serialized. The function should return a `bytes` or `bytearray`
object, which represents the serialized object. Note that the object will be immediately serialized using this function,
thus other serialization methods will not be used (e.g. `pandas.DataFrame.to_csv`), even if possible.
To deserialize this artifact when getting it using the `Artifact.get` method, use its `deserialization_function` argument
:return: The status of the upload.
- ``True`` - Upload succeeded.
@ -1916,8 +1935,16 @@ class Task(_Task):
:raise: If the artifact object type is not supported, raise a ``ValueError``.
"""
return self._artifacts_manager.upload_artifact(
name=name, artifact_object=artifact_object, metadata=metadata, delete_after_upload=delete_after_upload,
auto_pickle=auto_pickle, preview=preview, wait_on_upload=wait_on_upload, extension_name=extension_name)
name=name,
artifact_object=artifact_object,
metadata=metadata,
delete_after_upload=delete_after_upload,
auto_pickle=auto_pickle,
preview=preview,
wait_on_upload=wait_on_upload,
extension_name=extension_name,
serialization_function=serialization_function,
)
def get_models(self):
# type: () -> Mapping[str, Sequence[Model]]

View File

@ -572,6 +572,7 @@ def prep_xy_axis(ax, props, x_bounds, y_bounds):
yaxis.update(prep_ticks(ax, 1, "y", props))
return xaxis, yaxis
def prep_xyz_axis(ax, props, x_bounds, y_bounds):
# there is no z_bounds as they can't (at least easily) be extracted from an `Axes3DSubplot` object
xaxis, yaxis = prep_xy_axis(ax, props, x_bounds, y_bounds)