Add artifact get() returning the artifact object (after downloading and loading)

This commit is contained in:
allegroai 2019-11-15 21:59:36 +02:00
parent b37aea1839
commit d19fde7041

View File

@ -114,6 +114,47 @@ class Artifact(object):
self._timestamp = datetime.fromtimestamp(artifact_api_object.timestamp)
self._metadata = dict(artifact_api_object.display_data) if artifact_api_object.display_data else {}
self._preview = artifact_api_object.type_data.preview if artifact_api_object.type_data else None
self._object = None
def get(self):
"""
Return an object constructed from the artifact file
Currently supported types: Numpy.array, pandas.DataFrame, PIL.Image, dict (json)
All other types will return a pathlib2.Path object pointing to a local copy of the artifacts file (or directory)
:return: One of the following objects Numpy.array, pandas.DataFrame, PIL.Image, dict (json), pathlib2.Path
"""
if self._object:
return self._object
local_file = self.get_local_copy()
if self.type == 'numpy' and np:
self._object = np.load(local_file)[self.name]
elif self.type in ('pandas', Artifacts._pd_artifact_type) and pd:
self._object = pd.read_csv(local_file)
elif self.type == 'image':
self._object = Image.open(local_file)
elif self.type == 'JSON':
with open(local_file, 'rt') as f:
self._object = json.load(f)
local_file = Path(local_file)
if self._object is None:
self._object = local_file
else:
from trains.storage.helper import StorageHelper
# only of we are not using cache, we should delete the file
if not hasattr(StorageHelper, 'get_cached_disabled'):
# delete the temporary file, we already used it
try:
local_file.unlink()
except Exception:
pass
return self._object
def get_local_copy(self, extract_archive=True):
"""