Fix filserver upload does not support path in fileserver URL

This commit is contained in:
allegroai 2022-02-26 13:40:26 +02:00
parent 0cca7396e7
commit 83a71d0fb0

View File

@ -33,6 +33,7 @@ from six.moves.urllib.request import url2pathname
from .callbacks import UploadProgressReport, DownloadProgressReport from .callbacks import UploadProgressReport, DownloadProgressReport
from .util import quote_url from .util import quote_url
from ..backend_api.session import Session
from ..backend_api.utils import get_http_session_with_retry from ..backend_api.utils import get_http_session_with_retry
from ..backend_config.bucket_config import S3BucketConfigurations, GSBucketConfigurations, AzureContainerConfigurations from ..backend_config.bucket_config import S3BucketConfigurations, GSBucketConfigurations, AzureContainerConfigurations
from ..config import config, deferred_config from ..config import config, deferred_config
@ -104,6 +105,8 @@ class StorageHelper(object):
""" """
_temp_download_suffix = '.partially' _temp_download_suffix = '.partially'
_file_server_host = None
@classmethod @classmethod
def _get_logger(cls): def _get_logger(cls):
return get_logger('storage') return get_logger('storage')
@ -876,10 +879,11 @@ class StorageHelper(object):
elif parsed.scheme == _GoogleCloudStorageDriver.scheme: elif parsed.scheme == _GoogleCloudStorageDriver.scheme:
conf = cls._gs_configurations.get_config_by_uri(base_url) conf = cls._gs_configurations.get_config_by_uri(base_url)
return str(furl(scheme=parsed.scheme, netloc=conf.bucket)) return str(furl(scheme=parsed.scheme, netloc=conf.bucket))
elif parsed.scheme == 'http': elif parsed.scheme in ('http', 'https'):
return 'http://' files_server = cls._get_file_server_host()
elif parsed.scheme == 'https': if base_url.startswith(files_server):
return 'https://' return files_server
return parsed.scheme + "://"
else: # if parsed.scheme == 'file': else: # if parsed.scheme == 'file':
# if we do not know what it is, we assume file # if we do not know what it is, we assume file
return 'file://' return 'file://'
@ -906,6 +910,12 @@ class StorageHelper(object):
return folder_uri return folder_uri
@classmethod
def _get_file_server_host(cls):
if cls._file_server_host is None:
cls._file_server_host = Session.get_files_server_host()
return cls._file_server_host
def _absolute_object_name(self, path): def _absolute_object_name(self, path):
""" Returns absolute remote path, including any prefix that is handled by the container """ """ Returns absolute remote path, including any prefix that is handled by the container """
if not path.startswith(self.base_url): if not path.startswith(self.base_url):
@ -1081,11 +1091,14 @@ class _HttpDriver(_Driver):
return self._containers[container_name] return self._containers[container_name]
def upload_object_via_stream(self, iterator, container, object_name, extra=None, callback=None, **kwargs): def upload_object_via_stream(self, iterator, container, object_name, extra=None, callback=None, **kwargs):
url = object_name[:object_name.index('/')]
url_path = object_name[len(url) + 1:]
full_url = container.name + url
# when sending data in post, there is no connection timeout, just an entire upload timeout # when sending data in post, there is no connection timeout, just an entire upload timeout
timeout = int(self.timeout_total) timeout = int(self.timeout_total)
url = container.name
path = object_name
if not urlparse(url).netloc:
host, _, path = object_name.partition('/')
url += host + '/'
stream_size = 0 stream_size = 0
if hasattr(iterator, 'tell') and hasattr(iterator, 'seek'): if hasattr(iterator, 'tell') and hasattr(iterator, 'seek'):
pos = iterator.tell() pos = iterator.tell()
@ -1093,9 +1106,9 @@ class _HttpDriver(_Driver):
stream_size = iterator.tell() - pos stream_size = iterator.tell() - pos
iterator.seek(pos, 0) iterator.seek(pos, 0)
timeout = max(timeout, (stream_size / 1024) / float(self.min_kbps_speed)) timeout = max(timeout, (stream_size / 1024) / float(self.min_kbps_speed))
res = container.session.post(
res = container.session.post(full_url, files={url_path: iterator}, timeout=timeout, url, files={path: iterator}, timeout=timeout, headers=container.get_headers(url)
headers=container.get_headers(full_url)) )
if res.status_code != requests.codes.ok: if res.status_code != requests.codes.ok:
raise ValueError('Failed uploading object %s (%d): %s' % (object_name, res.status_code, res.text)) raise ValueError('Failed uploading object %s (%d): %s' % (object_name, res.status_code, res.text))