mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Use StorageManager cache when downloading models
This commit is contained in:
		
							parent
							
								
									b8ceba38dc
								
							
						
					
					
						commit
						b2bb04899c
					
				@ -527,7 +527,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
 | 
			
		||||
            return False
 | 
			
		||||
        return bool(self.data.ready)
 | 
			
		||||
 | 
			
		||||
    def download_model_weights(self, raise_on_error=False, force_download=False):
 | 
			
		||||
    def download_model_weights(self, raise_on_error=False, force_download=False, extract_archive=False):
 | 
			
		||||
        """
 | 
			
		||||
        Download the model weights into a local file in our cache
 | 
			
		||||
 | 
			
		||||
@ -537,6 +537,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
 | 
			
		||||
        :param bool force_download: If True, the base artifact will be downloaded,
 | 
			
		||||
            even if the artifact is already cached.
 | 
			
		||||
 | 
			
		||||
        :param bool extract_archive: If True, unzip the downloaded file if possible
 | 
			
		||||
 | 
			
		||||
        :return: a local path to a downloaded copy of the model
 | 
			
		||||
        """
 | 
			
		||||
        uri = self.data.uri
 | 
			
		||||
@ -556,7 +558,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
 | 
			
		||||
            Model._local_model_to_id_uri.pop(dl_file, None)
 | 
			
		||||
 | 
			
		||||
        local_download = StorageManager.get_local_copy(
 | 
			
		||||
            uri, extract_archive=False, force_download=force_download
 | 
			
		||||
            uri, extract_archive=extract_archive, force_download=force_download
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # save local model, so we can later query what was the original one
 | 
			
		||||
 | 
			
		||||
@ -362,8 +362,8 @@ class BaseModel(object):
 | 
			
		||||
        self._task_connect_name = None
 | 
			
		||||
        self._set_task(task)
 | 
			
		||||
 | 
			
		||||
    def get_weights(self, raise_on_error=False, force_download=False):
 | 
			
		||||
        # type: (bool, bool) -> str
 | 
			
		||||
    def get_weights(self, raise_on_error=False, force_download=False, extract_archive=False):
 | 
			
		||||
        # type: (bool, bool, bool) -> str
 | 
			
		||||
        """
 | 
			
		||||
        Download the base model and return the locally stored filename.
 | 
			
		||||
 | 
			
		||||
@ -373,17 +373,19 @@ class BaseModel(object):
 | 
			
		||||
        :param bool force_download: If True, the base model will be downloaded,
 | 
			
		||||
            even if the base model is already cached.
 | 
			
		||||
 | 
			
		||||
        :param bool extract_archive: If True, the downloaded weights file will be extracted if possible
 | 
			
		||||
 | 
			
		||||
        :return: The locally stored file.
 | 
			
		||||
        """
 | 
			
		||||
        # download model (synchronously) and return local file
 | 
			
		||||
        return self._get_base_model().download_model_weights(
 | 
			
		||||
            raise_on_error=raise_on_error, force_download=force_download
 | 
			
		||||
            raise_on_error=raise_on_error, force_download=force_download, extract_archive=extract_archive
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def get_weights_package(
 | 
			
		||||
        self, return_path=False, raise_on_error=False, force_download=False
 | 
			
		||||
        self, return_path=False, raise_on_error=False, force_download=False, extract_archive=True
 | 
			
		||||
    ):
 | 
			
		||||
        # type: (bool, bool, bool) -> Optional[Union[str, List[Path]]]
 | 
			
		||||
        # type: (bool, bool, bool, bool) -> Optional[Union[str, List[Path]]]
 | 
			
		||||
        """
 | 
			
		||||
        Download the base model package into a temporary directory (extract the files), or return a list of the
 | 
			
		||||
        locally stored filenames.
 | 
			
		||||
@ -399,6 +401,8 @@ class BaseModel(object):
 | 
			
		||||
        :param bool force_download: If True, the base artifact will be downloaded,
 | 
			
		||||
            even if the artifact is already cached.
 | 
			
		||||
 | 
			
		||||
        :param bool extract_archive: If True, the downloaded weights file will be extracted if possible
 | 
			
		||||
 | 
			
		||||
        :return: The model weights, or a list of the locally stored filenames.
 | 
			
		||||
            if raise_on_error=False, returns None on error.
 | 
			
		||||
        """
 | 
			
		||||
@ -407,40 +411,21 @@ class BaseModel(object):
 | 
			
		||||
            raise ValueError("Model is not packaged")
 | 
			
		||||
 | 
			
		||||
        # download packaged model
 | 
			
		||||
        packed_file = self.get_weights(
 | 
			
		||||
            raise_on_error=raise_on_error, force_download=force_download
 | 
			
		||||
        model_path = self.get_weights(
 | 
			
		||||
            raise_on_error=raise_on_error, force_download=force_download, extract_archive=extract_archive
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if not packed_file:
 | 
			
		||||
        if not model_path:
 | 
			
		||||
            if raise_on_error:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "Model package '{}' could not be downloaded".format(self.url)
 | 
			
		||||
                )
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        # unpack
 | 
			
		||||
        target_folder = mkdtemp(prefix="model_package_")
 | 
			
		||||
        if not target_folder:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "cannot create temporary directory for packed weight files"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        for func in (zipfile.ZipFile, tarfile.open):
 | 
			
		||||
            try:
 | 
			
		||||
                obj = func(packed_file)
 | 
			
		||||
                obj.extractall(path=target_folder)
 | 
			
		||||
                break
 | 
			
		||||
            except (zipfile.BadZipfile, tarfile.ReadError):
 | 
			
		||||
                pass
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "cannot extract files from packaged model at %s", packed_file
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if return_path:
 | 
			
		||||
            return target_folder
 | 
			
		||||
            return model_path
 | 
			
		||||
 | 
			
		||||
        target_files = list(Path(target_folder).glob("*"))
 | 
			
		||||
        target_files = list(Path(model_path).glob("*"))
 | 
			
		||||
        return target_files
 | 
			
		||||
 | 
			
		||||
    def report_scalar(self, title, series, value, iteration):
 | 
			
		||||
@ -1374,17 +1359,18 @@ class Model(BaseModel):
 | 
			
		||||
        self._base_model = None
 | 
			
		||||
 | 
			
		||||
    def get_local_copy(
 | 
			
		||||
        self, extract_archive=True, raise_on_error=False, force_download=False
 | 
			
		||||
        self, extract_archive=None, raise_on_error=False, force_download=False
 | 
			
		||||
    ):
 | 
			
		||||
        # type: (bool, bool, bool) -> str
 | 
			
		||||
        # type: (Optional[bool], bool, bool) -> str
 | 
			
		||||
        """
 | 
			
		||||
        Retrieve a valid link to the model file(s).
 | 
			
		||||
        If the model URL is a file system link, it will be returned directly.
 | 
			
		||||
        If the model URL points to a remote location (http/s3/gs etc.),
 | 
			
		||||
        it will download the file(s) and return the temporary location of the downloaded model.
 | 
			
		||||
 | 
			
		||||
        :param bool extract_archive: If True, and the model is of type 'packaged' (e.g. TensorFlow compressed folder)
 | 
			
		||||
            The returned path will be a temporary folder containing the archive content
 | 
			
		||||
        :param bool extract_archive: If True, the local copy will be extracted if possible. If False,
 | 
			
		||||
            the local copy will not be extracted. If None (default), the downloaded file will be extracted
 | 
			
		||||
            if the model is a package.
 | 
			
		||||
        :param bool raise_on_error: If True, and the artifact could not be downloaded,
 | 
			
		||||
            raise ValueError, otherwise return None on failure and output log warning.
 | 
			
		||||
        :param bool force_download: If True, the artifact will be downloaded,
 | 
			
		||||
@ -1392,14 +1378,17 @@ class Model(BaseModel):
 | 
			
		||||
 | 
			
		||||
        :return: A local path to the model (or a downloaded copy of it).
 | 
			
		||||
        """
 | 
			
		||||
        if extract_archive and self._is_package():
 | 
			
		||||
        if self._is_package():
 | 
			
		||||
            return self.get_weights_package(
 | 
			
		||||
                return_path=True,
 | 
			
		||||
                raise_on_error=raise_on_error,
 | 
			
		||||
                force_download=force_download,
 | 
			
		||||
                extract_archive=True if extract_archive is None else extract_archive
 | 
			
		||||
            )
 | 
			
		||||
        return self.get_weights(
 | 
			
		||||
            raise_on_error=raise_on_error, force_download=force_download
 | 
			
		||||
            raise_on_error=raise_on_error,
 | 
			
		||||
            force_download=force_download,
 | 
			
		||||
            extract_archive=False if extract_archive is None else extract_archive
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _get_base_model(self):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user