Add StorageManager.upload_folder() and StorageManager.download_folder()

This commit is contained in:
allegroai 2021-01-18 11:35:23 +02:00
parent d06504c32c
commit 8c4d3aca5b

View File

@ -1,5 +1,8 @@
import fnmatch
import os
import shutil
import tarfile
from multiprocessing.pool import ThreadPool
from random import random
from time import time
from typing import Optional
@ -7,9 +10,10 @@ from zipfile import ZipFile
from pathlib2 import Path
from .cache import CacheManager
from .helper import StorageHelper
from .util import encode_string_to_filename
from ..debugging.log import LoggerRoot
from .cache import CacheManager
class StorageManager(object):
@ -184,3 +188,98 @@ class StorageManager(object):
def get_files_server(cls):
from ..backend_api import Session
return Session.get_files_server_host()
@classmethod
def upload_folder(cls, local_folder, remote_url, match_wildcard=None):
# type: (str, str, Optional[str]) -> None
"""
Upload local folder recursively to a remote storage, maintaining the sub folder structure
in the remote storage. For Example:
If we have a local file: ~/folder/sub/file.ext
StorageManager.upload_folder('~/folder/', 's3://bucket/')
will create: s3://bucket/sub/file.ext
:param local_folder: Local folder to recursively upload
:param remote_url: Target remote storage location, tree structure of `local_folder` will
be created under the target remote_url. Supports Http/S3/GS/Azure and shared filesystem.
Example: 's3://bucket/data/'
:param match_wildcard: If specified only upload files matching the `match_wildcard`
Example: `*.json`
(Notice: target file size/date are not checked). Default True, always upload
Notice if uploading to http, we will always overwrite the target.
"""
base_logger = LoggerRoot.get_base_logger()
if not Path(local_folder).is_dir():
base_logger.error("Local folder '{}' does not exist".format(local_folder))
return
results = []
helper = StorageHelper.get(remote_url)
with ThreadPool() as pool:
for path in Path(local_folder).rglob(match_wildcard or "*"):
if not path.is_file():
continue
results.append(
pool.apply_async(
helper.upload,
args=(str(path), str(path).replace(local_folder, remote_url)),
)
)
for res in results:
res.wait()
@classmethod
def download_folder(
cls, remote_url, local_folder=None, match_wildcard=None, overwrite=False
):
# type: (str, Optional[str], Optional[str], bool) -> Optional[str]
"""
Download remote folder recursively to the local machine, maintaining the sub folder structure
from the remote storage. For Example:
If we have a local file: s3://bucket/sub/file.ext
StorageManager.download_folder('s3://bucket/', '~/folder/')
will create: ~/folder/sub/file.ext
:param remote_url: Source remote storage location, tree structure of `remote_url` will
be created under the target local_folder. Supports S3/GS/Azure and shared filesystem.
Example: 's3://bucket/data/'
:param local_folder: Local target folder to create the full tree from remote_url.
If None, use the cache folder. (Default: use cache folder)
:param match_wildcard: If specified only download files matching the `match_wildcard`
Example: `*.json`
:param overwrite: If False, and target files exist do not download.
If True always download the remote files. Default False.
:return: Target local folder
"""
base_logger = LoggerRoot.get_base_logger()
if local_folder:
try:
Path(local_folder).mkdir(parents=True, exist_ok=True)
except OSError as ex:
base_logger.error("Failed creating local folder '{}': {}".format(local_folder, ex))
return
else:
local_folder = CacheManager.get_cache_manager().get_cache_folder()
helper = StorageHelper.get(remote_url)
results = []
with ThreadPool() as pool:
for path in helper.list(prefix=remote_url):
remote_path = os.path.join(helper.base_url, path)
if match_wildcard and not fnmatch.fnmatch(remote_path, match_wildcard):
continue
local_url = remote_path.replace(remote_url, local_folder)
if not os.path.exists(local_url) or os.path.getsize(local_url) == 0:
results.append(
pool.apply_async(
helper.download_to_file,
args=(remote_path, local_url),
kwds={"overwrite_existing": overwrite},
)
)
for res in results:
res.wait()
return local_folder