Add StorageManager.upload_folder() and StorageManager.download_folder()

This commit is contained in:
allegroai 2021-01-18 11:35:23 +02:00
parent d06504c32c
commit 8c4d3aca5b

View File

@ -1,5 +1,8 @@
import fnmatch
import os
import shutil import shutil
import tarfile import tarfile
from multiprocessing.pool import ThreadPool
from random import random from random import random
from time import time from time import time
from typing import Optional from typing import Optional
@ -7,9 +10,10 @@ from zipfile import ZipFile
from pathlib2 import Path from pathlib2 import Path
from .cache import CacheManager
from .helper import StorageHelper
from .util import encode_string_to_filename from .util import encode_string_to_filename
from ..debugging.log import LoggerRoot from ..debugging.log import LoggerRoot
from .cache import CacheManager
class StorageManager(object): class StorageManager(object):
@ -184,3 +188,98 @@ class StorageManager(object):
def get_files_server(cls): def get_files_server(cls):
from ..backend_api import Session from ..backend_api import Session
return Session.get_files_server_host() return Session.get_files_server_host()
@classmethod
def upload_folder(cls, local_folder, remote_url, match_wildcard=None):
# type: (str, str, Optional[str]) -> None
"""
Upload local folder recursively to a remote storage, maintaining the sub folder structure
in the remote storage. For Example:
If we have a local file: ~/folder/sub/file.ext
StorageManager.upload_folder('~/folder/', 's3://bucket/')
will create: s3://bucket/sub/file.ext
:param local_folder: Local folder to recursively upload
:param remote_url: Target remote storage location, tree structure of `local_folder` will
be created under the target remote_url. Supports Http/S3/GS/Azure and shared filesystem.
Example: 's3://bucket/data/'
:param match_wildcard: If specified only upload files matching the `match_wildcard`
Example: `*.json`
(Notice: target file size/date are not checked). Default True, always upload
Notice if uploading to http, we will always overwrite the target.
"""
base_logger = LoggerRoot.get_base_logger()
if not Path(local_folder).is_dir():
base_logger.error("Local folder '{}' does not exist".format(local_folder))
return
results = []
helper = StorageHelper.get(remote_url)
with ThreadPool() as pool:
for path in Path(local_folder).rglob(match_wildcard or "*"):
if not path.is_file():
continue
results.append(
pool.apply_async(
helper.upload,
args=(str(path), str(path).replace(local_folder, remote_url)),
)
)
for res in results:
res.wait()
@classmethod
def download_folder(
cls, remote_url, local_folder=None, match_wildcard=None, overwrite=False
):
# type: (str, Optional[str], Optional[str], bool) -> Optional[str]
"""
Download remote folder recursively to the local machine, maintaining the sub folder structure
from the remote storage. For Example:
If we have a local file: s3://bucket/sub/file.ext
StorageManager.download_folder('s3://bucket/', '~/folder/')
will create: ~/folder/sub/file.ext
:param remote_url: Source remote storage location, tree structure of `remote_url` will
be created under the target local_folder. Supports S3/GS/Azure and shared filesystem.
Example: 's3://bucket/data/'
:param local_folder: Local target folder to create the full tree from remote_url.
If None, use the cache folder. (Default: use cache folder)
:param match_wildcard: If specified only download files matching the `match_wildcard`
Example: `*.json`
:param overwrite: If False, and target files exist do not download.
If True always download the remote files. Default False.
:return: Target local folder
"""
base_logger = LoggerRoot.get_base_logger()
if local_folder:
try:
Path(local_folder).mkdir(parents=True, exist_ok=True)
except OSError as ex:
base_logger.error("Failed creating local folder '{}': {}".format(local_folder, ex))
return
else:
local_folder = CacheManager.get_cache_manager().get_cache_folder()
helper = StorageHelper.get(remote_url)
results = []
with ThreadPool() as pool:
for path in helper.list(prefix=remote_url):
remote_path = os.path.join(helper.base_url, path)
if match_wildcard and not fnmatch.fnmatch(remote_path, match_wildcard):
continue
local_url = remote_path.replace(remote_url, local_folder)
if not os.path.exists(local_url) or os.path.getsize(local_url) == 0:
results.append(
pool.apply_async(
helper.download_to_file,
args=(remote_path, local_url),
kwds={"overwrite_existing": overwrite},
)
)
for res in results:
res.wait()
return local_folder