From b2bb04899c4976de254df7be87731f65bd0caa31 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 24 Oct 2023 18:45:18 +0300 Subject: [PATCH] Use StorageManager cache when downloading models --- clearml/backend_interface/model.py | 6 ++- clearml/model.py | 59 ++++++++++++------------------ 2 files changed, 28 insertions(+), 37 deletions(-) diff --git a/clearml/backend_interface/model.py b/clearml/backend_interface/model.py index 9e16dcbc..bc14258f 100644 --- a/clearml/backend_interface/model.py +++ b/clearml/backend_interface/model.py @@ -527,7 +527,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): return False return bool(self.data.ready) - def download_model_weights(self, raise_on_error=False, force_download=False): + def download_model_weights(self, raise_on_error=False, force_download=False, extract_archive=False): """ Download the model weights into a local file in our cache @@ -537,6 +537,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): :param bool force_download: If True, the base artifact will be downloaded, even if the artifact is already cached. + :param bool extract_archive: If True, unzip the downloaded file if possible + :return: a local path to a downloaded copy of the model """ uri = self.data.uri @@ -556,7 +558,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): Model._local_model_to_id_uri.pop(dl_file, None) local_download = StorageManager.get_local_copy( - uri, extract_archive=False, force_download=force_download + uri, extract_archive=extract_archive, force_download=force_download ) # save local model, so we can later query what was the original one diff --git a/clearml/model.py b/clearml/model.py index f143da82..7f529bda 100644 --- a/clearml/model.py +++ b/clearml/model.py @@ -362,8 +362,8 @@ class BaseModel(object): self._task_connect_name = None self._set_task(task) - def get_weights(self, raise_on_error=False, force_download=False): - # type: (bool, bool) -> str + def get_weights(self, raise_on_error=False, force_download=False, extract_archive=False): + # type: (bool, bool, bool) -> str """ Download the base model and return the locally stored filename. @@ -373,17 +373,19 @@ class BaseModel(object): :param bool force_download: If True, the base model will be downloaded, even if the base model is already cached. + :param bool extract_archive: If True, the downloaded weights file will be extracted if possible + :return: The locally stored file. """ # download model (synchronously) and return local file return self._get_base_model().download_model_weights( - raise_on_error=raise_on_error, force_download=force_download + raise_on_error=raise_on_error, force_download=force_download, extract_archive=extract_archive ) def get_weights_package( - self, return_path=False, raise_on_error=False, force_download=False + self, return_path=False, raise_on_error=False, force_download=False, extract_archive=True ): - # type: (bool, bool, bool) -> Optional[Union[str, List[Path]]] + # type: (bool, bool, bool, bool) -> Optional[Union[str, List[Path]]] """ Download the base model package into a temporary directory (extract the files), or return a list of the locally stored filenames. @@ -399,6 +401,8 @@ class BaseModel(object): :param bool force_download: If True, the base artifact will be downloaded, even if the artifact is already cached. + :param bool extract_archive: If True, the downloaded weights file will be extracted if possible + :return: The model weights, or a list of the locally stored filenames. if raise_on_error=False, returns None on error. """ @@ -407,40 +411,21 @@ class BaseModel(object): raise ValueError("Model is not packaged") # download packaged model - packed_file = self.get_weights( - raise_on_error=raise_on_error, force_download=force_download + model_path = self.get_weights( + raise_on_error=raise_on_error, force_download=force_download, extract_archive=extract_archive ) - if not packed_file: + if not model_path: if raise_on_error: raise ValueError( "Model package '{}' could not be downloaded".format(self.url) ) return None - # unpack - target_folder = mkdtemp(prefix="model_package_") - if not target_folder: - raise ValueError( - "cannot create temporary directory for packed weight files" - ) - - for func in (zipfile.ZipFile, tarfile.open): - try: - obj = func(packed_file) - obj.extractall(path=target_folder) - break - except (zipfile.BadZipfile, tarfile.ReadError): - pass - else: - raise ValueError( - "cannot extract files from packaged model at %s", packed_file - ) - if return_path: - return target_folder + return model_path - target_files = list(Path(target_folder).glob("*")) + target_files = list(Path(model_path).glob("*")) return target_files def report_scalar(self, title, series, value, iteration): @@ -1374,17 +1359,18 @@ class Model(BaseModel): self._base_model = None def get_local_copy( - self, extract_archive=True, raise_on_error=False, force_download=False + self, extract_archive=None, raise_on_error=False, force_download=False ): - # type: (bool, bool, bool) -> str + # type: (Optional[bool], bool, bool) -> str """ Retrieve a valid link to the model file(s). If the model URL is a file system link, it will be returned directly. If the model URL points to a remote location (http/s3/gs etc.), it will download the file(s) and return the temporary location of the downloaded model. - :param bool extract_archive: If True, and the model is of type 'packaged' (e.g. TensorFlow compressed folder) - The returned path will be a temporary folder containing the archive content + :param bool extract_archive: If True, the local copy will be extracted if possible. If False, + the local copy will not be extracted. If None (default), the downloaded file will be extracted + if the model is a package. :param bool raise_on_error: If True, and the artifact could not be downloaded, raise ValueError, otherwise return None on failure and output log warning. :param bool force_download: If True, the artifact will be downloaded, @@ -1392,14 +1378,17 @@ class Model(BaseModel): :return: A local path to the model (or a downloaded copy of it). """ - if extract_archive and self._is_package(): + if self._is_package(): return self.get_weights_package( return_path=True, raise_on_error=raise_on_error, force_download=force_download, + extract_archive=True if extract_archive is None else extract_archive ) return self.get_weights( - raise_on_error=raise_on_error, force_download=force_download + raise_on_error=raise_on_error, + force_download=force_download, + extract_archive=False if extract_archive is None else extract_archive ) def _get_base_model(self):