Add raise_on_error (default=False) argument to Model.get_local_copy()

This commit is contained in:
allegroai 2020-05-22 11:00:17 +03:00
parent 163f0c8587
commit 88cddcfe1d
2 changed files with 35 additions and 13 deletions

View File

@ -394,8 +394,15 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return False
return bool(self.data.ready)
def download_model_weights(self):
""" Download the model weights into a local file in our cache """
def download_model_weights(self, raise_on_error=False):
"""
Download the model weights into a local file in our cache
:param bool raise_on_error: If True and the artifact could not be downloaded,
raise ValueError, otherwise return None on failure and output log warning.
:return: a local path to a downloaded copy of the model
"""
uri = self.data.uri
if not uri or not uri.strip():
return None
@ -411,7 +418,11 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
local_download = StorageManager.get_local_copy(uri)
# save local model, so we can later query what was the original one
Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri)
if local_download is not None:
Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri)
elif raise_on_error:
raise ValueError("Could not retrieve a local copy of model weights {}, "
"failed downloading {}".format(self.model_id, uri))
return local_download

View File

@ -258,20 +258,23 @@ class BaseModel(object):
self._task = None
self._set_task(task)
def get_weights(self):
# type: () -> str
def get_weights(self, raise_on_error=False):
# type: (bool) -> str
"""
Download the base model and return the locally stored filename.
:param bool raise_on_error: If True and the artifact could not be downloaded,
raise ValueError, otherwise return None on failure and output log warning.
:return: The locally stored file.
:rtype: str
"""
# download model (synchronously) and return local file
return self._get_base_model().download_model_weights()
return self._get_base_model().download_model_weights(raise_on_error=raise_on_error)
def get_weights_package(self, return_path=False):
# type: (bool) -> Union[str, List[Path]]
def get_weights_package(self, return_path=False, raise_on_error=False):
# type: (bool, bool) -> Union[str, List[Path]]
"""
Download the base model package into a temporary directory (extract the files), or return a list of the
locally stored filenames.
@ -281,6 +284,9 @@ class BaseModel(object):
- ``True`` - Download the model weights into a temporary directory, and return the temporary directory path.
- ``False`` - Return a list of the locally stored filenames. (Default)
:param bool raise_on_error: If True and the artifact could not be downloaded,
raise ValueError, otherwise return None on failure and output log warning.
:return: The model weights, or a list of the locally stored filenames.
:rtype: package or path
@ -290,7 +296,7 @@ class BaseModel(object):
raise ValueError('Model is not packaged')
# download packaged model
packed_file = self.get_weights()
packed_file = self.get_weights(raise_on_error=raise_on_error)
# unpack
target_folder = mkdtemp(prefix='model_package_')
@ -314,6 +320,7 @@ class BaseModel(object):
return target_files
def publish(self):
# type: () -> ()
"""
Set the model to the status ``published`` and for public use. If the model's status is already ``published``,
then this method is a no-op.
@ -323,9 +330,11 @@ class BaseModel(object):
self._get_base_model().publish()
def _running_remotely(self):
# type: () -> ()
return bool(running_remotely() and self._task is not None)
def _set_task(self, value):
# type: (_Task) -> ()
if value is not None and not isinstance(value, _Task):
raise ValueError('task argument must be of Task type')
self._task = value
@ -404,8 +413,8 @@ class Model(BaseModel):
self._base_model_id = model_id
self._base_model = None
def get_local_copy(self, extract_archive=True):
# type: (bool) -> str
def get_local_copy(self, extract_archive=True, raise_on_error=False):
# type: (bool, bool) -> str
"""
Retrieve a valid link to the model file(s).
If the model URL is a file system link, it will be returned directly.
@ -414,11 +423,13 @@ class Model(BaseModel):
:param bool extract_archive: If True and the model is of type 'packaged' (e.g. TensorFlow compressed folder)
The returned path will be a temporary folder containing the archive content
:param bool raise_on_error: If True and the artifact could not be downloaded,
raise ValueError, otherwise return None on failure and output log warning.
:return str: a local path to the model (or a downloaded copy of it)
"""
if extract_archive and self._package_tag in self.tags:
return self.get_weights_package(return_path=True)
return self.get_weights()
return self.get_weights_package(return_path=True, raise_on_error=raise_on_error)
return self.get_weights(raise_on_error=raise_on_error)
def _get_base_model(self):
if self._base_model: