Use StorageManager cache when downloading models

This commit is contained in:
allegroai 2023-10-24 18:45:18 +03:00
parent b8ceba38dc
commit b2bb04899c
2 changed files with 28 additions and 37 deletions

View File

@ -527,7 +527,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return False
return bool(self.data.ready)
def download_model_weights(self, raise_on_error=False, force_download=False):
def download_model_weights(self, raise_on_error=False, force_download=False, extract_archive=False):
"""
Download the model weights into a local file in our cache
@ -537,6 +537,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
:param bool force_download: If True, the base artifact will be downloaded,
even if the artifact is already cached.
:param bool extract_archive: If True, unzip the downloaded file if possible
:return: a local path to a downloaded copy of the model
"""
uri = self.data.uri
@ -556,7 +558,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
Model._local_model_to_id_uri.pop(dl_file, None)
local_download = StorageManager.get_local_copy(
uri, extract_archive=False, force_download=force_download
uri, extract_archive=extract_archive, force_download=force_download
)
# save local model, so we can later query what was the original one

View File

@ -362,8 +362,8 @@ class BaseModel(object):
self._task_connect_name = None
self._set_task(task)
def get_weights(self, raise_on_error=False, force_download=False):
# type: (bool, bool) -> str
def get_weights(self, raise_on_error=False, force_download=False, extract_archive=False):
# type: (bool, bool, bool) -> str
"""
Download the base model and return the locally stored filename.
@ -373,17 +373,19 @@ class BaseModel(object):
:param bool force_download: If True, the base model will be downloaded,
even if the base model is already cached.
:param bool extract_archive: If True, the downloaded weights file will be extracted if possible
:return: The locally stored file.
"""
# download model (synchronously) and return local file
return self._get_base_model().download_model_weights(
raise_on_error=raise_on_error, force_download=force_download
raise_on_error=raise_on_error, force_download=force_download, extract_archive=extract_archive
)
def get_weights_package(
self, return_path=False, raise_on_error=False, force_download=False
self, return_path=False, raise_on_error=False, force_download=False, extract_archive=True
):
# type: (bool, bool, bool) -> Optional[Union[str, List[Path]]]
# type: (bool, bool, bool, bool) -> Optional[Union[str, List[Path]]]
"""
Download the base model package into a temporary directory (extract the files), or return a list of the
locally stored filenames.
@ -399,6 +401,8 @@ class BaseModel(object):
:param bool force_download: If True, the base artifact will be downloaded,
even if the artifact is already cached.
:param bool extract_archive: If True, the downloaded weights file will be extracted if possible
:return: The model weights, or a list of the locally stored filenames.
if raise_on_error=False, returns None on error.
"""
@ -407,40 +411,21 @@ class BaseModel(object):
raise ValueError("Model is not packaged")
# download packaged model
packed_file = self.get_weights(
raise_on_error=raise_on_error, force_download=force_download
model_path = self.get_weights(
raise_on_error=raise_on_error, force_download=force_download, extract_archive=extract_archive
)
if not packed_file:
if not model_path:
if raise_on_error:
raise ValueError(
"Model package '{}' could not be downloaded".format(self.url)
)
return None
# unpack
target_folder = mkdtemp(prefix="model_package_")
if not target_folder:
raise ValueError(
"cannot create temporary directory for packed weight files"
)
for func in (zipfile.ZipFile, tarfile.open):
try:
obj = func(packed_file)
obj.extractall(path=target_folder)
break
except (zipfile.BadZipfile, tarfile.ReadError):
pass
else:
raise ValueError(
"cannot extract files from packaged model at %s", packed_file
)
if return_path:
return target_folder
return model_path
target_files = list(Path(target_folder).glob("*"))
target_files = list(Path(model_path).glob("*"))
return target_files
def report_scalar(self, title, series, value, iteration):
@ -1374,17 +1359,18 @@ class Model(BaseModel):
self._base_model = None
def get_local_copy(
self, extract_archive=True, raise_on_error=False, force_download=False
self, extract_archive=None, raise_on_error=False, force_download=False
):
# type: (bool, bool, bool) -> str
# type: (Optional[bool], bool, bool) -> str
"""
Retrieve a valid link to the model file(s).
If the model URL is a file system link, it will be returned directly.
If the model URL points to a remote location (http/s3/gs etc.),
it will download the file(s) and return the temporary location of the downloaded model.
:param bool extract_archive: If True, and the model is of type 'packaged' (e.g. TensorFlow compressed folder)
The returned path will be a temporary folder containing the archive content
:param bool extract_archive: If True, the local copy will be extracted if possible. If False,
the local copy will not be extracted. If None (default), the downloaded file will be extracted
if the model is a package.
:param bool raise_on_error: If True, and the artifact could not be downloaded,
raise ValueError, otherwise return None on failure and output log warning.
:param bool force_download: If True, the artifact will be downloaded,
@ -1392,14 +1378,17 @@ class Model(BaseModel):
:return: A local path to the model (or a downloaded copy of it).
"""
if extract_archive and self._is_package():
if self._is_package():
return self.get_weights_package(
return_path=True,
raise_on_error=raise_on_error,
force_download=force_download,
extract_archive=True if extract_archive is None else extract_archive
)
return self.get_weights(
raise_on_error=raise_on_error, force_download=force_download
raise_on_error=raise_on_error,
force_download=force_download,
extract_archive=False if extract_archive is None else extract_archive
)
def _get_base_model(self):