Fix output_uri support for local folders

This commit is contained in:
allegroai 2019-10-15 22:35:37 +03:00
parent 67fc8e3eb0
commit 1a658e9d89
2 changed files with 118 additions and 32 deletions

View File

@ -392,6 +392,12 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
fd, local_filename = mkstemp(suffix='.'+ext)
os.close(fd)
local_download = helper.download_to_file(uri, local_path=local_filename, overwrite_existing=True, verbose=True)
# if we ended up without any local copy, delete the temp file
if local_download != local_filename:
try:
Path(local_filename).unlink()
except Exception:
pass
# save local model, so we can later query what was the original one
Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri)
return local_download

View File

@ -7,6 +7,7 @@ import os
import shutil
import sys
import threading
from abc import ABCMeta, abstractmethod
from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor
from copy import copy
@ -80,6 +81,50 @@ class _DownloadProgressReport(object):
(self.downloaded_mb, self._total_size, speed, self._remote_path))
@six.add_metaclass(ABCMeta)
class _Driver(object):
@abstractmethod
def get_container(self, container_name, config=None, **kwargs):
pass
@abstractmethod
def test_upload(self, test_path, config, **kwargs):
pass
@abstractmethod
def upload_object_via_stream(self, iterator, container, object_name, extra, **kwargs):
pass
@abstractmethod
def list_container_objects(self, container, ex_prefix, **kwargs):
pass
@abstractmethod
def get_direct_access(self, remote_path, **kwargs):
pass
@abstractmethod
def download_object(self, obj, local_path, overwrite_existing, delete_on_failure, callback, **kwargs):
pass
@abstractmethod
def download_object_as_stream(self, obj, chunk_size, **kwargs):
pass
@abstractmethod
def delete_object(self, obj, **kwargs):
pass
@abstractmethod
def upload_object(self, file_path, container, object_name, extra, **kwargs):
pass
@abstractmethod
def get_object(self, container_name, object_name, **kwargs):
pass
class StorageHelper(object):
""" Storage helper.
Used by the entire system to download/upload files.
@ -572,6 +617,11 @@ class StorageHelper(object):
remote_path = self._canonize_url(remote_path)
verbose = self._verbose if verbose is None else verbose
# Check if driver type supports direct access:
direct_access_path = self._driver.get_direct_access(remote_path)
if direct_access_path:
return direct_access_path
temp_local_path = None
try:
if verbose:
@ -893,8 +943,9 @@ class StorageHelper(object):
return None
class _HttpDriver(object):
class _HttpDriver(_Driver):
""" LibCloud http/https adapter (simple, enough for now) """
timeout = (5.0, 30.)
class _Container(object):
@ -920,7 +971,7 @@ class _HttpDriver(object):
self._retries = retries
self._containers = {}
def get_container(self, container_name, *_, **kwargs):
def get_container(self, container_name, config=None, **kwargs):
if container_name not in self._containers:
self._containers[container_name] = self._Container(name=container_name, retries=self._retries, **kwargs)
return self._containers[container_name]
@ -951,11 +1002,11 @@ class _HttpDriver(object):
raise ValueError('Failed getting object %s (%d): %s' % (object_name, res.status_code, res.text))
return res
def download_object_as_stream(self, obj, chunk_size=64 * 1024):
def download_object_as_stream(self, obj, chunk_size=64 * 1024, **_):
# return iterable object
return obj.iter_content(chunk_size=chunk_size)
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None):
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, **_):
p = Path(local_path)
if not overwrite_existing and p.is_file():
log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
@ -974,6 +1025,17 @@ class _HttpDriver(object):
return length
def get_direct_access(self, remote_path, **_):
return None
def test_upload(self, test_path, config, **kwargs):
return True
def upload_object(self, file_path, container, object_name, extra, **kwargs):
with open(file_path, 'rb') as stream:
return self.upload_object_via_stream(iterator=stream, container=container,
object_name=object_name, extra=extra, **kwargs)
class _Stream(object):
encoding = None
@ -1068,8 +1130,9 @@ class _Stream(object):
self.write(s)
class _Boto3Driver(object):
class _Boto3Driver(_Driver):
""" Boto3 storage adapter (simple, enough for now) """
_max_multipart_concurrency = config.get('aws.boto3.max_multipart_concurrency', 16)
_min_pool_connections = 512
@ -1134,9 +1197,9 @@ class _Boto3Driver(object):
self._stream_download_pool = ThreadPoolExecutor(max_workers=self._stream_download_pool_connections)
return self._stream_download_pool
def get_container(self, container_name, *_, **kwargs):
def get_container(self, container_name, config=None, **kwargs):
if container_name not in self._containers:
self._containers[container_name] = self._Container(name=container_name, cfg=kwargs.get('config'))
self._containers[container_name] = self._Container(name=container_name, cfg=config)
self._containers[container_name].config.retries = kwargs.get('retries', 5)
return self._containers[container_name]
@ -1183,7 +1246,7 @@ class _Boto3Driver(object):
obj.container_name = full_container_name
return obj
def download_object_as_stream(self, obj, chunk_size=64 * 1024):
def download_object_as_stream(self, obj, chunk_size=64 * 1024, **_):
def async_download(a_obj, a_stream, cfg):
try:
a_obj.download_fileobj(a_stream, Config=cfg)
@ -1203,7 +1266,7 @@ class _Boto3Driver(object):
return stream
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None):
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, **_):
import boto3.s3.transfer
p = Path(local_path)
if not overwrite_existing and p.is_file():
@ -1311,8 +1374,14 @@ class _Boto3Driver(object):
return None
def get_direct_access(self, remote_path, **_):
return None
class _GoogleCloudStorageDriver(object):
def test_upload(self, test_path, config, **_):
return True
class _GoogleCloudStorageDriver(_Driver):
"""Storage driver for google cloud storage"""
_stream_download_pool_connections = 128
@ -1350,9 +1419,9 @@ class _GoogleCloudStorageDriver(object):
self._stream_download_pool = ThreadPoolExecutor(max_workers=self._stream_download_pool_connections)
return self._stream_download_pool
def get_container(self, container_name, *_, **kwargs):
def get_container(self, container_name, config=None, **kwargs):
if container_name not in self._containers:
self._containers[container_name] = self._Container(name=container_name, cfg=kwargs.get('config'))
self._containers[container_name] = self._Container(name=container_name, cfg=config)
self._containers[container_name].config.retries = kwargs.get('retries', 5)
return self._containers[container_name]
@ -1387,7 +1456,7 @@ class _GoogleCloudStorageDriver(object):
obj.container_name = full_container_name
return obj
def download_object_as_stream(self, obj, chunk_size=256 * 1024):
def download_object_as_stream(self, obj, chunk_size=256 * 1024, **_):
raise NotImplementedError('Unsupported for google storage')
def async_download(a_obj, a_stream):
@ -1404,14 +1473,14 @@ class _GoogleCloudStorageDriver(object):
return stream
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None):
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, **_):
p = Path(local_path)
if not overwrite_existing and p.is_file():
log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
return
obj.download_to_filename(str(p))
def test_upload(self, test_path, config):
def test_upload(self, test_path, config, **_):
bucket_url = str(furl(scheme=self.scheme, netloc=config.bucket, path=config.subdir))
bucket = self.get_container(container_name=bucket_url, config=config).bucket
@ -1429,8 +1498,11 @@ class _GoogleCloudStorageDriver(object):
permissions_to_test = ('storage.objects.get', 'storage.objects.update')
return set(test_obj.test_iam_permissions(permissions_to_test)) == set(permissions_to_test)
def get_direct_access(self, remote_path, **_):
return None
class _AzureBlobServiceStorageDriver(object):
class _AzureBlobServiceStorageDriver(_Driver):
scheme = 'azure'
_containers = {}
@ -1459,8 +1531,8 @@ class _AzureBlobServiceStorageDriver(object):
blob_name = attrib()
content_length = attrib()
def get_container(self, config, *_, **kwargs):
container_name = config.container_name
def get_container(self, container_name=None, config=None, **kwargs):
container_name = container_name or config.container_name
if container_name not in self._containers:
self._containers[container_name] = self._Container(name=container_name, config=config)
# self._containers[container_name].config.retries = kwargs.get('retries', 5)
@ -1542,7 +1614,7 @@ class _AzureBlobServiceStorageDriver(object):
)
return blob.content
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None):
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, **_):
p = Path(local_path)
if not overwrite_existing and p.is_file():
log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
@ -1570,7 +1642,7 @@ class _AzureBlobServiceStorageDriver(object):
)
download_done.wait()
def test_upload(self, test_path, config):
def test_upload(self, test_path, config, **_):
container = self.get_container(config=config)
try:
container.blob_service.get_container_properties(container.name)
@ -1620,8 +1692,11 @@ class _AzureBlobServiceStorageDriver(object):
return name
def get_direct_access(self, remote_path, **_):
return None
class _FileStorageDriver(object):
class _FileStorageDriver(_Driver):
"""
A base StorageDriver to derive from.
"""
@ -1775,7 +1850,7 @@ class _FileStorageDriver(object):
return self._get_objects(container)
def get_container(self, container_name):
def get_container(self, container_name, **_):
"""
Return a container instance.
@ -1807,7 +1882,7 @@ class _FileStorageDriver(object):
return path
def get_object(self, container_name, object_name):
def get_object(self, container_name, object_name, **_):
"""
Return an object instance.
@ -1835,8 +1910,7 @@ class _FileStorageDriver(object):
"""
return os.path.realpath(os.path.join(self.base_path, obj.container.name, obj.name))
def download_object(self, obj, destination_path, overwrite_existing=False,
delete_on_failure=True):
def download_object(self, obj, destination_path, overwrite_existing=False, delete_on_failure=True, **_):
"""
Download an object to the specified destination path.
@ -1888,7 +1962,7 @@ class _FileStorageDriver(object):
return True
def download_object_as_stream(self, obj, chunk_size=None):
def download_object_as_stream(self, obj, chunk_size=None, **_):
"""
Return a generator which yields object data.
@ -1906,8 +1980,7 @@ class _FileStorageDriver(object):
for data in self._read_in_chunks(obj_file, chunk_size=chunk_size):
yield data
def upload_object(self, file_path, container, object_name, extra=None,
verify_hash=True):
def upload_object(self, file_path, container, object_name, extra=None, verify_hash=True, **_):
"""
Upload an object currently located on a disk.
@ -1941,9 +2014,7 @@ class _FileStorageDriver(object):
return self._make_object(container, object_name)
def upload_object_via_stream(self, iterator, container,
object_name,
extra=None):
def upload_object_via_stream(self, iterator, container, object_name, extra=None, **kwargs):
"""
Upload an object using an iterator.
@ -1990,7 +2061,7 @@ class _FileStorageDriver(object):
os.chmod(obj_path, int('664', 8))
return self._make_object(container, object_name)
def delete_object(self, obj):
def delete_object(self, obj, **_):
"""
Delete an object.
@ -2144,3 +2215,12 @@ class _FileStorageDriver(object):
else:
yield data
data = bytes('')
def get_direct_access(self, remote_path, **_):
# this will always make sure we have full path and file:// prefix
full_url = StorageHelper.conform_url(remote_path)
# now get rid of the file:// prefix
return Path(full_url[7:]).as_posix()
def test_upload(self, test_path, config, **kwargs):
return True