From d19fde7041878780da7b1f3a30f489b07145972e Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 15 Nov 2019 21:59:36 +0200 Subject: [PATCH] Add artifact get() returning the artifact object (after downloading and loading) --- trains/binding/artifacts.py | 41 +++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/trains/binding/artifacts.py b/trains/binding/artifacts.py index 8c7bbd60..9e380a60 100644 --- a/trains/binding/artifacts.py +++ b/trains/binding/artifacts.py @@ -114,6 +114,47 @@ class Artifact(object): self._timestamp = datetime.fromtimestamp(artifact_api_object.timestamp) self._metadata = dict(artifact_api_object.display_data) if artifact_api_object.display_data else {} self._preview = artifact_api_object.type_data.preview if artifact_api_object.type_data else None + self._object = None + + def get(self): + """ + Return an object constructed from the artifact file + + Currently supported types: Numpy.array, pandas.DataFrame, PIL.Image, dict (json) + All other types will return a pathlib2.Path object pointing to a local copy of the artifacts file (or directory) + + :return: One of the following objects Numpy.array, pandas.DataFrame, PIL.Image, dict (json), pathlib2.Path + """ + if self._object: + return self._object + + local_file = self.get_local_copy() + + if self.type == 'numpy' and np: + self._object = np.load(local_file)[self.name] + elif self.type in ('pandas', Artifacts._pd_artifact_type) and pd: + self._object = pd.read_csv(local_file) + elif self.type == 'image': + self._object = Image.open(local_file) + elif self.type == 'JSON': + with open(local_file, 'rt') as f: + self._object = json.load(f) + + local_file = Path(local_file) + + if self._object is None: + self._object = local_file + else: + from trains.storage.helper import StorageHelper + # only of we are not using cache, we should delete the file + if not hasattr(StorageHelper, 'get_cached_disabled'): + # delete the temporary file, we already used it + try: + local_file.unlink() + except Exception: + pass + + return self._object def get_local_copy(self, extract_archive=True): """