mirror of
https://github.com/clearml/clearml
synced 2025-04-23 07:45:24 +00:00
Use StorageManager cache when downloading models
This commit is contained in:
parent
b8ceba38dc
commit
b2bb04899c
@ -527,7 +527,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
return False
|
||||
return bool(self.data.ready)
|
||||
|
||||
def download_model_weights(self, raise_on_error=False, force_download=False):
|
||||
def download_model_weights(self, raise_on_error=False, force_download=False, extract_archive=False):
|
||||
"""
|
||||
Download the model weights into a local file in our cache
|
||||
|
||||
@ -537,6 +537,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
:param bool force_download: If True, the base artifact will be downloaded,
|
||||
even if the artifact is already cached.
|
||||
|
||||
:param bool extract_archive: If True, unzip the downloaded file if possible
|
||||
|
||||
:return: a local path to a downloaded copy of the model
|
||||
"""
|
||||
uri = self.data.uri
|
||||
@ -556,7 +558,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
Model._local_model_to_id_uri.pop(dl_file, None)
|
||||
|
||||
local_download = StorageManager.get_local_copy(
|
||||
uri, extract_archive=False, force_download=force_download
|
||||
uri, extract_archive=extract_archive, force_download=force_download
|
||||
)
|
||||
|
||||
# save local model, so we can later query what was the original one
|
||||
|
@ -362,8 +362,8 @@ class BaseModel(object):
|
||||
self._task_connect_name = None
|
||||
self._set_task(task)
|
||||
|
||||
def get_weights(self, raise_on_error=False, force_download=False):
|
||||
# type: (bool, bool) -> str
|
||||
def get_weights(self, raise_on_error=False, force_download=False, extract_archive=False):
|
||||
# type: (bool, bool, bool) -> str
|
||||
"""
|
||||
Download the base model and return the locally stored filename.
|
||||
|
||||
@ -373,17 +373,19 @@ class BaseModel(object):
|
||||
:param bool force_download: If True, the base model will be downloaded,
|
||||
even if the base model is already cached.
|
||||
|
||||
:param bool extract_archive: If True, the downloaded weights file will be extracted if possible
|
||||
|
||||
:return: The locally stored file.
|
||||
"""
|
||||
# download model (synchronously) and return local file
|
||||
return self._get_base_model().download_model_weights(
|
||||
raise_on_error=raise_on_error, force_download=force_download
|
||||
raise_on_error=raise_on_error, force_download=force_download, extract_archive=extract_archive
|
||||
)
|
||||
|
||||
def get_weights_package(
|
||||
self, return_path=False, raise_on_error=False, force_download=False
|
||||
self, return_path=False, raise_on_error=False, force_download=False, extract_archive=True
|
||||
):
|
||||
# type: (bool, bool, bool) -> Optional[Union[str, List[Path]]]
|
||||
# type: (bool, bool, bool, bool) -> Optional[Union[str, List[Path]]]
|
||||
"""
|
||||
Download the base model package into a temporary directory (extract the files), or return a list of the
|
||||
locally stored filenames.
|
||||
@ -399,6 +401,8 @@ class BaseModel(object):
|
||||
:param bool force_download: If True, the base artifact will be downloaded,
|
||||
even if the artifact is already cached.
|
||||
|
||||
:param bool extract_archive: If True, the downloaded weights file will be extracted if possible
|
||||
|
||||
:return: The model weights, or a list of the locally stored filenames.
|
||||
if raise_on_error=False, returns None on error.
|
||||
"""
|
||||
@ -407,40 +411,21 @@ class BaseModel(object):
|
||||
raise ValueError("Model is not packaged")
|
||||
|
||||
# download packaged model
|
||||
packed_file = self.get_weights(
|
||||
raise_on_error=raise_on_error, force_download=force_download
|
||||
model_path = self.get_weights(
|
||||
raise_on_error=raise_on_error, force_download=force_download, extract_archive=extract_archive
|
||||
)
|
||||
|
||||
if not packed_file:
|
||||
if not model_path:
|
||||
if raise_on_error:
|
||||
raise ValueError(
|
||||
"Model package '{}' could not be downloaded".format(self.url)
|
||||
)
|
||||
return None
|
||||
|
||||
# unpack
|
||||
target_folder = mkdtemp(prefix="model_package_")
|
||||
if not target_folder:
|
||||
raise ValueError(
|
||||
"cannot create temporary directory for packed weight files"
|
||||
)
|
||||
|
||||
for func in (zipfile.ZipFile, tarfile.open):
|
||||
try:
|
||||
obj = func(packed_file)
|
||||
obj.extractall(path=target_folder)
|
||||
break
|
||||
except (zipfile.BadZipfile, tarfile.ReadError):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"cannot extract files from packaged model at %s", packed_file
|
||||
)
|
||||
|
||||
if return_path:
|
||||
return target_folder
|
||||
return model_path
|
||||
|
||||
target_files = list(Path(target_folder).glob("*"))
|
||||
target_files = list(Path(model_path).glob("*"))
|
||||
return target_files
|
||||
|
||||
def report_scalar(self, title, series, value, iteration):
|
||||
@ -1374,17 +1359,18 @@ class Model(BaseModel):
|
||||
self._base_model = None
|
||||
|
||||
def get_local_copy(
|
||||
self, extract_archive=True, raise_on_error=False, force_download=False
|
||||
self, extract_archive=None, raise_on_error=False, force_download=False
|
||||
):
|
||||
# type: (bool, bool, bool) -> str
|
||||
# type: (Optional[bool], bool, bool) -> str
|
||||
"""
|
||||
Retrieve a valid link to the model file(s).
|
||||
If the model URL is a file system link, it will be returned directly.
|
||||
If the model URL points to a remote location (http/s3/gs etc.),
|
||||
it will download the file(s) and return the temporary location of the downloaded model.
|
||||
|
||||
:param bool extract_archive: If True, and the model is of type 'packaged' (e.g. TensorFlow compressed folder)
|
||||
The returned path will be a temporary folder containing the archive content
|
||||
:param bool extract_archive: If True, the local copy will be extracted if possible. If False,
|
||||
the local copy will not be extracted. If None (default), the downloaded file will be extracted
|
||||
if the model is a package.
|
||||
:param bool raise_on_error: If True, and the artifact could not be downloaded,
|
||||
raise ValueError, otherwise return None on failure and output log warning.
|
||||
:param bool force_download: If True, the artifact will be downloaded,
|
||||
@ -1392,14 +1378,17 @@ class Model(BaseModel):
|
||||
|
||||
:return: A local path to the model (or a downloaded copy of it).
|
||||
"""
|
||||
if extract_archive and self._is_package():
|
||||
if self._is_package():
|
||||
return self.get_weights_package(
|
||||
return_path=True,
|
||||
raise_on_error=raise_on_error,
|
||||
force_download=force_download,
|
||||
extract_archive=True if extract_archive is None else extract_archive
|
||||
)
|
||||
return self.get_weights(
|
||||
raise_on_error=raise_on_error, force_download=force_download
|
||||
raise_on_error=raise_on_error,
|
||||
force_download=force_download,
|
||||
extract_archive=False if extract_archive is None else extract_archive
|
||||
)
|
||||
|
||||
def _get_base_model(self):
|
||||
|
Loading…
Reference in New Issue
Block a user