Allow to force_download model weights

This commit is contained in:
Alex Burlacu 2023-03-23 18:17:19 +02:00
parent c5b05fcd49
commit 4bb83c1c6c
3 changed files with 38 additions and 16 deletions

7
.gitignore vendored
View File

@ -11,9 +11,14 @@ build/
dist/
*.egg-info
.env
venv/
.venv/
# example data
examples/runs/
examples/*_data
examples/frameworks/data
venv/
# vscode
.workspace/
.vscode/

View File

@ -388,13 +388,14 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return False
return bool(self.data.ready)
def download_model_weights(self, raise_on_error=False):
def download_model_weights(self, raise_on_error=False, force_download=False):
"""
Download the model weights into a local file in our cache
:param bool raise_on_error: If True and the artifact could not be downloaded,
raise ValueError, otherwise return None on failure and output log warning.
:param bool force_download: If True, the base artifact will be downloaded,
even if the artifact is already cached.
:return: a local path to a downloaded copy of the model
"""
uri = self.data.uri
@ -404,12 +405,12 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
# check if we already downloaded the file
downloaded_models = [k for k, (i, u) in Model._local_model_to_id_uri.items() if i == self.id and u == uri]
for dl_file in downloaded_models:
if Path(dl_file).exists():
if Path(dl_file).exists() and not force_download:
return dl_file
# remove non existing model file
Model._local_model_to_id_uri.pop(dl_file, None)
local_download = StorageManager.get_local_copy(uri, extract_archive=False)
local_download = StorageManager.get_local_copy(uri, extract_archive=False, force_download=force_download)
# save local model, so we can later query what was the original one
if local_download is not None:

View File

@ -330,21 +330,26 @@ class BaseModel(object):
self._reload_required = False
self._set_task(task)
def get_weights(self, raise_on_error=False):
# type: (bool) -> str
def get_weights(self, raise_on_error=False, force_download=False):
# type: (bool, bool) -> str
"""
Download the base model and return the locally stored filename.
:param bool raise_on_error: If True, and the artifact could not be downloaded,
raise ValueError, otherwise return None on failure and output log warning.
:param bool force_download: If True, the base model will be downloaded,
even if the base model is already cached.
:return: The locally stored file.
"""
# download model (synchronously) and return local file
return self._get_base_model().download_model_weights(raise_on_error=raise_on_error)
return self._get_base_model().download_model_weights(raise_on_error=raise_on_error, force_download=force_download)
def get_weights_package(self, return_path=False, raise_on_error=False):
# type: (bool, bool) -> Optional[Union[str, List[Path]]]
def get_weights_package(
self, return_path=False, raise_on_error=False, force_download=False
):
# type: (bool, bool, bool) -> Optional[Union[str, List[Path]]]
"""
Download the base model package into a temporary directory (extract the files), or return a list of the
locally stored filenames.
@ -356,7 +361,8 @@ class BaseModel(object):
:param bool raise_on_error: If True, and the artifact could not be downloaded,
raise ValueError, otherwise return None on failure and output log warning.
:param bool force_download: If True, the base artifact will be downloaded,
even if the artifact is already cached.
:return: The model weights, or a list of the locally stored filenames.
if raise_on_error=False, returns None on error.
"""
@ -365,7 +371,7 @@ class BaseModel(object):
raise ValueError('Model is not packaged')
# download packaged model
packed_file = self.get_weights(raise_on_error=raise_on_error)
packed_file = self.get_weights(raise_on_error=raise_on_error, force_download=force_download)
if not packed_file:
if raise_on_error:
@ -580,8 +586,10 @@ class Model(BaseModel):
self._base_model_id = model_id
self._base_model = None
def get_local_copy(self, extract_archive=True, raise_on_error=False):
# type: (bool, bool) -> str
def get_local_copy(
self, extract_archive=True, raise_on_error=False, force_download=False
):
# type: (bool, bool, bool) -> str
"""
Retrieve a valid link to the model file(s).
If the model URL is a file system link, it will be returned directly.
@ -592,12 +600,20 @@ class Model(BaseModel):
The returned path will be a temporary folder containing the archive content
:param bool raise_on_error: If True, and the artifact could not be downloaded,
raise ValueError, otherwise return None on failure and output log warning.
:param bool force_download: If True, the artifact will be downloaded,
even if the model artifact is already cached.
:return: A local path to the model (or a downloaded copy of it).
"""
if extract_archive and self._is_package():
return self.get_weights_package(return_path=True, raise_on_error=raise_on_error)
return self.get_weights(raise_on_error=raise_on_error)
return self.get_weights_package(
return_path=True,
raise_on_error=raise_on_error,
force_download=force_download,
)
return self.get_weights(
raise_on_error=raise_on_error, force_download=force_download
)
def _get_base_model(self):
if self._base_model: