mirror of
https://github.com/clearml/clearml
synced 2025-01-31 09:07:00 +00:00
Add artifacts upload with a custom serializer (#689)
This commit is contained in:
parent
f04c39fb30
commit
5228b799c1
@ -16,7 +16,7 @@ import six
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
from six.moves.urllib.parse import urlparse
|
from six.moves.urllib.parse import urlparse
|
||||||
from typing import Dict, Union, Optional, Any, Sequence
|
from typing import Dict, Union, Optional, Any, Sequence, Callable
|
||||||
|
|
||||||
from ..backend_api import Session
|
from ..backend_api import Session
|
||||||
from ..backend_api.services import tasks
|
from ..backend_api.services import tasks
|
||||||
@ -140,8 +140,8 @@ class Artifact(object):
|
|||||||
self._content_type = artifact_api_object.type_data.content_type if artifact_api_object.type_data else None
|
self._content_type = artifact_api_object.type_data.content_type if artifact_api_object.type_data else None
|
||||||
self._object = self._not_set
|
self._object = self._not_set
|
||||||
|
|
||||||
def get(self, force_download=False):
|
def get(self, force_download=False, deserialization_function=None):
|
||||||
# type: (bool) -> Any
|
# type: (bool, Optional[Callable[Union[bytes], Any]]) -> Any
|
||||||
"""
|
"""
|
||||||
Return an object constructed from the artifact file
|
Return an object constructed from the artifact file
|
||||||
|
|
||||||
@ -156,7 +156,12 @@ class Artifact(object):
|
|||||||
pointing to a local copy of the artifacts file (or directory) will be returned
|
pointing to a local copy of the artifacts file (or directory) will be returned
|
||||||
|
|
||||||
:param bool force_download: download file from remote even if exists in local cache
|
:param bool force_download: download file from remote even if exists in local cache
|
||||||
:return: One of the following objects Numpy.array, pandas.DataFrame, PIL.Image, dict (json), or pathlib2.Path.
|
:param Callable[Union[bytes], Any] deserialization_function: A deserialization function that takes one parameter of type `bytes`,
|
||||||
|
which represents the serialized object. This function should return the deserialized object.
|
||||||
|
Useful when the artifact was uploaded using a custom serialization function when calling the
|
||||||
|
`Task.upload_artifact` method with the `serialization_function` argument.
|
||||||
|
:return: Usually, one of the following objects: Numpy.array, pandas.DataFrame, PIL.Image, dict (json), or pathlib2.Path.
|
||||||
|
An object with an arbitrary type may also be returned if it was serialized (using pickle or a custom serialization function).
|
||||||
"""
|
"""
|
||||||
if self._object is not self._not_set:
|
if self._object is not self._not_set:
|
||||||
return self._object
|
return self._object
|
||||||
@ -165,7 +170,10 @@ class Artifact(object):
|
|||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
if self.type == "numpy" and np:
|
if deserialization_function:
|
||||||
|
with open(local_file, "rb") as f:
|
||||||
|
self._object = deserialization_function(f.read())
|
||||||
|
elif self.type == "numpy" and np:
|
||||||
if self._content_type == "text/csv":
|
if self._content_type == "text/csv":
|
||||||
self._object = np.genfromtxt(local_file, delimiter=",")
|
self._object = np.genfromtxt(local_file, delimiter=",")
|
||||||
else:
|
else:
|
||||||
@ -339,12 +347,24 @@ class Artifacts(object):
|
|||||||
self._unregister_request.add(name)
|
self._unregister_request.add(name)
|
||||||
self.flush()
|
self.flush()
|
||||||
|
|
||||||
def upload_artifact(self, name, artifact_object=None, metadata=None, preview=None,
|
def upload_artifact(
|
||||||
delete_after_upload=False, auto_pickle=True, wait_on_upload=False, extension_name=None):
|
self,
|
||||||
# type: (str, Optional[object], Optional[dict], Optional[str], bool, bool, bool, Optional[str]) -> bool
|
name, # type: str
|
||||||
if not Session.check_min_api_version('2.3'):
|
artifact_object=None, # type: Optional[object]
|
||||||
LoggerRoot.get_base_logger().warning('Artifacts not supported by your ClearML-server version, '
|
metadata=None, # type: Optional[dict]
|
||||||
'please upgrade to the latest server version')
|
preview=None, # type: Optional[str]
|
||||||
|
delete_after_upload=False, # type: bool
|
||||||
|
auto_pickle=True, # type: bool
|
||||||
|
wait_on_upload=False, # type: bool
|
||||||
|
extension_name=None, # type: Optional[str]
|
||||||
|
serialization_function=None, # type: Optional[Callable[Any, Union[bytes, bytearray]]]
|
||||||
|
):
|
||||||
|
# type: (...) -> bool
|
||||||
|
if not Session.check_min_api_version("2.3"):
|
||||||
|
LoggerRoot.get_base_logger().warning(
|
||||||
|
"Artifacts not supported by your ClearML-server version,"
|
||||||
|
" please upgrade to the latest server version"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if name in self._artifacts_container:
|
if name in self._artifacts_container:
|
||||||
@ -399,7 +419,27 @@ class Artifacts(object):
|
|||||||
)
|
)
|
||||||
return default_extension
|
return default_extension
|
||||||
|
|
||||||
if np and isinstance(artifact_object, np.ndarray):
|
if serialization_function:
|
||||||
|
artifact_type = "custom"
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
artifact_type_data.preview = preview or str(artifact_object.__repr__())[:self.max_preview_size_bytes]
|
||||||
|
except Exception:
|
||||||
|
artifact_type_data.preview = ""
|
||||||
|
override_filename_ext_in_uri = ""
|
||||||
|
override_filename_in_uri = name
|
||||||
|
fd, local_filename = mkstemp(prefix=quote(name, safe="") + ".")
|
||||||
|
os.close(fd)
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
with open(local_filename, "wb") as f:
|
||||||
|
f.write(serialization_function(artifact_object))
|
||||||
|
except Exception:
|
||||||
|
# cleanup and raise exception
|
||||||
|
os.unlink(local_filename)
|
||||||
|
raise
|
||||||
|
artifact_type_data.content_type = mimetypes.guess_type(local_filename)[0]
|
||||||
|
elif np and isinstance(artifact_object, np.ndarray):
|
||||||
artifact_type = 'numpy'
|
artifact_type = 'numpy'
|
||||||
artifact_type_data.preview = preview or str(artifact_object.__repr__())
|
artifact_type_data.preview = preview or str(artifact_object.__repr__())
|
||||||
override_filename_ext_in_uri = get_extension(
|
override_filename_ext_in_uri = get_extension(
|
||||||
|
@ -18,7 +18,19 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from collections import Sequence as CollectionsSequence
|
from collections import Sequence as CollectionsSequence
|
||||||
|
|
||||||
from typing import Optional, Union, Mapping, Sequence, Any, Dict, Iterable, TYPE_CHECKING, Callable, Tuple, List
|
from typing import (
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
Mapping,
|
||||||
|
Sequence,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Callable,
|
||||||
|
Tuple,
|
||||||
|
List,
|
||||||
|
)
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import six
|
import six
|
||||||
@ -1861,6 +1873,7 @@ class Task(_Task):
|
|||||||
preview=None, # type: Any
|
preview=None, # type: Any
|
||||||
wait_on_upload=False, # type: bool
|
wait_on_upload=False, # type: bool
|
||||||
extension_name=None, # type: Optional[str]
|
extension_name=None, # type: Optional[str]
|
||||||
|
serialization_function=None # type: Optional[Callable[Any, Union[bytes, bytearray]]]
|
||||||
):
|
):
|
||||||
# type: (...) -> bool
|
# type: (...) -> bool
|
||||||
"""
|
"""
|
||||||
@ -1908,6 +1921,12 @@ class Task(_Task):
|
|||||||
- numpy.ndarray - ``.npz``, ``.csv.gz`` (default ``.npz``)
|
- numpy.ndarray - ``.npz``, ``.csv.gz`` (default ``.npz``)
|
||||||
- PIL.Image - whatever extensions PIL supports (default ``.png``)
|
- PIL.Image - whatever extensions PIL supports (default ``.png``)
|
||||||
|
|
||||||
|
:param Callable[Any, Union[bytes, bytearray]] serialization_function: A serialization function that takes one
|
||||||
|
parameter of any types which is the object to be serialized. The function should return a `bytes` or `bytearray`
|
||||||
|
object, which represents the serialized object. Note that the object will be immediately serialized using this function,
|
||||||
|
thus other serialization methods will not be used (e.g. `pandas.DataFrame.to_csv`), even if possible.
|
||||||
|
To deserialize this artifact when getting it using the `Artifact.get` method, use its `deserialization_function` argument
|
||||||
|
|
||||||
:return: The status of the upload.
|
:return: The status of the upload.
|
||||||
|
|
||||||
- ``True`` - Upload succeeded.
|
- ``True`` - Upload succeeded.
|
||||||
@ -1916,8 +1935,16 @@ class Task(_Task):
|
|||||||
:raise: If the artifact object type is not supported, raise a ``ValueError``.
|
:raise: If the artifact object type is not supported, raise a ``ValueError``.
|
||||||
"""
|
"""
|
||||||
return self._artifacts_manager.upload_artifact(
|
return self._artifacts_manager.upload_artifact(
|
||||||
name=name, artifact_object=artifact_object, metadata=metadata, delete_after_upload=delete_after_upload,
|
name=name,
|
||||||
auto_pickle=auto_pickle, preview=preview, wait_on_upload=wait_on_upload, extension_name=extension_name)
|
artifact_object=artifact_object,
|
||||||
|
metadata=metadata,
|
||||||
|
delete_after_upload=delete_after_upload,
|
||||||
|
auto_pickle=auto_pickle,
|
||||||
|
preview=preview,
|
||||||
|
wait_on_upload=wait_on_upload,
|
||||||
|
extension_name=extension_name,
|
||||||
|
serialization_function=serialization_function,
|
||||||
|
)
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
# type: () -> Mapping[str, Sequence[Model]]
|
# type: () -> Mapping[str, Sequence[Model]]
|
||||||
|
@ -572,6 +572,7 @@ def prep_xy_axis(ax, props, x_bounds, y_bounds):
|
|||||||
yaxis.update(prep_ticks(ax, 1, "y", props))
|
yaxis.update(prep_ticks(ax, 1, "y", props))
|
||||||
return xaxis, yaxis
|
return xaxis, yaxis
|
||||||
|
|
||||||
|
|
||||||
def prep_xyz_axis(ax, props, x_bounds, y_bounds):
|
def prep_xyz_axis(ax, props, x_bounds, y_bounds):
|
||||||
# there is no z_bounds as they can't (at least easily) be extracted from an `Axes3DSubplot` object
|
# there is no z_bounds as they can't (at least easily) be extracted from an `Axes3DSubplot` object
|
||||||
xaxis, yaxis = prep_xy_axis(ax, props, x_bounds, y_bounds)
|
xaxis, yaxis = prep_xy_axis(ax, props, x_bounds, y_bounds)
|
||||||
|
Loading…
Reference in New Issue
Block a user