mirror of
https://github.com/clearml/clearml
synced 2025-01-31 09:07:00 +00:00
Add artifacts upload with a custom serializer (#689)
This commit is contained in:
parent
f04c39fb30
commit
5228b799c1
@ -16,7 +16,7 @@ import six
|
||||
from PIL import Image
|
||||
from pathlib2 import Path
|
||||
from six.moves.urllib.parse import urlparse
|
||||
from typing import Dict, Union, Optional, Any, Sequence
|
||||
from typing import Dict, Union, Optional, Any, Sequence, Callable
|
||||
|
||||
from ..backend_api import Session
|
||||
from ..backend_api.services import tasks
|
||||
@ -140,8 +140,8 @@ class Artifact(object):
|
||||
self._content_type = artifact_api_object.type_data.content_type if artifact_api_object.type_data else None
|
||||
self._object = self._not_set
|
||||
|
||||
def get(self, force_download=False):
|
||||
# type: (bool) -> Any
|
||||
def get(self, force_download=False, deserialization_function=None):
|
||||
# type: (bool, Optional[Callable[Union[bytes], Any]]) -> Any
|
||||
"""
|
||||
Return an object constructed from the artifact file
|
||||
|
||||
@ -156,7 +156,12 @@ class Artifact(object):
|
||||
pointing to a local copy of the artifacts file (or directory) will be returned
|
||||
|
||||
:param bool force_download: download file from remote even if exists in local cache
|
||||
:return: One of the following objects Numpy.array, pandas.DataFrame, PIL.Image, dict (json), or pathlib2.Path.
|
||||
:param Callable[Union[bytes], Any] deserialization_function: A deserialization function that takes one parameter of type `bytes`,
|
||||
which represents the serialized object. This function should return the deserialized object.
|
||||
Useful when the artifact was uploaded using a custom serialization function when calling the
|
||||
`Task.upload_artifact` method with the `serialization_function` argument.
|
||||
:return: Usually, one of the following objects: Numpy.array, pandas.DataFrame, PIL.Image, dict (json), or pathlib2.Path.
|
||||
An object with an arbitrary type may also be returned if it was serialized (using pickle or a custom serialization function).
|
||||
"""
|
||||
if self._object is not self._not_set:
|
||||
return self._object
|
||||
@ -165,7 +170,10 @@ class Artifact(object):
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if self.type == "numpy" and np:
|
||||
if deserialization_function:
|
||||
with open(local_file, "rb") as f:
|
||||
self._object = deserialization_function(f.read())
|
||||
elif self.type == "numpy" and np:
|
||||
if self._content_type == "text/csv":
|
||||
self._object = np.genfromtxt(local_file, delimiter=",")
|
||||
else:
|
||||
@ -339,12 +347,24 @@ class Artifacts(object):
|
||||
self._unregister_request.add(name)
|
||||
self.flush()
|
||||
|
||||
def upload_artifact(self, name, artifact_object=None, metadata=None, preview=None,
|
||||
delete_after_upload=False, auto_pickle=True, wait_on_upload=False, extension_name=None):
|
||||
# type: (str, Optional[object], Optional[dict], Optional[str], bool, bool, bool, Optional[str]) -> bool
|
||||
if not Session.check_min_api_version('2.3'):
|
||||
LoggerRoot.get_base_logger().warning('Artifacts not supported by your ClearML-server version, '
|
||||
'please upgrade to the latest server version')
|
||||
def upload_artifact(
|
||||
self,
|
||||
name, # type: str
|
||||
artifact_object=None, # type: Optional[object]
|
||||
metadata=None, # type: Optional[dict]
|
||||
preview=None, # type: Optional[str]
|
||||
delete_after_upload=False, # type: bool
|
||||
auto_pickle=True, # type: bool
|
||||
wait_on_upload=False, # type: bool
|
||||
extension_name=None, # type: Optional[str]
|
||||
serialization_function=None, # type: Optional[Callable[Any, Union[bytes, bytearray]]]
|
||||
):
|
||||
# type: (...) -> bool
|
||||
if not Session.check_min_api_version("2.3"):
|
||||
LoggerRoot.get_base_logger().warning(
|
||||
"Artifacts not supported by your ClearML-server version,"
|
||||
" please upgrade to the latest server version"
|
||||
)
|
||||
return False
|
||||
|
||||
if name in self._artifacts_container:
|
||||
@ -399,7 +419,27 @@ class Artifacts(object):
|
||||
)
|
||||
return default_extension
|
||||
|
||||
if np and isinstance(artifact_object, np.ndarray):
|
||||
if serialization_function:
|
||||
artifact_type = "custom"
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
artifact_type_data.preview = preview or str(artifact_object.__repr__())[:self.max_preview_size_bytes]
|
||||
except Exception:
|
||||
artifact_type_data.preview = ""
|
||||
override_filename_ext_in_uri = ""
|
||||
override_filename_in_uri = name
|
||||
fd, local_filename = mkstemp(prefix=quote(name, safe="") + ".")
|
||||
os.close(fd)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
with open(local_filename, "wb") as f:
|
||||
f.write(serialization_function(artifact_object))
|
||||
except Exception:
|
||||
# cleanup and raise exception
|
||||
os.unlink(local_filename)
|
||||
raise
|
||||
artifact_type_data.content_type = mimetypes.guess_type(local_filename)[0]
|
||||
elif np and isinstance(artifact_object, np.ndarray):
|
||||
artifact_type = 'numpy'
|
||||
artifact_type_data.preview = preview or str(artifact_object.__repr__())
|
||||
override_filename_ext_in_uri = get_extension(
|
||||
|
@ -18,7 +18,19 @@ try:
|
||||
except ImportError:
|
||||
from collections import Sequence as CollectionsSequence
|
||||
|
||||
from typing import Optional, Union, Mapping, Sequence, Any, Dict, Iterable, TYPE_CHECKING, Callable, Tuple, List
|
||||
from typing import (
|
||||
Optional,
|
||||
Union,
|
||||
Mapping,
|
||||
Sequence,
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
TYPE_CHECKING,
|
||||
Callable,
|
||||
Tuple,
|
||||
List,
|
||||
)
|
||||
|
||||
import psutil
|
||||
import six
|
||||
@ -1861,6 +1873,7 @@ class Task(_Task):
|
||||
preview=None, # type: Any
|
||||
wait_on_upload=False, # type: bool
|
||||
extension_name=None, # type: Optional[str]
|
||||
serialization_function=None # type: Optional[Callable[Any, Union[bytes, bytearray]]]
|
||||
):
|
||||
# type: (...) -> bool
|
||||
"""
|
||||
@ -1908,6 +1921,12 @@ class Task(_Task):
|
||||
- numpy.ndarray - ``.npz``, ``.csv.gz`` (default ``.npz``)
|
||||
- PIL.Image - whatever extensions PIL supports (default ``.png``)
|
||||
|
||||
:param Callable[Any, Union[bytes, bytearray]] serialization_function: A serialization function that takes one
|
||||
parameter of any types which is the object to be serialized. The function should return a `bytes` or `bytearray`
|
||||
object, which represents the serialized object. Note that the object will be immediately serialized using this function,
|
||||
thus other serialization methods will not be used (e.g. `pandas.DataFrame.to_csv`), even if possible.
|
||||
To deserialize this artifact when getting it using the `Artifact.get` method, use its `deserialization_function` argument
|
||||
|
||||
:return: The status of the upload.
|
||||
|
||||
- ``True`` - Upload succeeded.
|
||||
@ -1916,8 +1935,16 @@ class Task(_Task):
|
||||
:raise: If the artifact object type is not supported, raise a ``ValueError``.
|
||||
"""
|
||||
return self._artifacts_manager.upload_artifact(
|
||||
name=name, artifact_object=artifact_object, metadata=metadata, delete_after_upload=delete_after_upload,
|
||||
auto_pickle=auto_pickle, preview=preview, wait_on_upload=wait_on_upload, extension_name=extension_name)
|
||||
name=name,
|
||||
artifact_object=artifact_object,
|
||||
metadata=metadata,
|
||||
delete_after_upload=delete_after_upload,
|
||||
auto_pickle=auto_pickle,
|
||||
preview=preview,
|
||||
wait_on_upload=wait_on_upload,
|
||||
extension_name=extension_name,
|
||||
serialization_function=serialization_function,
|
||||
)
|
||||
|
||||
def get_models(self):
|
||||
# type: () -> Mapping[str, Sequence[Model]]
|
||||
|
@ -572,6 +572,7 @@ def prep_xy_axis(ax, props, x_bounds, y_bounds):
|
||||
yaxis.update(prep_ticks(ax, 1, "y", props))
|
||||
return xaxis, yaxis
|
||||
|
||||
|
||||
def prep_xyz_axis(ax, props, x_bounds, y_bounds):
|
||||
# there is no z_bounds as they can't (at least easily) be extracted from an `Axes3DSubplot` object
|
||||
xaxis, yaxis = prep_xy_axis(ax, props, x_bounds, y_bounds)
|
||||
|
Loading…
Reference in New Issue
Block a user