Fix handling of legacy fileserver

This commit is contained in:
allegroai 2022-03-11 23:53:07 +02:00
parent fcdc561f5f
commit 05791ba6f8
2 changed files with 18 additions and 17 deletions

View File

@ -83,6 +83,8 @@ class Session(TokenManager):
default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
force_max_api_version = None
legacy_file_servers = ["https://files.community.clear.ml"]
# TODO: add requests.codes.gateway_timeout once we support async commits
_retry_codes = [
requests.codes.bad_gateway,

View File

@ -52,6 +52,7 @@ class DownloadError(Exception):
@six.add_metaclass(ABCMeta)
class _Driver(object):
_file_server_hosts = None
@classmethod
def get_logger(cls):
@ -97,6 +98,12 @@ class _Driver(object):
def get_object(self, container_name, object_name, **kwargs):
pass
@classmethod
def get_file_server_hosts(cls):
if cls._file_server_hosts is None:
cls._file_server_hosts = [Session.get_files_server_host()] + (Session.legacy_file_servers or [])
return cls._file_server_hosts
class StorageHelper(object):
""" Storage helper.
@ -105,8 +112,6 @@ class StorageHelper(object):
"""
_temp_download_suffix = '.partially'
_file_server_host = None
@classmethod
def _get_logger(cls):
return get_logger('storage')
@ -880,9 +885,9 @@ class StorageHelper(object):
conf = cls._gs_configurations.get_config_by_uri(base_url)
return str(furl(scheme=parsed.scheme, netloc=conf.bucket))
elif parsed.scheme in _HttpDriver.schemes:
files_server = cls._get_file_server_host()
if base_url.startswith(files_server):
return files_server
for files_server in _Driver.get_file_server_hosts():
if base_url.startswith(files_server):
return files_server
return parsed.scheme + "://"
else: # if parsed.scheme == 'file':
# if we do not know what it is, we assume file
@ -910,12 +915,6 @@ class StorageHelper(object):
return folder_uri
@classmethod
def _get_file_server_host(cls):
if cls._file_server_host is None:
cls._file_server_host = Session.get_files_server_host()
return cls._file_server_host
def _absolute_object_name(self, path):
""" Returns absolute remote path, including any prefix that is handled by the container """
if not path.startswith(self.base_url):
@ -1044,7 +1043,6 @@ class _HttpDriver(_Driver):
class _Container(object):
_default_backend_session = None
_default_files_server_host = None
def __init__(self, name, retries=5, **kwargs):
self.name = name
@ -1064,17 +1062,18 @@ class _HttpDriver(_Driver):
requests_codes.too_many_requests,
]
)
self.attach_auth_header = any(
(name.rstrip('/') == host.rstrip('/') or name.startswith(host.rstrip('/') + '/'))
for host in _HttpDriver.get_file_server_hosts()
)
def get_headers(self, url):
def get_headers(self, _):
if not self._default_backend_session:
from ..backend_interface.base import InterfaceBase
self._default_backend_session = InterfaceBase._get_default_session()
if self._default_files_server_host is None:
self._default_files_server_host = self._default_backend_session.get_files_server_host().rstrip('/')
if url == self._default_files_server_host or url.startswith(self._default_files_server_host + '/'):
if self.attach_auth_header:
return self._default_backend_session.add_auth_headers({})
return None
class _HttpSessionHandle(object):
def __init__(self, url, is_stream, container_name, object_name):