diff --git a/clearml/binding/artifacts.py b/clearml/binding/artifacts.py index 0809f46f..5e773bbc 100644 --- a/clearml/binding/artifacts.py +++ b/clearml/binding/artifacts.py @@ -16,7 +16,7 @@ import six from PIL import Image from pathlib2 import Path from six.moves.urllib.parse import urlparse -from typing import Dict, Union, Optional, Any, Sequence +from typing import Dict, Union, Optional, Any, Sequence, Callable from ..backend_api import Session from ..backend_api.services import tasks @@ -140,8 +140,8 @@ class Artifact(object): self._content_type = artifact_api_object.type_data.content_type if artifact_api_object.type_data else None self._object = self._not_set - def get(self, force_download=False): - # type: (bool) -> Any + def get(self, force_download=False, deserialization_function=None): + # type: (bool, Optional[Callable[Union[bytes], Any]]) -> Any """ Return an object constructed from the artifact file @@ -156,7 +156,12 @@ class Artifact(object): pointing to a local copy of the artifacts file (or directory) will be returned :param bool force_download: download file from remote even if exists in local cache - :return: One of the following objects Numpy.array, pandas.DataFrame, PIL.Image, dict (json), or pathlib2.Path. + :param Callable[Union[bytes], Any] deserialization_function: A deserialization function that takes one parameter of type `bytes`, + which represents the serialized object. This function should return the deserialized object. + Useful when the artifact was uploaded using a custom serialization function when calling the + `Task.upload_artifact` method with the `serialization_function` argument. + :return: Usually, one of the following objects: Numpy.array, pandas.DataFrame, PIL.Image, dict (json), or pathlib2.Path. + An object with an arbitrary type may also be returned if it was serialized (using pickle or a custom serialization function). """ if self._object is not self._not_set: return self._object @@ -165,7 +170,10 @@ class Artifact(object): # noinspection PyBroadException try: - if self.type == "numpy" and np: + if deserialization_function: + with open(local_file, "rb") as f: + self._object = deserialization_function(f.read()) + elif self.type == "numpy" and np: if self._content_type == "text/csv": self._object = np.genfromtxt(local_file, delimiter=",") else: @@ -339,12 +347,24 @@ class Artifacts(object): self._unregister_request.add(name) self.flush() - def upload_artifact(self, name, artifact_object=None, metadata=None, preview=None, - delete_after_upload=False, auto_pickle=True, wait_on_upload=False, extension_name=None): - # type: (str, Optional[object], Optional[dict], Optional[str], bool, bool, bool, Optional[str]) -> bool - if not Session.check_min_api_version('2.3'): - LoggerRoot.get_base_logger().warning('Artifacts not supported by your ClearML-server version, ' - 'please upgrade to the latest server version') + def upload_artifact( + self, + name, # type: str + artifact_object=None, # type: Optional[object] + metadata=None, # type: Optional[dict] + preview=None, # type: Optional[str] + delete_after_upload=False, # type: bool + auto_pickle=True, # type: bool + wait_on_upload=False, # type: bool + extension_name=None, # type: Optional[str] + serialization_function=None, # type: Optional[Callable[Any, Union[bytes, bytearray]]] + ): + # type: (...) -> bool + if not Session.check_min_api_version("2.3"): + LoggerRoot.get_base_logger().warning( + "Artifacts not supported by your ClearML-server version," + " please upgrade to the latest server version" + ) return False if name in self._artifacts_container: @@ -399,7 +419,27 @@ class Artifacts(object): ) return default_extension - if np and isinstance(artifact_object, np.ndarray): + if serialization_function: + artifact_type = "custom" + # noinspection PyBroadException + try: + artifact_type_data.preview = preview or str(artifact_object.__repr__())[:self.max_preview_size_bytes] + except Exception: + artifact_type_data.preview = "" + override_filename_ext_in_uri = "" + override_filename_in_uri = name + fd, local_filename = mkstemp(prefix=quote(name, safe="") + ".") + os.close(fd) + # noinspection PyBroadException + try: + with open(local_filename, "wb") as f: + f.write(serialization_function(artifact_object)) + except Exception: + # cleanup and raise exception + os.unlink(local_filename) + raise + artifact_type_data.content_type = mimetypes.guess_type(local_filename)[0] + elif np and isinstance(artifact_object, np.ndarray): artifact_type = 'numpy' artifact_type_data.preview = preview or str(artifact_object.__repr__()) override_filename_ext_in_uri = get_extension( diff --git a/clearml/task.py b/clearml/task.py index c949bfc6..af9053c7 100644 --- a/clearml/task.py +++ b/clearml/task.py @@ -18,7 +18,19 @@ try: except ImportError: from collections import Sequence as CollectionsSequence -from typing import Optional, Union, Mapping, Sequence, Any, Dict, Iterable, TYPE_CHECKING, Callable, Tuple, List +from typing import ( + Optional, + Union, + Mapping, + Sequence, + Any, + Dict, + Iterable, + TYPE_CHECKING, + Callable, + Tuple, + List, +) import psutil import six @@ -1861,6 +1873,7 @@ class Task(_Task): preview=None, # type: Any wait_on_upload=False, # type: bool extension_name=None, # type: Optional[str] + serialization_function=None # type: Optional[Callable[Any, Union[bytes, bytearray]]] ): # type: (...) -> bool """ @@ -1908,6 +1921,12 @@ class Task(_Task): - numpy.ndarray - ``.npz``, ``.csv.gz`` (default ``.npz``) - PIL.Image - whatever extensions PIL supports (default ``.png``) + :param Callable[Any, Union[bytes, bytearray]] serialization_function: A serialization function that takes one + parameter of any types which is the object to be serialized. The function should return a `bytes` or `bytearray` + object, which represents the serialized object. Note that the object will be immediately serialized using this function, + thus other serialization methods will not be used (e.g. `pandas.DataFrame.to_csv`), even if possible. + To deserialize this artifact when getting it using the `Artifact.get` method, use its `deserialization_function` argument + :return: The status of the upload. - ``True`` - Upload succeeded. @@ -1916,8 +1935,16 @@ class Task(_Task): :raise: If the artifact object type is not supported, raise a ``ValueError``. """ return self._artifacts_manager.upload_artifact( - name=name, artifact_object=artifact_object, metadata=metadata, delete_after_upload=delete_after_upload, - auto_pickle=auto_pickle, preview=preview, wait_on_upload=wait_on_upload, extension_name=extension_name) + name=name, + artifact_object=artifact_object, + metadata=metadata, + delete_after_upload=delete_after_upload, + auto_pickle=auto_pickle, + preview=preview, + wait_on_upload=wait_on_upload, + extension_name=extension_name, + serialization_function=serialization_function, + ) def get_models(self): # type: () -> Mapping[str, Sequence[Model]] diff --git a/clearml/utilities/plotlympl/mpltools.py b/clearml/utilities/plotlympl/mpltools.py index dcd19089..eefa6f4d 100644 --- a/clearml/utilities/plotlympl/mpltools.py +++ b/clearml/utilities/plotlympl/mpltools.py @@ -572,6 +572,7 @@ def prep_xy_axis(ax, props, x_bounds, y_bounds): yaxis.update(prep_ticks(ax, 1, "y", props)) return xaxis, yaxis + def prep_xyz_axis(ax, props, x_bounds, y_bounds): # there is no z_bounds as they can't (at least easily) be extracted from an `Axes3DSubplot` object xaxis, yaxis = prep_xy_axis(ax, props, x_bounds, y_bounds)