mirror of
https://github.com/clearml/clearml
synced 2025-05-31 18:48:16 +00:00
Allow to force_download model weights
This commit is contained in:
parent
c5b05fcd49
commit
4bb83c1c6c
7
.gitignore
vendored
7
.gitignore
vendored
@ -11,9 +11,14 @@ build/
|
||||
dist/
|
||||
*.egg-info
|
||||
.env
|
||||
venv/
|
||||
.venv/
|
||||
|
||||
# example data
|
||||
examples/runs/
|
||||
examples/*_data
|
||||
examples/frameworks/data
|
||||
venv/
|
||||
|
||||
# vscode
|
||||
.workspace/
|
||||
.vscode/
|
||||
|
@ -388,13 +388,14 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
return False
|
||||
return bool(self.data.ready)
|
||||
|
||||
def download_model_weights(self, raise_on_error=False):
|
||||
def download_model_weights(self, raise_on_error=False, force_download=False):
|
||||
"""
|
||||
Download the model weights into a local file in our cache
|
||||
|
||||
:param bool raise_on_error: If True and the artifact could not be downloaded,
|
||||
raise ValueError, otherwise return None on failure and output log warning.
|
||||
|
||||
:param bool force_download: If True, the base artifact will be downloaded,
|
||||
even if the artifact is already cached.
|
||||
:return: a local path to a downloaded copy of the model
|
||||
"""
|
||||
uri = self.data.uri
|
||||
@ -404,12 +405,12 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
# check if we already downloaded the file
|
||||
downloaded_models = [k for k, (i, u) in Model._local_model_to_id_uri.items() if i == self.id and u == uri]
|
||||
for dl_file in downloaded_models:
|
||||
if Path(dl_file).exists():
|
||||
if Path(dl_file).exists() and not force_download:
|
||||
return dl_file
|
||||
# remove non existing model file
|
||||
Model._local_model_to_id_uri.pop(dl_file, None)
|
||||
|
||||
local_download = StorageManager.get_local_copy(uri, extract_archive=False)
|
||||
local_download = StorageManager.get_local_copy(uri, extract_archive=False, force_download=force_download)
|
||||
|
||||
# save local model, so we can later query what was the original one
|
||||
if local_download is not None:
|
||||
|
@ -330,21 +330,26 @@ class BaseModel(object):
|
||||
self._reload_required = False
|
||||
self._set_task(task)
|
||||
|
||||
def get_weights(self, raise_on_error=False):
|
||||
# type: (bool) -> str
|
||||
def get_weights(self, raise_on_error=False, force_download=False):
|
||||
# type: (bool, bool) -> str
|
||||
"""
|
||||
Download the base model and return the locally stored filename.
|
||||
|
||||
:param bool raise_on_error: If True, and the artifact could not be downloaded,
|
||||
raise ValueError, otherwise return None on failure and output log warning.
|
||||
|
||||
:param bool force_download: If True, the base model will be downloaded,
|
||||
even if the base model is already cached.
|
||||
|
||||
:return: The locally stored file.
|
||||
"""
|
||||
# download model (synchronously) and return local file
|
||||
return self._get_base_model().download_model_weights(raise_on_error=raise_on_error)
|
||||
return self._get_base_model().download_model_weights(raise_on_error=raise_on_error, force_download=force_download)
|
||||
|
||||
def get_weights_package(self, return_path=False, raise_on_error=False):
|
||||
# type: (bool, bool) -> Optional[Union[str, List[Path]]]
|
||||
def get_weights_package(
|
||||
self, return_path=False, raise_on_error=False, force_download=False
|
||||
):
|
||||
# type: (bool, bool, bool) -> Optional[Union[str, List[Path]]]
|
||||
"""
|
||||
Download the base model package into a temporary directory (extract the files), or return a list of the
|
||||
locally stored filenames.
|
||||
@ -356,7 +361,8 @@ class BaseModel(object):
|
||||
|
||||
:param bool raise_on_error: If True, and the artifact could not be downloaded,
|
||||
raise ValueError, otherwise return None on failure and output log warning.
|
||||
|
||||
:param bool force_download: If True, the base artifact will be downloaded,
|
||||
even if the artifact is already cached.
|
||||
:return: The model weights, or a list of the locally stored filenames.
|
||||
if raise_on_error=False, returns None on error.
|
||||
"""
|
||||
@ -365,7 +371,7 @@ class BaseModel(object):
|
||||
raise ValueError('Model is not packaged')
|
||||
|
||||
# download packaged model
|
||||
packed_file = self.get_weights(raise_on_error=raise_on_error)
|
||||
packed_file = self.get_weights(raise_on_error=raise_on_error, force_download=force_download)
|
||||
|
||||
if not packed_file:
|
||||
if raise_on_error:
|
||||
@ -580,8 +586,10 @@ class Model(BaseModel):
|
||||
self._base_model_id = model_id
|
||||
self._base_model = None
|
||||
|
||||
def get_local_copy(self, extract_archive=True, raise_on_error=False):
|
||||
# type: (bool, bool) -> str
|
||||
def get_local_copy(
|
||||
self, extract_archive=True, raise_on_error=False, force_download=False
|
||||
):
|
||||
# type: (bool, bool, bool) -> str
|
||||
"""
|
||||
Retrieve a valid link to the model file(s).
|
||||
If the model URL is a file system link, it will be returned directly.
|
||||
@ -592,12 +600,20 @@ class Model(BaseModel):
|
||||
The returned path will be a temporary folder containing the archive content
|
||||
:param bool raise_on_error: If True, and the artifact could not be downloaded,
|
||||
raise ValueError, otherwise return None on failure and output log warning.
|
||||
:param bool force_download: If True, the artifact will be downloaded,
|
||||
even if the model artifact is already cached.
|
||||
|
||||
:return: A local path to the model (or a downloaded copy of it).
|
||||
"""
|
||||
if extract_archive and self._is_package():
|
||||
return self.get_weights_package(return_path=True, raise_on_error=raise_on_error)
|
||||
return self.get_weights(raise_on_error=raise_on_error)
|
||||
return self.get_weights_package(
|
||||
return_path=True,
|
||||
raise_on_error=raise_on_error,
|
||||
force_download=force_download,
|
||||
)
|
||||
return self.get_weights(
|
||||
raise_on_error=raise_on_error, force_download=force_download
|
||||
)
|
||||
|
||||
def _get_base_model(self):
|
||||
if self._base_model:
|
||||
|
Loading…
Reference in New Issue
Block a user