Add StorageHelper.get_local_copy to quickly download and provide local path to remote files (http/s3/gs/azure support)

This commit is contained in:
allegroai 2019-09-23 18:40:56 +03:00
parent 0b4f00af4d
commit 4f1eeb49c6

View File

@ -12,6 +12,7 @@ from concurrent.futures import ThreadPoolExecutor
from copy import copy
from datetime import datetime
from multiprocessing.pool import ThreadPool
from tempfile import mktemp
from time import time
from types import GeneratorType
@ -215,6 +216,22 @@ class StorageHelper(object):
cls._helpers[instance_key] = instance
return instance
@classmethod
def get_local_copy(cls, remote_url):
"""
Download a file from remote URL to a local storage, and return path to local copy,
:param remote_url: Remote URL. Example: https://example.com/file.jpg s3://bucket/folder/file.mp4 etc.
:return: Path to local copy of the downloaded file. None if error occurred.
"""
helper = cls.get(remote_url)
if not helper:
return None
# create temp file with the requested file name
file_name = '.' + remote_url.split('/')[-1].split(os.path.sep)[-1]
local_path = mktemp(suffix=file_name)
return helper.download_to_file(remote_url, local_path)
def __init__(self, base_url, url, key=None, secret=None, region=None, verbose=False, logger=None, retries=5,
**kwargs):
self._log = logger or log
@ -540,7 +557,7 @@ class StorageHelper(object):
if self._verbose:
self._log.info('Start downloading from %s' % remote_path)
if not overwrite_existing and Path(local_path).is_file():
self._log.warn(
self._log.warning(
'File {} already exists, no need to download, thread id = {}'.format(
local_path,
threading.current_thread().ident,
@ -911,7 +928,7 @@ class _HttpDriver(object):
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None):
p = Path(local_path)
if not overwrite_existing and p.is_file():
log.warn('failed saving after download: overwrite=False and file exists (%s)' % str(p))
log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
return
length = 0
with p.open(mode='wb') as f:
@ -1160,7 +1177,7 @@ class _Boto3Driver(object):
import boto3.s3.transfer
p = Path(local_path)
if not overwrite_existing and p.is_file():
log.warn('failed saving after download: overwrite=False and file exists (%s)' % str(p))
log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
return
container = self._containers[obj.container_name]
obj.download_file(str(p),
@ -1360,7 +1377,7 @@ class _GoogleCloudStorageDriver(object):
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None):
p = Path(local_path)
if not overwrite_existing and p.is_file():
log.warn('failed saving after download: overwrite=False and file exists (%s)' % str(p))
log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
return
obj.download_to_filename(str(p))
@ -1498,7 +1515,7 @@ class _AzureBlobServiceStorageDriver(object):
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None):
p = Path(local_path)
if not overwrite_existing and p.is_file():
log.warn('failed saving after download: overwrite=False and file exists (%s)' % str(p))
log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
return
download_done = threading.Event()