Add artifacts upload with a custom serializer (#689)

This commit is contained in:
allegroai 2022-09-06 10:47:56 +03:00
parent f04c39fb30
commit 5228b799c1
3 changed files with 83 additions and 15 deletions

View File

@ -16,7 +16,7 @@ import six
from PIL import Image from PIL import Image
from pathlib2 import Path from pathlib2 import Path
from six.moves.urllib.parse import urlparse from six.moves.urllib.parse import urlparse
from typing import Dict, Union, Optional, Any, Sequence from typing import Dict, Union, Optional, Any, Sequence, Callable
from ..backend_api import Session from ..backend_api import Session
from ..backend_api.services import tasks from ..backend_api.services import tasks
@ -140,8 +140,8 @@ class Artifact(object):
self._content_type = artifact_api_object.type_data.content_type if artifact_api_object.type_data else None self._content_type = artifact_api_object.type_data.content_type if artifact_api_object.type_data else None
self._object = self._not_set self._object = self._not_set
def get(self, force_download=False): def get(self, force_download=False, deserialization_function=None):
# type: (bool) -> Any # type: (bool, Optional[Callable[Union[bytes], Any]]) -> Any
""" """
Return an object constructed from the artifact file Return an object constructed from the artifact file
@ -156,7 +156,12 @@ class Artifact(object):
pointing to a local copy of the artifacts file (or directory) will be returned pointing to a local copy of the artifacts file (or directory) will be returned
:param bool force_download: download file from remote even if exists in local cache :param bool force_download: download file from remote even if exists in local cache
:return: One of the following objects Numpy.array, pandas.DataFrame, PIL.Image, dict (json), or pathlib2.Path. :param Callable[Union[bytes], Any] deserialization_function: A deserialization function that takes one parameter of type `bytes`,
which represents the serialized object. This function should return the deserialized object.
Useful when the artifact was uploaded using a custom serialization function when calling the
`Task.upload_artifact` method with the `serialization_function` argument.
:return: Usually, one of the following objects: Numpy.array, pandas.DataFrame, PIL.Image, dict (json), or pathlib2.Path.
An object with an arbitrary type may also be returned if it was serialized (using pickle or a custom serialization function).
""" """
if self._object is not self._not_set: if self._object is not self._not_set:
return self._object return self._object
@ -165,7 +170,10 @@ class Artifact(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if self.type == "numpy" and np: if deserialization_function:
with open(local_file, "rb") as f:
self._object = deserialization_function(f.read())
elif self.type == "numpy" and np:
if self._content_type == "text/csv": if self._content_type == "text/csv":
self._object = np.genfromtxt(local_file, delimiter=",") self._object = np.genfromtxt(local_file, delimiter=",")
else: else:
@ -339,12 +347,24 @@ class Artifacts(object):
self._unregister_request.add(name) self._unregister_request.add(name)
self.flush() self.flush()
def upload_artifact(self, name, artifact_object=None, metadata=None, preview=None, def upload_artifact(
delete_after_upload=False, auto_pickle=True, wait_on_upload=False, extension_name=None): self,
# type: (str, Optional[object], Optional[dict], Optional[str], bool, bool, bool, Optional[str]) -> bool name, # type: str
if not Session.check_min_api_version('2.3'): artifact_object=None, # type: Optional[object]
LoggerRoot.get_base_logger().warning('Artifacts not supported by your ClearML-server version, ' metadata=None, # type: Optional[dict]
'please upgrade to the latest server version') preview=None, # type: Optional[str]
delete_after_upload=False, # type: bool
auto_pickle=True, # type: bool
wait_on_upload=False, # type: bool
extension_name=None, # type: Optional[str]
serialization_function=None, # type: Optional[Callable[Any, Union[bytes, bytearray]]]
):
# type: (...) -> bool
if not Session.check_min_api_version("2.3"):
LoggerRoot.get_base_logger().warning(
"Artifacts not supported by your ClearML-server version,"
" please upgrade to the latest server version"
)
return False return False
if name in self._artifacts_container: if name in self._artifacts_container:
@ -399,7 +419,27 @@ class Artifacts(object):
) )
return default_extension return default_extension
if np and isinstance(artifact_object, np.ndarray): if serialization_function:
artifact_type = "custom"
# noinspection PyBroadException
try:
artifact_type_data.preview = preview or str(artifact_object.__repr__())[:self.max_preview_size_bytes]
except Exception:
artifact_type_data.preview = ""
override_filename_ext_in_uri = ""
override_filename_in_uri = name
fd, local_filename = mkstemp(prefix=quote(name, safe="") + ".")
os.close(fd)
# noinspection PyBroadException
try:
with open(local_filename, "wb") as f:
f.write(serialization_function(artifact_object))
except Exception:
# cleanup and raise exception
os.unlink(local_filename)
raise
artifact_type_data.content_type = mimetypes.guess_type(local_filename)[0]
elif np and isinstance(artifact_object, np.ndarray):
artifact_type = 'numpy' artifact_type = 'numpy'
artifact_type_data.preview = preview or str(artifact_object.__repr__()) artifact_type_data.preview = preview or str(artifact_object.__repr__())
override_filename_ext_in_uri = get_extension( override_filename_ext_in_uri = get_extension(

View File

@ -18,7 +18,19 @@ try:
except ImportError: except ImportError:
from collections import Sequence as CollectionsSequence from collections import Sequence as CollectionsSequence
from typing import Optional, Union, Mapping, Sequence, Any, Dict, Iterable, TYPE_CHECKING, Callable, Tuple, List from typing import (
Optional,
Union,
Mapping,
Sequence,
Any,
Dict,
Iterable,
TYPE_CHECKING,
Callable,
Tuple,
List,
)
import psutil import psutil
import six import six
@ -1861,6 +1873,7 @@ class Task(_Task):
preview=None, # type: Any preview=None, # type: Any
wait_on_upload=False, # type: bool wait_on_upload=False, # type: bool
extension_name=None, # type: Optional[str] extension_name=None, # type: Optional[str]
serialization_function=None # type: Optional[Callable[Any, Union[bytes, bytearray]]]
): ):
# type: (...) -> bool # type: (...) -> bool
""" """
@ -1908,6 +1921,12 @@ class Task(_Task):
- numpy.ndarray - ``.npz``, ``.csv.gz`` (default ``.npz``) - numpy.ndarray - ``.npz``, ``.csv.gz`` (default ``.npz``)
- PIL.Image - whatever extensions PIL supports (default ``.png``) - PIL.Image - whatever extensions PIL supports (default ``.png``)
:param Callable[Any, Union[bytes, bytearray]] serialization_function: A serialization function that takes one
parameter of any types which is the object to be serialized. The function should return a `bytes` or `bytearray`
object, which represents the serialized object. Note that the object will be immediately serialized using this function,
thus other serialization methods will not be used (e.g. `pandas.DataFrame.to_csv`), even if possible.
To deserialize this artifact when getting it using the `Artifact.get` method, use its `deserialization_function` argument
:return: The status of the upload. :return: The status of the upload.
- ``True`` - Upload succeeded. - ``True`` - Upload succeeded.
@ -1916,8 +1935,16 @@ class Task(_Task):
:raise: If the artifact object type is not supported, raise a ``ValueError``. :raise: If the artifact object type is not supported, raise a ``ValueError``.
""" """
return self._artifacts_manager.upload_artifact( return self._artifacts_manager.upload_artifact(
name=name, artifact_object=artifact_object, metadata=metadata, delete_after_upload=delete_after_upload, name=name,
auto_pickle=auto_pickle, preview=preview, wait_on_upload=wait_on_upload, extension_name=extension_name) artifact_object=artifact_object,
metadata=metadata,
delete_after_upload=delete_after_upload,
auto_pickle=auto_pickle,
preview=preview,
wait_on_upload=wait_on_upload,
extension_name=extension_name,
serialization_function=serialization_function,
)
def get_models(self): def get_models(self):
# type: () -> Mapping[str, Sequence[Model]] # type: () -> Mapping[str, Sequence[Model]]

View File

@ -572,6 +572,7 @@ def prep_xy_axis(ax, props, x_bounds, y_bounds):
yaxis.update(prep_ticks(ax, 1, "y", props)) yaxis.update(prep_ticks(ax, 1, "y", props))
return xaxis, yaxis return xaxis, yaxis
def prep_xyz_axis(ax, props, x_bounds, y_bounds): def prep_xyz_axis(ax, props, x_bounds, y_bounds):
# there is no z_bounds as they can't (at least easily) be extracted from an `Axes3DSubplot` object # there is no z_bounds as they can't (at least easily) be extracted from an `Axes3DSubplot` object
xaxis, yaxis = prep_xy_axis(ax, props, x_bounds, y_bounds) xaxis, yaxis = prep_xy_axis(ax, props, x_bounds, y_bounds)