Add StorageHelper.get_local_copy to quickly download and provide local path to remote files (http/s3/gs/azure support)

This commit is contained in:
allegroai 2019-09-23 18:40:56 +03:00
parent 0b4f00af4d
commit 4f1eeb49c6

View File

@ -12,6 +12,7 @@ from concurrent.futures import ThreadPoolExecutor
from copy import copy from copy import copy
from datetime import datetime from datetime import datetime
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from tempfile import mktemp
from time import time from time import time
from types import GeneratorType from types import GeneratorType
@ -215,6 +216,22 @@ class StorageHelper(object):
cls._helpers[instance_key] = instance cls._helpers[instance_key] = instance
return instance return instance
@classmethod
def get_local_copy(cls, remote_url):
"""
Download a file from remote URL to a local storage, and return path to local copy,
:param remote_url: Remote URL. Example: https://example.com/file.jpg s3://bucket/folder/file.mp4 etc.
:return: Path to local copy of the downloaded file. None if error occurred.
"""
helper = cls.get(remote_url)
if not helper:
return None
# create temp file with the requested file name
file_name = '.' + remote_url.split('/')[-1].split(os.path.sep)[-1]
local_path = mktemp(suffix=file_name)
return helper.download_to_file(remote_url, local_path)
def __init__(self, base_url, url, key=None, secret=None, region=None, verbose=False, logger=None, retries=5, def __init__(self, base_url, url, key=None, secret=None, region=None, verbose=False, logger=None, retries=5,
**kwargs): **kwargs):
self._log = logger or log self._log = logger or log
@ -540,7 +557,7 @@ class StorageHelper(object):
if self._verbose: if self._verbose:
self._log.info('Start downloading from %s' % remote_path) self._log.info('Start downloading from %s' % remote_path)
if not overwrite_existing and Path(local_path).is_file(): if not overwrite_existing and Path(local_path).is_file():
self._log.warn( self._log.warning(
'File {} already exists, no need to download, thread id = {}'.format( 'File {} already exists, no need to download, thread id = {}'.format(
local_path, local_path,
threading.current_thread().ident, threading.current_thread().ident,
@ -911,7 +928,7 @@ class _HttpDriver(object):
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None): def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None):
p = Path(local_path) p = Path(local_path)
if not overwrite_existing and p.is_file(): if not overwrite_existing and p.is_file():
log.warn('failed saving after download: overwrite=False and file exists (%s)' % str(p)) log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
return return
length = 0 length = 0
with p.open(mode='wb') as f: with p.open(mode='wb') as f:
@ -1160,7 +1177,7 @@ class _Boto3Driver(object):
import boto3.s3.transfer import boto3.s3.transfer
p = Path(local_path) p = Path(local_path)
if not overwrite_existing and p.is_file(): if not overwrite_existing and p.is_file():
log.warn('failed saving after download: overwrite=False and file exists (%s)' % str(p)) log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
return return
container = self._containers[obj.container_name] container = self._containers[obj.container_name]
obj.download_file(str(p), obj.download_file(str(p),
@ -1360,7 +1377,7 @@ class _GoogleCloudStorageDriver(object):
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None): def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None):
p = Path(local_path) p = Path(local_path)
if not overwrite_existing and p.is_file(): if not overwrite_existing and p.is_file():
log.warn('failed saving after download: overwrite=False and file exists (%s)' % str(p)) log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
return return
obj.download_to_filename(str(p)) obj.download_to_filename(str(p))
@ -1498,7 +1515,7 @@ class _AzureBlobServiceStorageDriver(object):
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None): def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None):
p = Path(local_path) p = Path(local_path)
if not overwrite_existing and p.is_file(): if not overwrite_existing and p.is_file():
log.warn('failed saving after download: overwrite=False and file exists (%s)' % str(p)) log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
return return
download_done = threading.Event() download_done = threading.Event()