diff --git a/.gitignore b/.gitignore index c9c5075f..24a4ec15 100644 --- a/.gitignore +++ b/.gitignore @@ -11,9 +11,14 @@ build/ dist/ *.egg-info .env +venv/ +.venv/ # example data examples/runs/ examples/*_data examples/frameworks/data -venv/ \ No newline at end of file + +# vscode +.workspace/ +.vscode/ diff --git a/clearml/backend_interface/model.py b/clearml/backend_interface/model.py index 15f09d7e..f1bc3b2b 100644 --- a/clearml/backend_interface/model.py +++ b/clearml/backend_interface/model.py @@ -388,13 +388,14 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): return False return bool(self.data.ready) - def download_model_weights(self, raise_on_error=False): + def download_model_weights(self, raise_on_error=False, force_download=False): """ Download the model weights into a local file in our cache :param bool raise_on_error: If True and the artifact could not be downloaded, raise ValueError, otherwise return None on failure and output log warning. - + :param bool force_download: If True, the base artifact will be downloaded, + even if the artifact is already cached. :return: a local path to a downloaded copy of the model """ uri = self.data.uri @@ -404,12 +405,12 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): # check if we already downloaded the file downloaded_models = [k for k, (i, u) in Model._local_model_to_id_uri.items() if i == self.id and u == uri] for dl_file in downloaded_models: - if Path(dl_file).exists(): + if Path(dl_file).exists() and not force_download: return dl_file # remove non existing model file Model._local_model_to_id_uri.pop(dl_file, None) - local_download = StorageManager.get_local_copy(uri, extract_archive=False) + local_download = StorageManager.get_local_copy(uri, extract_archive=False, force_download=force_download) # save local model, so we can later query what was the original one if local_download is not None: diff --git a/clearml/model.py b/clearml/model.py index 1dc2e942..223a1989 100644 --- a/clearml/model.py +++ b/clearml/model.py @@ -330,21 +330,26 @@ class BaseModel(object): self._reload_required = False self._set_task(task) - def get_weights(self, raise_on_error=False): - # type: (bool) -> str + def get_weights(self, raise_on_error=False, force_download=False): + # type: (bool, bool) -> str """ Download the base model and return the locally stored filename. :param bool raise_on_error: If True, and the artifact could not be downloaded, raise ValueError, otherwise return None on failure and output log warning. + :param bool force_download: If True, the base model will be downloaded, + even if the base model is already cached. + :return: The locally stored file. """ # download model (synchronously) and return local file - return self._get_base_model().download_model_weights(raise_on_error=raise_on_error) + return self._get_base_model().download_model_weights(raise_on_error=raise_on_error, force_download=force_download) - def get_weights_package(self, return_path=False, raise_on_error=False): - # type: (bool, bool) -> Optional[Union[str, List[Path]]] + def get_weights_package( + self, return_path=False, raise_on_error=False, force_download=False + ): + # type: (bool, bool, bool) -> Optional[Union[str, List[Path]]] """ Download the base model package into a temporary directory (extract the files), or return a list of the locally stored filenames. @@ -356,7 +361,8 @@ class BaseModel(object): :param bool raise_on_error: If True, and the artifact could not be downloaded, raise ValueError, otherwise return None on failure and output log warning. - + :param bool force_download: If True, the base artifact will be downloaded, + even if the artifact is already cached. :return: The model weights, or a list of the locally stored filenames. if raise_on_error=False, returns None on error. """ @@ -365,7 +371,7 @@ class BaseModel(object): raise ValueError('Model is not packaged') # download packaged model - packed_file = self.get_weights(raise_on_error=raise_on_error) + packed_file = self.get_weights(raise_on_error=raise_on_error, force_download=force_download) if not packed_file: if raise_on_error: @@ -580,8 +586,10 @@ class Model(BaseModel): self._base_model_id = model_id self._base_model = None - def get_local_copy(self, extract_archive=True, raise_on_error=False): - # type: (bool, bool) -> str + def get_local_copy( + self, extract_archive=True, raise_on_error=False, force_download=False + ): + # type: (bool, bool, bool) -> str """ Retrieve a valid link to the model file(s). If the model URL is a file system link, it will be returned directly. @@ -592,12 +600,20 @@ class Model(BaseModel): The returned path will be a temporary folder containing the archive content :param bool raise_on_error: If True, and the artifact could not be downloaded, raise ValueError, otherwise return None on failure and output log warning. + :param bool force_download: If True, the artifact will be downloaded, + even if the model artifact is already cached. :return: A local path to the model (or a downloaded copy of it). """ if extract_archive and self._is_package(): - return self.get_weights_package(return_path=True, raise_on_error=raise_on_error) - return self.get_weights(raise_on_error=raise_on_error) + return self.get_weights_package( + return_path=True, + raise_on_error=raise_on_error, + force_download=force_download, + ) + return self.get_weights( + raise_on_error=raise_on_error, force_download=force_download + ) def _get_base_model(self): if self._base_model: