mirror of
https://github.com/clearml/clearml
synced 2025-04-03 12:31:11 +00:00
Add raise_on_error (default=False) argument to Model.get_local_copy()
This commit is contained in:
parent
163f0c8587
commit
88cddcfe1d
@ -394,8 +394,15 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
return False
|
||||
return bool(self.data.ready)
|
||||
|
||||
def download_model_weights(self):
|
||||
""" Download the model weights into a local file in our cache """
|
||||
def download_model_weights(self, raise_on_error=False):
|
||||
"""
|
||||
Download the model weights into a local file in our cache
|
||||
|
||||
:param bool raise_on_error: If True and the artifact could not be downloaded,
|
||||
raise ValueError, otherwise return None on failure and output log warning.
|
||||
|
||||
:return: a local path to a downloaded copy of the model
|
||||
"""
|
||||
uri = self.data.uri
|
||||
if not uri or not uri.strip():
|
||||
return None
|
||||
@ -411,7 +418,11 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
local_download = StorageManager.get_local_copy(uri)
|
||||
|
||||
# save local model, so we can later query what was the original one
|
||||
Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri)
|
||||
if local_download is not None:
|
||||
Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri)
|
||||
elif raise_on_error:
|
||||
raise ValueError("Could not retrieve a local copy of model weights {}, "
|
||||
"failed downloading {}".format(self.model_id, uri))
|
||||
|
||||
return local_download
|
||||
|
||||
|
@ -258,20 +258,23 @@ class BaseModel(object):
|
||||
self._task = None
|
||||
self._set_task(task)
|
||||
|
||||
def get_weights(self):
|
||||
# type: () -> str
|
||||
def get_weights(self, raise_on_error=False):
|
||||
# type: (bool) -> str
|
||||
"""
|
||||
Download the base model and return the locally stored filename.
|
||||
|
||||
:param bool raise_on_error: If True and the artifact could not be downloaded,
|
||||
raise ValueError, otherwise return None on failure and output log warning.
|
||||
|
||||
:return: The locally stored file.
|
||||
|
||||
:rtype: str
|
||||
"""
|
||||
# download model (synchronously) and return local file
|
||||
return self._get_base_model().download_model_weights()
|
||||
return self._get_base_model().download_model_weights(raise_on_error=raise_on_error)
|
||||
|
||||
def get_weights_package(self, return_path=False):
|
||||
# type: (bool) -> Union[str, List[Path]]
|
||||
def get_weights_package(self, return_path=False, raise_on_error=False):
|
||||
# type: (bool, bool) -> Union[str, List[Path]]
|
||||
"""
|
||||
Download the base model package into a temporary directory (extract the files), or return a list of the
|
||||
locally stored filenames.
|
||||
@ -281,6 +284,9 @@ class BaseModel(object):
|
||||
- ``True`` - Download the model weights into a temporary directory, and return the temporary directory path.
|
||||
- ``False`` - Return a list of the locally stored filenames. (Default)
|
||||
|
||||
:param bool raise_on_error: If True and the artifact could not be downloaded,
|
||||
raise ValueError, otherwise return None on failure and output log warning.
|
||||
|
||||
:return: The model weights, or a list of the locally stored filenames.
|
||||
|
||||
:rtype: package or path
|
||||
@ -290,7 +296,7 @@ class BaseModel(object):
|
||||
raise ValueError('Model is not packaged')
|
||||
|
||||
# download packaged model
|
||||
packed_file = self.get_weights()
|
||||
packed_file = self.get_weights(raise_on_error=raise_on_error)
|
||||
|
||||
# unpack
|
||||
target_folder = mkdtemp(prefix='model_package_')
|
||||
@ -314,6 +320,7 @@ class BaseModel(object):
|
||||
return target_files
|
||||
|
||||
def publish(self):
|
||||
# type: () -> ()
|
||||
"""
|
||||
Set the model to the status ``published`` and for public use. If the model's status is already ``published``,
|
||||
then this method is a no-op.
|
||||
@ -323,9 +330,11 @@ class BaseModel(object):
|
||||
self._get_base_model().publish()
|
||||
|
||||
def _running_remotely(self):
|
||||
# type: () -> ()
|
||||
return bool(running_remotely() and self._task is not None)
|
||||
|
||||
def _set_task(self, value):
|
||||
# type: (_Task) -> ()
|
||||
if value is not None and not isinstance(value, _Task):
|
||||
raise ValueError('task argument must be of Task type')
|
||||
self._task = value
|
||||
@ -404,8 +413,8 @@ class Model(BaseModel):
|
||||
self._base_model_id = model_id
|
||||
self._base_model = None
|
||||
|
||||
def get_local_copy(self, extract_archive=True):
|
||||
# type: (bool) -> str
|
||||
def get_local_copy(self, extract_archive=True, raise_on_error=False):
|
||||
# type: (bool, bool) -> str
|
||||
"""
|
||||
Retrieve a valid link to the model file(s).
|
||||
If the model URL is a file system link, it will be returned directly.
|
||||
@ -414,11 +423,13 @@ class Model(BaseModel):
|
||||
|
||||
:param bool extract_archive: If True and the model is of type 'packaged' (e.g. TensorFlow compressed folder)
|
||||
The returned path will be a temporary folder containing the archive content
|
||||
:param bool raise_on_error: If True and the artifact could not be downloaded,
|
||||
raise ValueError, otherwise return None on failure and output log warning.
|
||||
:return str: a local path to the model (or a downloaded copy of it)
|
||||
"""
|
||||
if extract_archive and self._package_tag in self.tags:
|
||||
return self.get_weights_package(return_path=True)
|
||||
return self.get_weights()
|
||||
return self.get_weights_package(return_path=True, raise_on_error=raise_on_error)
|
||||
return self.get_weights(raise_on_error=raise_on_error)
|
||||
|
||||
def _get_base_model(self):
|
||||
if self._base_model:
|
||||
|
Loading…
Reference in New Issue
Block a user