mirror of
https://github.com/clearml/clearml
synced 2025-04-09 23:24:31 +00:00
Add raise_on_error (default=False) argument to Model.get_local_copy()
This commit is contained in:
parent
163f0c8587
commit
88cddcfe1d
@ -394,8 +394,15 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
return False
|
return False
|
||||||
return bool(self.data.ready)
|
return bool(self.data.ready)
|
||||||
|
|
||||||
def download_model_weights(self):
|
def download_model_weights(self, raise_on_error=False):
|
||||||
""" Download the model weights into a local file in our cache """
|
"""
|
||||||
|
Download the model weights into a local file in our cache
|
||||||
|
|
||||||
|
:param bool raise_on_error: If True and the artifact could not be downloaded,
|
||||||
|
raise ValueError, otherwise return None on failure and output log warning.
|
||||||
|
|
||||||
|
:return: a local path to a downloaded copy of the model
|
||||||
|
"""
|
||||||
uri = self.data.uri
|
uri = self.data.uri
|
||||||
if not uri or not uri.strip():
|
if not uri or not uri.strip():
|
||||||
return None
|
return None
|
||||||
@ -411,7 +418,11 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
local_download = StorageManager.get_local_copy(uri)
|
local_download = StorageManager.get_local_copy(uri)
|
||||||
|
|
||||||
# save local model, so we can later query what was the original one
|
# save local model, so we can later query what was the original one
|
||||||
|
if local_download is not None:
|
||||||
Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri)
|
Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri)
|
||||||
|
elif raise_on_error:
|
||||||
|
raise ValueError("Could not retrieve a local copy of model weights {}, "
|
||||||
|
"failed downloading {}".format(self.model_id, uri))
|
||||||
|
|
||||||
return local_download
|
return local_download
|
||||||
|
|
||||||
|
@ -258,20 +258,23 @@ class BaseModel(object):
|
|||||||
self._task = None
|
self._task = None
|
||||||
self._set_task(task)
|
self._set_task(task)
|
||||||
|
|
||||||
def get_weights(self):
|
def get_weights(self, raise_on_error=False):
|
||||||
# type: () -> str
|
# type: (bool) -> str
|
||||||
"""
|
"""
|
||||||
Download the base model and return the locally stored filename.
|
Download the base model and return the locally stored filename.
|
||||||
|
|
||||||
|
:param bool raise_on_error: If True and the artifact could not be downloaded,
|
||||||
|
raise ValueError, otherwise return None on failure and output log warning.
|
||||||
|
|
||||||
:return: The locally stored file.
|
:return: The locally stored file.
|
||||||
|
|
||||||
:rtype: str
|
:rtype: str
|
||||||
"""
|
"""
|
||||||
# download model (synchronously) and return local file
|
# download model (synchronously) and return local file
|
||||||
return self._get_base_model().download_model_weights()
|
return self._get_base_model().download_model_weights(raise_on_error=raise_on_error)
|
||||||
|
|
||||||
def get_weights_package(self, return_path=False):
|
def get_weights_package(self, return_path=False, raise_on_error=False):
|
||||||
# type: (bool) -> Union[str, List[Path]]
|
# type: (bool, bool) -> Union[str, List[Path]]
|
||||||
"""
|
"""
|
||||||
Download the base model package into a temporary directory (extract the files), or return a list of the
|
Download the base model package into a temporary directory (extract the files), or return a list of the
|
||||||
locally stored filenames.
|
locally stored filenames.
|
||||||
@ -281,6 +284,9 @@ class BaseModel(object):
|
|||||||
- ``True`` - Download the model weights into a temporary directory, and return the temporary directory path.
|
- ``True`` - Download the model weights into a temporary directory, and return the temporary directory path.
|
||||||
- ``False`` - Return a list of the locally stored filenames. (Default)
|
- ``False`` - Return a list of the locally stored filenames. (Default)
|
||||||
|
|
||||||
|
:param bool raise_on_error: If True and the artifact could not be downloaded,
|
||||||
|
raise ValueError, otherwise return None on failure and output log warning.
|
||||||
|
|
||||||
:return: The model weights, or a list of the locally stored filenames.
|
:return: The model weights, or a list of the locally stored filenames.
|
||||||
|
|
||||||
:rtype: package or path
|
:rtype: package or path
|
||||||
@ -290,7 +296,7 @@ class BaseModel(object):
|
|||||||
raise ValueError('Model is not packaged')
|
raise ValueError('Model is not packaged')
|
||||||
|
|
||||||
# download packaged model
|
# download packaged model
|
||||||
packed_file = self.get_weights()
|
packed_file = self.get_weights(raise_on_error=raise_on_error)
|
||||||
|
|
||||||
# unpack
|
# unpack
|
||||||
target_folder = mkdtemp(prefix='model_package_')
|
target_folder = mkdtemp(prefix='model_package_')
|
||||||
@ -314,6 +320,7 @@ class BaseModel(object):
|
|||||||
return target_files
|
return target_files
|
||||||
|
|
||||||
def publish(self):
|
def publish(self):
|
||||||
|
# type: () -> ()
|
||||||
"""
|
"""
|
||||||
Set the model to the status ``published`` and for public use. If the model's status is already ``published``,
|
Set the model to the status ``published`` and for public use. If the model's status is already ``published``,
|
||||||
then this method is a no-op.
|
then this method is a no-op.
|
||||||
@ -323,9 +330,11 @@ class BaseModel(object):
|
|||||||
self._get_base_model().publish()
|
self._get_base_model().publish()
|
||||||
|
|
||||||
def _running_remotely(self):
|
def _running_remotely(self):
|
||||||
|
# type: () -> ()
|
||||||
return bool(running_remotely() and self._task is not None)
|
return bool(running_remotely() and self._task is not None)
|
||||||
|
|
||||||
def _set_task(self, value):
|
def _set_task(self, value):
|
||||||
|
# type: (_Task) -> ()
|
||||||
if value is not None and not isinstance(value, _Task):
|
if value is not None and not isinstance(value, _Task):
|
||||||
raise ValueError('task argument must be of Task type')
|
raise ValueError('task argument must be of Task type')
|
||||||
self._task = value
|
self._task = value
|
||||||
@ -404,8 +413,8 @@ class Model(BaseModel):
|
|||||||
self._base_model_id = model_id
|
self._base_model_id = model_id
|
||||||
self._base_model = None
|
self._base_model = None
|
||||||
|
|
||||||
def get_local_copy(self, extract_archive=True):
|
def get_local_copy(self, extract_archive=True, raise_on_error=False):
|
||||||
# type: (bool) -> str
|
# type: (bool, bool) -> str
|
||||||
"""
|
"""
|
||||||
Retrieve a valid link to the model file(s).
|
Retrieve a valid link to the model file(s).
|
||||||
If the model URL is a file system link, it will be returned directly.
|
If the model URL is a file system link, it will be returned directly.
|
||||||
@ -414,11 +423,13 @@ class Model(BaseModel):
|
|||||||
|
|
||||||
:param bool extract_archive: If True and the model is of type 'packaged' (e.g. TensorFlow compressed folder)
|
:param bool extract_archive: If True and the model is of type 'packaged' (e.g. TensorFlow compressed folder)
|
||||||
The returned path will be a temporary folder containing the archive content
|
The returned path will be a temporary folder containing the archive content
|
||||||
|
:param bool raise_on_error: If True and the artifact could not be downloaded,
|
||||||
|
raise ValueError, otherwise return None on failure and output log warning.
|
||||||
:return str: a local path to the model (or a downloaded copy of it)
|
:return str: a local path to the model (or a downloaded copy of it)
|
||||||
"""
|
"""
|
||||||
if extract_archive and self._package_tag in self.tags:
|
if extract_archive and self._package_tag in self.tags:
|
||||||
return self.get_weights_package(return_path=True)
|
return self.get_weights_package(return_path=True, raise_on_error=raise_on_error)
|
||||||
return self.get_weights()
|
return self.get_weights(raise_on_error=raise_on_error)
|
||||||
|
|
||||||
def _get_base_model(self):
|
def _get_base_model(self):
|
||||||
if self._base_model:
|
if self._base_model:
|
||||||
|
Loading…
Reference in New Issue
Block a user