Fix output_uri support for local folders

This commit is contained in:
allegroai 2019-10-15 22:35:37 +03:00
parent 67fc8e3eb0
commit 1a658e9d89
2 changed files with 118 additions and 32 deletions

View File

@ -392,6 +392,12 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
fd, local_filename = mkstemp(suffix='.'+ext) fd, local_filename = mkstemp(suffix='.'+ext)
os.close(fd) os.close(fd)
local_download = helper.download_to_file(uri, local_path=local_filename, overwrite_existing=True, verbose=True) local_download = helper.download_to_file(uri, local_path=local_filename, overwrite_existing=True, verbose=True)
# if we ended up without any local copy, delete the temp file
if local_download != local_filename:
try:
Path(local_filename).unlink()
except Exception:
pass
# save local model, so we can later query what was the original one # save local model, so we can later query what was the original one
Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri) Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri)
return local_download return local_download

View File

@ -7,6 +7,7 @@ import os
import shutil import shutil
import sys import sys
import threading import threading
from abc import ABCMeta, abstractmethod
from collections import namedtuple from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from copy import copy from copy import copy
@ -80,6 +81,50 @@ class _DownloadProgressReport(object):
(self.downloaded_mb, self._total_size, speed, self._remote_path)) (self.downloaded_mb, self._total_size, speed, self._remote_path))
@six.add_metaclass(ABCMeta)
class _Driver(object):
@abstractmethod
def get_container(self, container_name, config=None, **kwargs):
pass
@abstractmethod
def test_upload(self, test_path, config, **kwargs):
pass
@abstractmethod
def upload_object_via_stream(self, iterator, container, object_name, extra, **kwargs):
pass
@abstractmethod
def list_container_objects(self, container, ex_prefix, **kwargs):
pass
@abstractmethod
def get_direct_access(self, remote_path, **kwargs):
pass
@abstractmethod
def download_object(self, obj, local_path, overwrite_existing, delete_on_failure, callback, **kwargs):
pass
@abstractmethod
def download_object_as_stream(self, obj, chunk_size, **kwargs):
pass
@abstractmethod
def delete_object(self, obj, **kwargs):
pass
@abstractmethod
def upload_object(self, file_path, container, object_name, extra, **kwargs):
pass
@abstractmethod
def get_object(self, container_name, object_name, **kwargs):
pass
class StorageHelper(object): class StorageHelper(object):
""" Storage helper. """ Storage helper.
Used by the entire system to download/upload files. Used by the entire system to download/upload files.
@ -572,6 +617,11 @@ class StorageHelper(object):
remote_path = self._canonize_url(remote_path) remote_path = self._canonize_url(remote_path)
verbose = self._verbose if verbose is None else verbose verbose = self._verbose if verbose is None else verbose
# Check if driver type supports direct access:
direct_access_path = self._driver.get_direct_access(remote_path)
if direct_access_path:
return direct_access_path
temp_local_path = None temp_local_path = None
try: try:
if verbose: if verbose:
@ -893,8 +943,9 @@ class StorageHelper(object):
return None return None
class _HttpDriver(object): class _HttpDriver(_Driver):
""" LibCloud http/https adapter (simple, enough for now) """ """ LibCloud http/https adapter (simple, enough for now) """
timeout = (5.0, 30.) timeout = (5.0, 30.)
class _Container(object): class _Container(object):
@ -920,7 +971,7 @@ class _HttpDriver(object):
self._retries = retries self._retries = retries
self._containers = {} self._containers = {}
def get_container(self, container_name, *_, **kwargs): def get_container(self, container_name, config=None, **kwargs):
if container_name not in self._containers: if container_name not in self._containers:
self._containers[container_name] = self._Container(name=container_name, retries=self._retries, **kwargs) self._containers[container_name] = self._Container(name=container_name, retries=self._retries, **kwargs)
return self._containers[container_name] return self._containers[container_name]
@ -951,11 +1002,11 @@ class _HttpDriver(object):
raise ValueError('Failed getting object %s (%d): %s' % (object_name, res.status_code, res.text)) raise ValueError('Failed getting object %s (%d): %s' % (object_name, res.status_code, res.text))
return res return res
def download_object_as_stream(self, obj, chunk_size=64 * 1024): def download_object_as_stream(self, obj, chunk_size=64 * 1024, **_):
# return iterable object # return iterable object
return obj.iter_content(chunk_size=chunk_size) return obj.iter_content(chunk_size=chunk_size)
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None): def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, **_):
p = Path(local_path) p = Path(local_path)
if not overwrite_existing and p.is_file(): if not overwrite_existing and p.is_file():
log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p)) log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
@ -974,6 +1025,17 @@ class _HttpDriver(object):
return length return length
def get_direct_access(self, remote_path, **_):
return None
def test_upload(self, test_path, config, **kwargs):
return True
def upload_object(self, file_path, container, object_name, extra, **kwargs):
with open(file_path, 'rb') as stream:
return self.upload_object_via_stream(iterator=stream, container=container,
object_name=object_name, extra=extra, **kwargs)
class _Stream(object): class _Stream(object):
encoding = None encoding = None
@ -1068,8 +1130,9 @@ class _Stream(object):
self.write(s) self.write(s)
class _Boto3Driver(object): class _Boto3Driver(_Driver):
""" Boto3 storage adapter (simple, enough for now) """ """ Boto3 storage adapter (simple, enough for now) """
_max_multipart_concurrency = config.get('aws.boto3.max_multipart_concurrency', 16) _max_multipart_concurrency = config.get('aws.boto3.max_multipart_concurrency', 16)
_min_pool_connections = 512 _min_pool_connections = 512
@ -1134,9 +1197,9 @@ class _Boto3Driver(object):
self._stream_download_pool = ThreadPoolExecutor(max_workers=self._stream_download_pool_connections) self._stream_download_pool = ThreadPoolExecutor(max_workers=self._stream_download_pool_connections)
return self._stream_download_pool return self._stream_download_pool
def get_container(self, container_name, *_, **kwargs): def get_container(self, container_name, config=None, **kwargs):
if container_name not in self._containers: if container_name not in self._containers:
self._containers[container_name] = self._Container(name=container_name, cfg=kwargs.get('config')) self._containers[container_name] = self._Container(name=container_name, cfg=config)
self._containers[container_name].config.retries = kwargs.get('retries', 5) self._containers[container_name].config.retries = kwargs.get('retries', 5)
return self._containers[container_name] return self._containers[container_name]
@ -1183,7 +1246,7 @@ class _Boto3Driver(object):
obj.container_name = full_container_name obj.container_name = full_container_name
return obj return obj
def download_object_as_stream(self, obj, chunk_size=64 * 1024): def download_object_as_stream(self, obj, chunk_size=64 * 1024, **_):
def async_download(a_obj, a_stream, cfg): def async_download(a_obj, a_stream, cfg):
try: try:
a_obj.download_fileobj(a_stream, Config=cfg) a_obj.download_fileobj(a_stream, Config=cfg)
@ -1203,7 +1266,7 @@ class _Boto3Driver(object):
return stream return stream
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None): def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, **_):
import boto3.s3.transfer import boto3.s3.transfer
p = Path(local_path) p = Path(local_path)
if not overwrite_existing and p.is_file(): if not overwrite_existing and p.is_file():
@ -1311,8 +1374,14 @@ class _Boto3Driver(object):
return None return None
def get_direct_access(self, remote_path, **_):
return None
class _GoogleCloudStorageDriver(object): def test_upload(self, test_path, config, **_):
return True
class _GoogleCloudStorageDriver(_Driver):
"""Storage driver for google cloud storage""" """Storage driver for google cloud storage"""
_stream_download_pool_connections = 128 _stream_download_pool_connections = 128
@ -1350,9 +1419,9 @@ class _GoogleCloudStorageDriver(object):
self._stream_download_pool = ThreadPoolExecutor(max_workers=self._stream_download_pool_connections) self._stream_download_pool = ThreadPoolExecutor(max_workers=self._stream_download_pool_connections)
return self._stream_download_pool return self._stream_download_pool
def get_container(self, container_name, *_, **kwargs): def get_container(self, container_name, config=None, **kwargs):
if container_name not in self._containers: if container_name not in self._containers:
self._containers[container_name] = self._Container(name=container_name, cfg=kwargs.get('config')) self._containers[container_name] = self._Container(name=container_name, cfg=config)
self._containers[container_name].config.retries = kwargs.get('retries', 5) self._containers[container_name].config.retries = kwargs.get('retries', 5)
return self._containers[container_name] return self._containers[container_name]
@ -1387,7 +1456,7 @@ class _GoogleCloudStorageDriver(object):
obj.container_name = full_container_name obj.container_name = full_container_name
return obj return obj
def download_object_as_stream(self, obj, chunk_size=256 * 1024): def download_object_as_stream(self, obj, chunk_size=256 * 1024, **_):
raise NotImplementedError('Unsupported for google storage') raise NotImplementedError('Unsupported for google storage')
def async_download(a_obj, a_stream): def async_download(a_obj, a_stream):
@ -1404,14 +1473,14 @@ class _GoogleCloudStorageDriver(object):
return stream return stream
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None): def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, **_):
p = Path(local_path) p = Path(local_path)
if not overwrite_existing and p.is_file(): if not overwrite_existing and p.is_file():
log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p)) log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
return return
obj.download_to_filename(str(p)) obj.download_to_filename(str(p))
def test_upload(self, test_path, config): def test_upload(self, test_path, config, **_):
bucket_url = str(furl(scheme=self.scheme, netloc=config.bucket, path=config.subdir)) bucket_url = str(furl(scheme=self.scheme, netloc=config.bucket, path=config.subdir))
bucket = self.get_container(container_name=bucket_url, config=config).bucket bucket = self.get_container(container_name=bucket_url, config=config).bucket
@ -1429,8 +1498,11 @@ class _GoogleCloudStorageDriver(object):
permissions_to_test = ('storage.objects.get', 'storage.objects.update') permissions_to_test = ('storage.objects.get', 'storage.objects.update')
return set(test_obj.test_iam_permissions(permissions_to_test)) == set(permissions_to_test) return set(test_obj.test_iam_permissions(permissions_to_test)) == set(permissions_to_test)
def get_direct_access(self, remote_path, **_):
return None
class _AzureBlobServiceStorageDriver(object):
class _AzureBlobServiceStorageDriver(_Driver):
scheme = 'azure' scheme = 'azure'
_containers = {} _containers = {}
@ -1459,8 +1531,8 @@ class _AzureBlobServiceStorageDriver(object):
blob_name = attrib() blob_name = attrib()
content_length = attrib() content_length = attrib()
def get_container(self, config, *_, **kwargs): def get_container(self, container_name=None, config=None, **kwargs):
container_name = config.container_name container_name = container_name or config.container_name
if container_name not in self._containers: if container_name not in self._containers:
self._containers[container_name] = self._Container(name=container_name, config=config) self._containers[container_name] = self._Container(name=container_name, config=config)
# self._containers[container_name].config.retries = kwargs.get('retries', 5) # self._containers[container_name].config.retries = kwargs.get('retries', 5)
@ -1542,7 +1614,7 @@ class _AzureBlobServiceStorageDriver(object):
) )
return blob.content return blob.content
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None): def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, **_):
p = Path(local_path) p = Path(local_path)
if not overwrite_existing and p.is_file(): if not overwrite_existing and p.is_file():
log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p)) log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
@ -1570,7 +1642,7 @@ class _AzureBlobServiceStorageDriver(object):
) )
download_done.wait() download_done.wait()
def test_upload(self, test_path, config): def test_upload(self, test_path, config, **_):
container = self.get_container(config=config) container = self.get_container(config=config)
try: try:
container.blob_service.get_container_properties(container.name) container.blob_service.get_container_properties(container.name)
@ -1620,8 +1692,11 @@ class _AzureBlobServiceStorageDriver(object):
return name return name
def get_direct_access(self, remote_path, **_):
return None
class _FileStorageDriver(object):
class _FileStorageDriver(_Driver):
""" """
A base StorageDriver to derive from. A base StorageDriver to derive from.
""" """
@ -1775,7 +1850,7 @@ class _FileStorageDriver(object):
return self._get_objects(container) return self._get_objects(container)
def get_container(self, container_name): def get_container(self, container_name, **_):
""" """
Return a container instance. Return a container instance.
@ -1807,7 +1882,7 @@ class _FileStorageDriver(object):
return path return path
def get_object(self, container_name, object_name): def get_object(self, container_name, object_name, **_):
""" """
Return an object instance. Return an object instance.
@ -1835,8 +1910,7 @@ class _FileStorageDriver(object):
""" """
return os.path.realpath(os.path.join(self.base_path, obj.container.name, obj.name)) return os.path.realpath(os.path.join(self.base_path, obj.container.name, obj.name))
def download_object(self, obj, destination_path, overwrite_existing=False, def download_object(self, obj, destination_path, overwrite_existing=False, delete_on_failure=True, **_):
delete_on_failure=True):
""" """
Download an object to the specified destination path. Download an object to the specified destination path.
@ -1888,7 +1962,7 @@ class _FileStorageDriver(object):
return True return True
def download_object_as_stream(self, obj, chunk_size=None): def download_object_as_stream(self, obj, chunk_size=None, **_):
""" """
Return a generator which yields object data. Return a generator which yields object data.
@ -1906,8 +1980,7 @@ class _FileStorageDriver(object):
for data in self._read_in_chunks(obj_file, chunk_size=chunk_size): for data in self._read_in_chunks(obj_file, chunk_size=chunk_size):
yield data yield data
def upload_object(self, file_path, container, object_name, extra=None, def upload_object(self, file_path, container, object_name, extra=None, verify_hash=True, **_):
verify_hash=True):
""" """
Upload an object currently located on a disk. Upload an object currently located on a disk.
@ -1941,9 +2014,7 @@ class _FileStorageDriver(object):
return self._make_object(container, object_name) return self._make_object(container, object_name)
def upload_object_via_stream(self, iterator, container, def upload_object_via_stream(self, iterator, container, object_name, extra=None, **kwargs):
object_name,
extra=None):
""" """
Upload an object using an iterator. Upload an object using an iterator.
@ -1990,7 +2061,7 @@ class _FileStorageDriver(object):
os.chmod(obj_path, int('664', 8)) os.chmod(obj_path, int('664', 8))
return self._make_object(container, object_name) return self._make_object(container, object_name)
def delete_object(self, obj): def delete_object(self, obj, **_):
""" """
Delete an object. Delete an object.
@ -2144,3 +2215,12 @@ class _FileStorageDriver(object):
else: else:
yield data yield data
data = bytes('') data = bytes('')
def get_direct_access(self, remote_path, **_):
# this will always make sure we have full path and file:// prefix
full_url = StorageHelper.conform_url(remote_path)
# now get rid of the file:// prefix
return Path(full_url[7:]).as_posix()
def test_upload(self, test_path, config, **kwargs):
return True