mirror of
https://github.com/clearml/clearml
synced 2025-04-23 07:45:24 +00:00
Use StorageManager cache when downloading models
This commit is contained in:
parent
b8ceba38dc
commit
b2bb04899c
@ -527,7 +527,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
return False
|
return False
|
||||||
return bool(self.data.ready)
|
return bool(self.data.ready)
|
||||||
|
|
||||||
def download_model_weights(self, raise_on_error=False, force_download=False):
|
def download_model_weights(self, raise_on_error=False, force_download=False, extract_archive=False):
|
||||||
"""
|
"""
|
||||||
Download the model weights into a local file in our cache
|
Download the model weights into a local file in our cache
|
||||||
|
|
||||||
@ -537,6 +537,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
:param bool force_download: If True, the base artifact will be downloaded,
|
:param bool force_download: If True, the base artifact will be downloaded,
|
||||||
even if the artifact is already cached.
|
even if the artifact is already cached.
|
||||||
|
|
||||||
|
:param bool extract_archive: If True, unzip the downloaded file if possible
|
||||||
|
|
||||||
:return: a local path to a downloaded copy of the model
|
:return: a local path to a downloaded copy of the model
|
||||||
"""
|
"""
|
||||||
uri = self.data.uri
|
uri = self.data.uri
|
||||||
@ -556,7 +558,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
Model._local_model_to_id_uri.pop(dl_file, None)
|
Model._local_model_to_id_uri.pop(dl_file, None)
|
||||||
|
|
||||||
local_download = StorageManager.get_local_copy(
|
local_download = StorageManager.get_local_copy(
|
||||||
uri, extract_archive=False, force_download=force_download
|
uri, extract_archive=extract_archive, force_download=force_download
|
||||||
)
|
)
|
||||||
|
|
||||||
# save local model, so we can later query what was the original one
|
# save local model, so we can later query what was the original one
|
||||||
|
@ -362,8 +362,8 @@ class BaseModel(object):
|
|||||||
self._task_connect_name = None
|
self._task_connect_name = None
|
||||||
self._set_task(task)
|
self._set_task(task)
|
||||||
|
|
||||||
def get_weights(self, raise_on_error=False, force_download=False):
|
def get_weights(self, raise_on_error=False, force_download=False, extract_archive=False):
|
||||||
# type: (bool, bool) -> str
|
# type: (bool, bool, bool) -> str
|
||||||
"""
|
"""
|
||||||
Download the base model and return the locally stored filename.
|
Download the base model and return the locally stored filename.
|
||||||
|
|
||||||
@ -373,17 +373,19 @@ class BaseModel(object):
|
|||||||
:param bool force_download: If True, the base model will be downloaded,
|
:param bool force_download: If True, the base model will be downloaded,
|
||||||
even if the base model is already cached.
|
even if the base model is already cached.
|
||||||
|
|
||||||
|
:param bool extract_archive: If True, the downloaded weights file will be extracted if possible
|
||||||
|
|
||||||
:return: The locally stored file.
|
:return: The locally stored file.
|
||||||
"""
|
"""
|
||||||
# download model (synchronously) and return local file
|
# download model (synchronously) and return local file
|
||||||
return self._get_base_model().download_model_weights(
|
return self._get_base_model().download_model_weights(
|
||||||
raise_on_error=raise_on_error, force_download=force_download
|
raise_on_error=raise_on_error, force_download=force_download, extract_archive=extract_archive
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_weights_package(
|
def get_weights_package(
|
||||||
self, return_path=False, raise_on_error=False, force_download=False
|
self, return_path=False, raise_on_error=False, force_download=False, extract_archive=True
|
||||||
):
|
):
|
||||||
# type: (bool, bool, bool) -> Optional[Union[str, List[Path]]]
|
# type: (bool, bool, bool, bool) -> Optional[Union[str, List[Path]]]
|
||||||
"""
|
"""
|
||||||
Download the base model package into a temporary directory (extract the files), or return a list of the
|
Download the base model package into a temporary directory (extract the files), or return a list of the
|
||||||
locally stored filenames.
|
locally stored filenames.
|
||||||
@ -399,6 +401,8 @@ class BaseModel(object):
|
|||||||
:param bool force_download: If True, the base artifact will be downloaded,
|
:param bool force_download: If True, the base artifact will be downloaded,
|
||||||
even if the artifact is already cached.
|
even if the artifact is already cached.
|
||||||
|
|
||||||
|
:param bool extract_archive: If True, the downloaded weights file will be extracted if possible
|
||||||
|
|
||||||
:return: The model weights, or a list of the locally stored filenames.
|
:return: The model weights, or a list of the locally stored filenames.
|
||||||
if raise_on_error=False, returns None on error.
|
if raise_on_error=False, returns None on error.
|
||||||
"""
|
"""
|
||||||
@ -407,40 +411,21 @@ class BaseModel(object):
|
|||||||
raise ValueError("Model is not packaged")
|
raise ValueError("Model is not packaged")
|
||||||
|
|
||||||
# download packaged model
|
# download packaged model
|
||||||
packed_file = self.get_weights(
|
model_path = self.get_weights(
|
||||||
raise_on_error=raise_on_error, force_download=force_download
|
raise_on_error=raise_on_error, force_download=force_download, extract_archive=extract_archive
|
||||||
)
|
)
|
||||||
|
|
||||||
if not packed_file:
|
if not model_path:
|
||||||
if raise_on_error:
|
if raise_on_error:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Model package '{}' could not be downloaded".format(self.url)
|
"Model package '{}' could not be downloaded".format(self.url)
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# unpack
|
|
||||||
target_folder = mkdtemp(prefix="model_package_")
|
|
||||||
if not target_folder:
|
|
||||||
raise ValueError(
|
|
||||||
"cannot create temporary directory for packed weight files"
|
|
||||||
)
|
|
||||||
|
|
||||||
for func in (zipfile.ZipFile, tarfile.open):
|
|
||||||
try:
|
|
||||||
obj = func(packed_file)
|
|
||||||
obj.extractall(path=target_folder)
|
|
||||||
break
|
|
||||||
except (zipfile.BadZipfile, tarfile.ReadError):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"cannot extract files from packaged model at %s", packed_file
|
|
||||||
)
|
|
||||||
|
|
||||||
if return_path:
|
if return_path:
|
||||||
return target_folder
|
return model_path
|
||||||
|
|
||||||
target_files = list(Path(target_folder).glob("*"))
|
target_files = list(Path(model_path).glob("*"))
|
||||||
return target_files
|
return target_files
|
||||||
|
|
||||||
def report_scalar(self, title, series, value, iteration):
|
def report_scalar(self, title, series, value, iteration):
|
||||||
@ -1374,17 +1359,18 @@ class Model(BaseModel):
|
|||||||
self._base_model = None
|
self._base_model = None
|
||||||
|
|
||||||
def get_local_copy(
|
def get_local_copy(
|
||||||
self, extract_archive=True, raise_on_error=False, force_download=False
|
self, extract_archive=None, raise_on_error=False, force_download=False
|
||||||
):
|
):
|
||||||
# type: (bool, bool, bool) -> str
|
# type: (Optional[bool], bool, bool) -> str
|
||||||
"""
|
"""
|
||||||
Retrieve a valid link to the model file(s).
|
Retrieve a valid link to the model file(s).
|
||||||
If the model URL is a file system link, it will be returned directly.
|
If the model URL is a file system link, it will be returned directly.
|
||||||
If the model URL points to a remote location (http/s3/gs etc.),
|
If the model URL points to a remote location (http/s3/gs etc.),
|
||||||
it will download the file(s) and return the temporary location of the downloaded model.
|
it will download the file(s) and return the temporary location of the downloaded model.
|
||||||
|
|
||||||
:param bool extract_archive: If True, and the model is of type 'packaged' (e.g. TensorFlow compressed folder)
|
:param bool extract_archive: If True, the local copy will be extracted if possible. If False,
|
||||||
The returned path will be a temporary folder containing the archive content
|
the local copy will not be extracted. If None (default), the downloaded file will be extracted
|
||||||
|
if the model is a package.
|
||||||
:param bool raise_on_error: If True, and the artifact could not be downloaded,
|
:param bool raise_on_error: If True, and the artifact could not be downloaded,
|
||||||
raise ValueError, otherwise return None on failure and output log warning.
|
raise ValueError, otherwise return None on failure and output log warning.
|
||||||
:param bool force_download: If True, the artifact will be downloaded,
|
:param bool force_download: If True, the artifact will be downloaded,
|
||||||
@ -1392,14 +1378,17 @@ class Model(BaseModel):
|
|||||||
|
|
||||||
:return: A local path to the model (or a downloaded copy of it).
|
:return: A local path to the model (or a downloaded copy of it).
|
||||||
"""
|
"""
|
||||||
if extract_archive and self._is_package():
|
if self._is_package():
|
||||||
return self.get_weights_package(
|
return self.get_weights_package(
|
||||||
return_path=True,
|
return_path=True,
|
||||||
raise_on_error=raise_on_error,
|
raise_on_error=raise_on_error,
|
||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
|
extract_archive=True if extract_archive is None else extract_archive
|
||||||
)
|
)
|
||||||
return self.get_weights(
|
return self.get_weights(
|
||||||
raise_on_error=raise_on_error, force_download=force_download
|
raise_on_error=raise_on_error,
|
||||||
|
force_download=force_download,
|
||||||
|
extract_archive=False if extract_archive is None else extract_archive
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_base_model(self):
|
def _get_base_model(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user