From 88cddcfe1d57e4c168562f4f70a3876d152932a1 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 22 May 2020 11:00:17 +0300 Subject: [PATCH] Add raise_on_error (default=False) argument to Model.get_local_copy() --- trains/backend_interface/model.py | 17 ++++++++++++++--- trains/model.py | 31 +++++++++++++++++++++---------- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/trains/backend_interface/model.py b/trains/backend_interface/model.py index 8cfe6b21..26a4db40 100644 --- a/trains/backend_interface/model.py +++ b/trains/backend_interface/model.py @@ -394,8 +394,15 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): return False return bool(self.data.ready) - def download_model_weights(self): - """ Download the model weights into a local file in our cache """ + def download_model_weights(self, raise_on_error=False): + """ + Download the model weights into a local file in our cache + + :param bool raise_on_error: If True and the artifact could not be downloaded, + raise ValueError, otherwise return None on failure and output log warning. + + :return: a local path to a downloaded copy of the model + """ uri = self.data.uri if not uri or not uri.strip(): return None @@ -411,7 +418,11 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): local_download = StorageManager.get_local_copy(uri) # save local model, so we can later query what was the original one - Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri) + if local_download is not None: + Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri) + elif raise_on_error: + raise ValueError("Could not retrieve a local copy of model weights {}, " + "failed downloading {}".format(self.model_id, uri)) return local_download diff --git a/trains/model.py b/trains/model.py index 50669e89..270354cf 100644 --- a/trains/model.py +++ b/trains/model.py @@ -258,20 +258,23 @@ class BaseModel(object): self._task = None self._set_task(task) - def get_weights(self): - # type: () -> str + def get_weights(self, raise_on_error=False): + # type: (bool) -> str """ Download the base model and return the locally stored filename. + :param bool raise_on_error: If True and the artifact could not be downloaded, + raise ValueError, otherwise return None on failure and output log warning. + :return: The locally stored file. :rtype: str """ # download model (synchronously) and return local file - return self._get_base_model().download_model_weights() + return self._get_base_model().download_model_weights(raise_on_error=raise_on_error) - def get_weights_package(self, return_path=False): - # type: (bool) -> Union[str, List[Path]] + def get_weights_package(self, return_path=False, raise_on_error=False): + # type: (bool, bool) -> Union[str, List[Path]] """ Download the base model package into a temporary directory (extract the files), or return a list of the locally stored filenames. @@ -281,6 +284,9 @@ class BaseModel(object): - ``True`` - Download the model weights into a temporary directory, and return the temporary directory path. - ``False`` - Return a list of the locally stored filenames. (Default) + :param bool raise_on_error: If True and the artifact could not be downloaded, + raise ValueError, otherwise return None on failure and output log warning. + :return: The model weights, or a list of the locally stored filenames. :rtype: package or path @@ -290,7 +296,7 @@ class BaseModel(object): raise ValueError('Model is not packaged') # download packaged model - packed_file = self.get_weights() + packed_file = self.get_weights(raise_on_error=raise_on_error) # unpack target_folder = mkdtemp(prefix='model_package_') @@ -314,6 +320,7 @@ class BaseModel(object): return target_files def publish(self): + # type: () -> () """ Set the model to the status ``published`` and for public use. If the model's status is already ``published``, then this method is a no-op. @@ -323,9 +330,11 @@ class BaseModel(object): self._get_base_model().publish() def _running_remotely(self): + # type: () -> () return bool(running_remotely() and self._task is not None) def _set_task(self, value): + # type: (_Task) -> () if value is not None and not isinstance(value, _Task): raise ValueError('task argument must be of Task type') self._task = value @@ -404,8 +413,8 @@ class Model(BaseModel): self._base_model_id = model_id self._base_model = None - def get_local_copy(self, extract_archive=True): - # type: (bool) -> str + def get_local_copy(self, extract_archive=True, raise_on_error=False): + # type: (bool, bool) -> str """ Retrieve a valid link to the model file(s). If the model URL is a file system link, it will be returned directly. @@ -414,11 +423,13 @@ class Model(BaseModel): :param bool extract_archive: If True and the model is of type 'packaged' (e.g. TensorFlow compressed folder) The returned path will be a temporary folder containing the archive content + :param bool raise_on_error: If True and the artifact could not be downloaded, + raise ValueError, otherwise return None on failure and output log warning. :return str: a local path to the model (or a downloaded copy of it) """ if extract_archive and self._package_tag in self.tags: - return self.get_weights_package(return_path=True) - return self.get_weights() + return self.get_weights_package(return_path=True, raise_on_error=raise_on_error) + return self.get_weights(raise_on_error=raise_on_error) def _get_base_model(self): if self._base_model: