diff --git a/trains/storage/helper.py b/trains/storage/helper.py index 85a8c6ae..9cb5b848 100644 --- a/trains/storage/helper.py +++ b/trains/storage/helper.py @@ -12,6 +12,7 @@ from concurrent.futures import ThreadPoolExecutor from copy import copy from datetime import datetime from multiprocessing.pool import ThreadPool +from tempfile import mktemp from time import time from types import GeneratorType @@ -215,6 +216,22 @@ class StorageHelper(object): cls._helpers[instance_key] = instance return instance + @classmethod + def get_local_copy(cls, remote_url): + """ + Download a file from remote URL to a local storage, and return path to local copy, + + :param remote_url: Remote URL. Example: https://example.com/file.jpg s3://bucket/folder/file.mp4 etc. + :return: Path to local copy of the downloaded file. None if error occurred. + """ + helper = cls.get(remote_url) + if not helper: + return None + # create temp file with the requested file name + file_name = '.' + remote_url.split('/')[-1].split(os.path.sep)[-1] + local_path = mktemp(suffix=file_name) + return helper.download_to_file(remote_url, local_path) + def __init__(self, base_url, url, key=None, secret=None, region=None, verbose=False, logger=None, retries=5, **kwargs): self._log = logger or log @@ -540,7 +557,7 @@ class StorageHelper(object): if self._verbose: self._log.info('Start downloading from %s' % remote_path) if not overwrite_existing and Path(local_path).is_file(): - self._log.warn( + self._log.warning( 'File {} already exists, no need to download, thread id = {}'.format( local_path, threading.current_thread().ident, @@ -911,7 +928,7 @@ class _HttpDriver(object): def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None): p = Path(local_path) if not overwrite_existing and p.is_file(): - log.warn('failed saving after download: overwrite=False and file exists (%s)' % str(p)) + log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p)) return length = 0 with p.open(mode='wb') as f: @@ -1160,7 +1177,7 @@ class _Boto3Driver(object): import boto3.s3.transfer p = Path(local_path) if not overwrite_existing and p.is_file(): - log.warn('failed saving after download: overwrite=False and file exists (%s)' % str(p)) + log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p)) return container = self._containers[obj.container_name] obj.download_file(str(p), @@ -1360,7 +1377,7 @@ class _GoogleCloudStorageDriver(object): def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None): p = Path(local_path) if not overwrite_existing and p.is_file(): - log.warn('failed saving after download: overwrite=False and file exists (%s)' % str(p)) + log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p)) return obj.download_to_filename(str(p)) @@ -1498,7 +1515,7 @@ class _AzureBlobServiceStorageDriver(object): def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None): p = Path(local_path) if not overwrite_existing and p.is_file(): - log.warn('failed saving after download: overwrite=False and file exists (%s)' % str(p)) + log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p)) return download_done = threading.Event()