mirror of
https://github.com/clearml/clearml
synced 2025-06-04 03:47:57 +00:00
Allow to force_download model weights
This commit is contained in:
parent
c5b05fcd49
commit
4bb83c1c6c
7
.gitignore
vendored
7
.gitignore
vendored
@ -11,9 +11,14 @@ build/
|
|||||||
dist/
|
dist/
|
||||||
*.egg-info
|
*.egg-info
|
||||||
.env
|
.env
|
||||||
|
venv/
|
||||||
|
.venv/
|
||||||
|
|
||||||
# example data
|
# example data
|
||||||
examples/runs/
|
examples/runs/
|
||||||
examples/*_data
|
examples/*_data
|
||||||
examples/frameworks/data
|
examples/frameworks/data
|
||||||
venv/
|
|
||||||
|
# vscode
|
||||||
|
.workspace/
|
||||||
|
.vscode/
|
||||||
|
@ -388,13 +388,14 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
return False
|
return False
|
||||||
return bool(self.data.ready)
|
return bool(self.data.ready)
|
||||||
|
|
||||||
def download_model_weights(self, raise_on_error=False):
|
def download_model_weights(self, raise_on_error=False, force_download=False):
|
||||||
"""
|
"""
|
||||||
Download the model weights into a local file in our cache
|
Download the model weights into a local file in our cache
|
||||||
|
|
||||||
:param bool raise_on_error: If True and the artifact could not be downloaded,
|
:param bool raise_on_error: If True and the artifact could not be downloaded,
|
||||||
raise ValueError, otherwise return None on failure and output log warning.
|
raise ValueError, otherwise return None on failure and output log warning.
|
||||||
|
:param bool force_download: If True, the base artifact will be downloaded,
|
||||||
|
even if the artifact is already cached.
|
||||||
:return: a local path to a downloaded copy of the model
|
:return: a local path to a downloaded copy of the model
|
||||||
"""
|
"""
|
||||||
uri = self.data.uri
|
uri = self.data.uri
|
||||||
@ -404,12 +405,12 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
# check if we already downloaded the file
|
# check if we already downloaded the file
|
||||||
downloaded_models = [k for k, (i, u) in Model._local_model_to_id_uri.items() if i == self.id and u == uri]
|
downloaded_models = [k for k, (i, u) in Model._local_model_to_id_uri.items() if i == self.id and u == uri]
|
||||||
for dl_file in downloaded_models:
|
for dl_file in downloaded_models:
|
||||||
if Path(dl_file).exists():
|
if Path(dl_file).exists() and not force_download:
|
||||||
return dl_file
|
return dl_file
|
||||||
# remove non existing model file
|
# remove non existing model file
|
||||||
Model._local_model_to_id_uri.pop(dl_file, None)
|
Model._local_model_to_id_uri.pop(dl_file, None)
|
||||||
|
|
||||||
local_download = StorageManager.get_local_copy(uri, extract_archive=False)
|
local_download = StorageManager.get_local_copy(uri, extract_archive=False, force_download=force_download)
|
||||||
|
|
||||||
# save local model, so we can later query what was the original one
|
# save local model, so we can later query what was the original one
|
||||||
if local_download is not None:
|
if local_download is not None:
|
||||||
|
@ -330,21 +330,26 @@ class BaseModel(object):
|
|||||||
self._reload_required = False
|
self._reload_required = False
|
||||||
self._set_task(task)
|
self._set_task(task)
|
||||||
|
|
||||||
def get_weights(self, raise_on_error=False):
|
def get_weights(self, raise_on_error=False, force_download=False):
|
||||||
# type: (bool) -> str
|
# type: (bool, bool) -> str
|
||||||
"""
|
"""
|
||||||
Download the base model and return the locally stored filename.
|
Download the base model and return the locally stored filename.
|
||||||
|
|
||||||
:param bool raise_on_error: If True, and the artifact could not be downloaded,
|
:param bool raise_on_error: If True, and the artifact could not be downloaded,
|
||||||
raise ValueError, otherwise return None on failure and output log warning.
|
raise ValueError, otherwise return None on failure and output log warning.
|
||||||
|
|
||||||
|
:param bool force_download: If True, the base model will be downloaded,
|
||||||
|
even if the base model is already cached.
|
||||||
|
|
||||||
:return: The locally stored file.
|
:return: The locally stored file.
|
||||||
"""
|
"""
|
||||||
# download model (synchronously) and return local file
|
# download model (synchronously) and return local file
|
||||||
return self._get_base_model().download_model_weights(raise_on_error=raise_on_error)
|
return self._get_base_model().download_model_weights(raise_on_error=raise_on_error, force_download=force_download)
|
||||||
|
|
||||||
def get_weights_package(self, return_path=False, raise_on_error=False):
|
def get_weights_package(
|
||||||
# type: (bool, bool) -> Optional[Union[str, List[Path]]]
|
self, return_path=False, raise_on_error=False, force_download=False
|
||||||
|
):
|
||||||
|
# type: (bool, bool, bool) -> Optional[Union[str, List[Path]]]
|
||||||
"""
|
"""
|
||||||
Download the base model package into a temporary directory (extract the files), or return a list of the
|
Download the base model package into a temporary directory (extract the files), or return a list of the
|
||||||
locally stored filenames.
|
locally stored filenames.
|
||||||
@ -356,7 +361,8 @@ class BaseModel(object):
|
|||||||
|
|
||||||
:param bool raise_on_error: If True, and the artifact could not be downloaded,
|
:param bool raise_on_error: If True, and the artifact could not be downloaded,
|
||||||
raise ValueError, otherwise return None on failure and output log warning.
|
raise ValueError, otherwise return None on failure and output log warning.
|
||||||
|
:param bool force_download: If True, the base artifact will be downloaded,
|
||||||
|
even if the artifact is already cached.
|
||||||
:return: The model weights, or a list of the locally stored filenames.
|
:return: The model weights, or a list of the locally stored filenames.
|
||||||
if raise_on_error=False, returns None on error.
|
if raise_on_error=False, returns None on error.
|
||||||
"""
|
"""
|
||||||
@ -365,7 +371,7 @@ class BaseModel(object):
|
|||||||
raise ValueError('Model is not packaged')
|
raise ValueError('Model is not packaged')
|
||||||
|
|
||||||
# download packaged model
|
# download packaged model
|
||||||
packed_file = self.get_weights(raise_on_error=raise_on_error)
|
packed_file = self.get_weights(raise_on_error=raise_on_error, force_download=force_download)
|
||||||
|
|
||||||
if not packed_file:
|
if not packed_file:
|
||||||
if raise_on_error:
|
if raise_on_error:
|
||||||
@ -580,8 +586,10 @@ class Model(BaseModel):
|
|||||||
self._base_model_id = model_id
|
self._base_model_id = model_id
|
||||||
self._base_model = None
|
self._base_model = None
|
||||||
|
|
||||||
def get_local_copy(self, extract_archive=True, raise_on_error=False):
|
def get_local_copy(
|
||||||
# type: (bool, bool) -> str
|
self, extract_archive=True, raise_on_error=False, force_download=False
|
||||||
|
):
|
||||||
|
# type: (bool, bool, bool) -> str
|
||||||
"""
|
"""
|
||||||
Retrieve a valid link to the model file(s).
|
Retrieve a valid link to the model file(s).
|
||||||
If the model URL is a file system link, it will be returned directly.
|
If the model URL is a file system link, it will be returned directly.
|
||||||
@ -592,12 +600,20 @@ class Model(BaseModel):
|
|||||||
The returned path will be a temporary folder containing the archive content
|
The returned path will be a temporary folder containing the archive content
|
||||||
:param bool raise_on_error: If True, and the artifact could not be downloaded,
|
:param bool raise_on_error: If True, and the artifact could not be downloaded,
|
||||||
raise ValueError, otherwise return None on failure and output log warning.
|
raise ValueError, otherwise return None on failure and output log warning.
|
||||||
|
:param bool force_download: If True, the artifact will be downloaded,
|
||||||
|
even if the model artifact is already cached.
|
||||||
|
|
||||||
:return: A local path to the model (or a downloaded copy of it).
|
:return: A local path to the model (or a downloaded copy of it).
|
||||||
"""
|
"""
|
||||||
if extract_archive and self._is_package():
|
if extract_archive and self._is_package():
|
||||||
return self.get_weights_package(return_path=True, raise_on_error=raise_on_error)
|
return self.get_weights_package(
|
||||||
return self.get_weights(raise_on_error=raise_on_error)
|
return_path=True,
|
||||||
|
raise_on_error=raise_on_error,
|
||||||
|
force_download=force_download,
|
||||||
|
)
|
||||||
|
return self.get_weights(
|
||||||
|
raise_on_error=raise_on_error, force_download=force_download
|
||||||
|
)
|
||||||
|
|
||||||
def _get_base_model(self):
|
def _get_base_model(self):
|
||||||
if self._base_model:
|
if self._base_model:
|
||||||
|
Loading…
Reference in New Issue
Block a user