mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Add artifacts upload with a custom serializer (#689)
This commit is contained in:
		
							parent
							
								
									f04c39fb30
								
							
						
					
					
						commit
						5228b799c1
					
				| @ -16,7 +16,7 @@ import six | ||||
| from PIL import Image | ||||
| from pathlib2 import Path | ||||
| from six.moves.urllib.parse import urlparse | ||||
| from typing import Dict, Union, Optional, Any, Sequence | ||||
| from typing import Dict, Union, Optional, Any, Sequence, Callable | ||||
| 
 | ||||
| from ..backend_api import Session | ||||
| from ..backend_api.services import tasks | ||||
| @ -140,8 +140,8 @@ class Artifact(object): | ||||
|         self._content_type = artifact_api_object.type_data.content_type if artifact_api_object.type_data else None | ||||
|         self._object = self._not_set | ||||
| 
 | ||||
|     def get(self, force_download=False): | ||||
|         # type: (bool) -> Any | ||||
|     def get(self, force_download=False, deserialization_function=None): | ||||
|         # type: (bool, Optional[Callable[Union[bytes], Any]]) -> Any | ||||
|         """ | ||||
|         Return an object constructed from the artifact file | ||||
| 
 | ||||
| @ -156,7 +156,12 @@ class Artifact(object): | ||||
|         pointing to a local copy of the artifacts file (or directory) will be returned | ||||
| 
 | ||||
|         :param bool force_download: download file from remote even if exists in local cache | ||||
|         :return: One of the following objects Numpy.array, pandas.DataFrame, PIL.Image, dict (json), or pathlib2.Path. | ||||
|         :param Callable[Union[bytes], Any] deserialization_function: A deserialization function that takes one parameter of type `bytes`, | ||||
|             which represents the serialized object. This function should return the deserialized object. | ||||
|             Useful when the artifact was uploaded using a custom serialization function when calling the | ||||
|             `Task.upload_artifact` method with the `serialization_function` argument. | ||||
|         :return: Usually, one of the following objects: Numpy.array, pandas.DataFrame, PIL.Image, dict (json), or pathlib2.Path. | ||||
|             An object with an arbitrary type may also be returned if it was serialized (using pickle or a custom serialization function). | ||||
|         """ | ||||
|         if self._object is not self._not_set: | ||||
|             return self._object | ||||
| @ -165,7 +170,10 @@ class Artifact(object): | ||||
| 
 | ||||
|         # noinspection PyBroadException | ||||
|         try: | ||||
|             if self.type == "numpy" and np: | ||||
|             if deserialization_function: | ||||
|                 with open(local_file, "rb") as f: | ||||
|                     self._object = deserialization_function(f.read()) | ||||
|             elif self.type == "numpy" and np: | ||||
|                 if self._content_type == "text/csv": | ||||
|                     self._object = np.genfromtxt(local_file, delimiter=",") | ||||
|                 else: | ||||
| @ -339,12 +347,24 @@ class Artifacts(object): | ||||
|         self._unregister_request.add(name) | ||||
|         self.flush() | ||||
| 
 | ||||
|     def upload_artifact(self, name, artifact_object=None, metadata=None, preview=None, | ||||
|                         delete_after_upload=False, auto_pickle=True, wait_on_upload=False, extension_name=None): | ||||
|         # type: (str, Optional[object], Optional[dict], Optional[str], bool, bool, bool, Optional[str]) -> bool | ||||
|         if not Session.check_min_api_version('2.3'): | ||||
|             LoggerRoot.get_base_logger().warning('Artifacts not supported by your ClearML-server version, ' | ||||
|                                                  'please upgrade to the latest server version') | ||||
|     def upload_artifact( | ||||
|         self, | ||||
|         name,  # type: str | ||||
|         artifact_object=None,  # type: Optional[object] | ||||
|         metadata=None,  # type: Optional[dict] | ||||
|         preview=None,  # type: Optional[str] | ||||
|         delete_after_upload=False,  # type: bool | ||||
|         auto_pickle=True,  # type: bool | ||||
|         wait_on_upload=False,  # type: bool | ||||
|         extension_name=None,  # type: Optional[str] | ||||
|         serialization_function=None,  # type: Optional[Callable[Any, Union[bytes, bytearray]]] | ||||
|     ): | ||||
|         # type: (...) -> bool | ||||
|         if not Session.check_min_api_version("2.3"): | ||||
|             LoggerRoot.get_base_logger().warning( | ||||
|                 "Artifacts not supported by your ClearML-server version," | ||||
|                 " please upgrade to the latest server version" | ||||
|             ) | ||||
|             return False | ||||
| 
 | ||||
|         if name in self._artifacts_container: | ||||
| @ -399,7 +419,27 @@ class Artifacts(object): | ||||
|             ) | ||||
|             return default_extension | ||||
| 
 | ||||
