diff --git a/trains/storage/cache.py b/trains/storage/cache.py index 4e73db57..7b1043e5 100644 --- a/trains/storage/cache.py +++ b/trains/storage/cache.py @@ -27,7 +27,7 @@ class CacheManager(object): self._file_limit = max(self._file_limit, int(cache_file_limit)) return self._file_limit - def get_local_copy(self, remote_url): + def get_local_copy(self, remote_url, force_download): helper = StorageHelper.get(remote_url) if not helper: raise ValueError("Storage access failed: {}".format(remote_url)) @@ -44,11 +44,11 @@ class CacheManager(object): # check if we already have the file in our cache cached_file, cached_size = self._get_cache_file(remote_url) - if cached_size is not None: + if cached_size is not None and not force_download: CacheManager._add_remote_url(remote_url, cached_file) return cached_file # we need to download the file: - downloaded_file = helper.download_to_file(remote_url, cached_file) + downloaded_file = helper.download_to_file(remote_url, cached_file, overwrite_existing=force_download) if downloaded_file != cached_file: # something happened return None diff --git a/trains/storage/manager.py b/trains/storage/manager.py index 0f5e734d..1ec62dca 100644 --- a/trains/storage/manager.py +++ b/trains/storage/manager.py @@ -21,9 +21,9 @@ class StorageManager(object): @classmethod def get_local_copy( - cls, remote_url, cache_context=None, extract_archive=True, name=None + cls, remote_url, force_download=False, cache_context=None, extract_archive=True, name=None ): - # type: (str, Optional[str], bool, Optional[str]) -> str + # type: (str, bool, Optional[str], bool, Optional[str]) -> str """ Get a local copy of the remote file. If the remote URL is a direct file access, the returned link is the same, otherwise a link to a local copy of the url file is returned. @@ -31,6 +31,7 @@ class StorageManager(object): Oldest accessed files are deleted when cache is full. :param str remote_url: remote url link (string) + :param force_download: download file from remote even if exists in local cache :param str cache_context: Optional caching context identifier (string), default context 'global' :param bool extract_archive: if True returned path will be a cached folder containing the archive's content, currently only zip files are supported. @@ -39,7 +40,7 @@ class StorageManager(object): """ cached_file = CacheManager.get_cache_manager( cache_context=cache_context - ).get_local_copy(remote_url=remote_url) + ).get_local_copy(remote_url=remote_url, force_download=force_download) if not extract_archive or not cached_file: return cached_file return cls._extract_to_cache(cached_file, name)