|         if np and isinstance(artifact_object, np.ndarray): | ||||
|         if serialization_function: | ||||
|             artifact_type = "custom" | ||||
|             # noinspection PyBroadException | ||||
|             try: | ||||
|                 artifact_type_data.preview = preview or str(artifact_object.__repr__())[:self.max_preview_size_bytes] | ||||
|             except Exception: | ||||
|                 artifact_type_data.preview = "" | ||||
|             override_filename_ext_in_uri = "" | ||||
|             override_filename_in_uri = name | ||||
|             fd, local_filename = mkstemp(prefix=quote(name, safe="") + ".") | ||||
|             os.close(fd) | ||||
|             # noinspection PyBroadException | ||||
|             try: | ||||
|                 with open(local_filename, "wb") as f: | ||||
|                     f.write(serialization_function(artifact_object)) | ||||
|             except Exception: | ||||
|                 # cleanup and raise exception | ||||
|                 os.unlink(local_filename) | ||||
|                 raise | ||||
|             artifact_type_data.content_type = mimetypes.guess_type(local_filename)[0] | ||||
|         elif np and isinstance(artifact_object, np.ndarray): | ||||
|             artifact_type = 'numpy' | ||||
|             artifact_type_data.preview = preview or str(artifact_object.__repr__()) | ||||
|             override_filename_ext_in_uri = get_extension( | ||||
|  | ||||
| @ -18,7 +18,19 @@ try: | ||||
| except ImportError: | ||||
|     from collections import Sequence as CollectionsSequence | ||||
| 
 | ||||
| from typing import Optional, Union, Mapping, Sequence, Any, Dict, Iterable, TYPE_CHECKING, Callable, Tuple, List | ||||
| from typing import ( | ||||
|     Optional, | ||||
|     Union, | ||||
|     Mapping, | ||||
|     Sequence, | ||||
|     Any, | ||||
|     Dict, | ||||
|     Iterable, | ||||
|     TYPE_CHECKING, | ||||
|     Callable, | ||||
|     Tuple, | ||||
|     List, | ||||
| ) | ||||
| 
 | ||||
| import psutil | ||||
| import six | ||||
| @ -1861,6 +1873,7 @@ class Task(_Task): | ||||
|             preview=None,  # type: Any | ||||
|             wait_on_upload=False,  # type: bool | ||||
|             extension_name=None,  # type: Optional[str] | ||||
|             serialization_function=None  # type: Optional[Callable[Any, Union[bytes, bytearray]]] | ||||
|     ): | ||||
|         # type: (...) -> bool | ||||
|         """ | ||||
| @ -1908,6 +1921,12 @@ class Task(_Task): | ||||
|         - numpy.ndarray - ``.npz``, ``.csv.gz`` (default ``.npz``) | ||||
|         - PIL.Image - whatever extensions PIL supports (default ``.png``) | ||||
| 
 | ||||
|         :param Callable[Any, Union[bytes, bytearray]] serialization_function: A serialization function that takes one | ||||
|             parameter of any types which is the object to be serialized. The function should return a `bytes` or `bytearray` | ||||
|             object, which represents the serialized object. Note that the object will be immediately serialized using this function, | ||||
|             thus other serialization methods will not be used (e.g. `pandas.DataFrame.to_csv`), even if possible. | ||||
|             To deserialize this artifact when getting it using the `Artifact.get` method, use its `deserialization_function` argument | ||||
| 
 | ||||
|         :return: The status of the upload. | ||||
| 
 | ||||
|         - ``True`` - Upload succeeded. | ||||
| @ -1916,8 +1935,16 @@ class Task(_Task): | ||||
|         :raise: If the artifact object type is not supported, raise a ``ValueError``. | ||||
|         """ | ||||
|         return self._artifacts_manager.upload_artifact( | ||||
|             name=name, artifact_object=artifact_object, metadata=metadata, delete_after_upload=delete_after_upload, | ||||
|             auto_pickle=auto_pickle, preview=preview, wait_on_upload=wait_on_upload, extension_name=extension_name) | ||||
|             name=name, | ||||
|             artifact_object=artifact_object, | ||||
|             metadata=metadata, | ||||
|             delete_after_upload=delete_after_upload, | ||||
|             auto_pickle=auto_pickle, | ||||
|             preview=preview, | ||||
|             wait_on_upload=wait_on_upload, | ||||
|             extension_name=extension_name, | ||||
|             serialization_function=serialization_function, | ||||
|         ) | ||||
| 
 | ||||
|     def get_models(self): | ||||
|         # type: () -> Mapping[str, Sequence[Model]] | ||||
|  | ||||
| @ -572,6 +572,7 @@ def prep_xy_axis(ax, props, x_bounds, y_bounds): | ||||
|     yaxis.update(prep_ticks(ax, 1, "y", props)) | ||||
|     return xaxis, yaxis | ||||
| 
 | ||||
| 
 | ||||
| def prep_xyz_axis(ax, props, x_bounds, y_bounds): | ||||
|     # there is no z_bounds as they can't (at least easily) be extracted from an `Axes3DSubplot` object | ||||
|     xaxis, yaxis = prep_xy_axis(ax, props, x_bounds, y_bounds) | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai