Fix handling of legacy fileserver

This commit is contained in:
allegroai 2022-03-11 23:53:07 +02:00
parent fcdc561f5f
commit 05791ba6f8
2 changed files with 18 additions and 17 deletions

View File

@ -83,6 +83,8 @@ class Session(TokenManager):
default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8" default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
force_max_api_version = None force_max_api_version = None
legacy_file_servers = ["https://files.community.clear.ml"]
# TODO: add requests.codes.gateway_timeout once we support async commits # TODO: add requests.codes.gateway_timeout once we support async commits
_retry_codes = [ _retry_codes = [
requests.codes.bad_gateway, requests.codes.bad_gateway,

View File

@ -52,6 +52,7 @@ class DownloadError(Exception):
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class _Driver(object): class _Driver(object):
_file_server_hosts = None
@classmethod @classmethod
def get_logger(cls): def get_logger(cls):
@ -97,6 +98,12 @@ class _Driver(object):
def get_object(self, container_name, object_name, **kwargs): def get_object(self, container_name, object_name, **kwargs):
pass pass
@classmethod
def get_file_server_hosts(cls):
if cls._file_server_hosts is None:
cls._file_server_hosts = [Session.get_files_server_host()] + (Session.legacy_file_servers or [])
return cls._file_server_hosts
class StorageHelper(object): class StorageHelper(object):
""" Storage helper. """ Storage helper.
@ -105,8 +112,6 @@ class StorageHelper(object):
""" """
_temp_download_suffix = '.partially' _temp_download_suffix = '.partially'
_file_server_host = None
@classmethod @classmethod
def _get_logger(cls): def _get_logger(cls):
return get_logger('storage') return get_logger('storage')
@ -880,9 +885,9 @@ class StorageHelper(object):
conf = cls._gs_configurations.get_config_by_uri(base_url) conf = cls._gs_configurations.get_config_by_uri(base_url)
return str(furl(scheme=parsed.scheme, netloc=conf.bucket)) return str(furl(scheme=parsed.scheme, netloc=conf.bucket))
elif parsed.scheme in _HttpDriver.schemes: elif parsed.scheme in _HttpDriver.schemes:
files_server = cls._get_file_server_host() for files_server in _Driver.get_file_server_hosts():
if base_url.startswith(files_server): if base_url.startswith(files_server):
return files_server return files_server
return parsed.scheme + "://" return parsed.scheme + "://"
else: # if parsed.scheme == 'file': else: # if parsed.scheme == 'file':
# if we do not know what it is, we assume file # if we do not know what it is, we assume file
@ -910,12 +915,6 @@ class StorageHelper(object):
return folder_uri return folder_uri
@classmethod
def _get_file_server_host(cls):
if cls._file_server_host is None:
cls._file_server_host = Session.get_files_server_host()
return cls._file_server_host
def _absolute_object_name(self, path): def _absolute_object_name(self, path):
""" Returns absolute remote path, including any prefix that is handled by the container """ """ Returns absolute remote path, including any prefix that is handled by the container """
if not path.startswith(self.base_url): if not path.startswith(self.base_url):
@ -1044,7 +1043,6 @@ class _HttpDriver(_Driver):
class _Container(object): class _Container(object):
_default_backend_session = None _default_backend_session = None
_default_files_server_host = None
def __init__(self, name, retries=5, **kwargs): def __init__(self, name, retries=5, **kwargs):
self.name = name self.name = name
@ -1064,17 +1062,18 @@ class _HttpDriver(_Driver):
requests_codes.too_many_requests, requests_codes.too_many_requests,
] ]
) )
self.attach_auth_header = any(
(name.rstrip('/') == host.rstrip('/') or name.startswith(host.rstrip('/') + '/'))
for host in _HttpDriver.get_file_server_hosts()
)
def get_headers(self, url): def get_headers(self, _):
if not self._default_backend_session: if not self._default_backend_session:
from ..backend_interface.base import InterfaceBase from ..backend_interface.base import InterfaceBase
self._default_backend_session = InterfaceBase._get_default_session() self._default_backend_session = InterfaceBase._get_default_session()
if self._default_files_server_host is None:
self._default_files_server_host = self._default_backend_session.get_files_server_host().rstrip('/')
if url == self._default_files_server_host or url.startswith(self._default_files_server_host + '/'): if self.attach_auth_header:
return self._default_backend_session.add_auth_headers({}) return self._default_backend_session.add_auth_headers({})
return None
class _HttpSessionHandle(object): class _HttpSessionHandle(object):
def __init__(self, url, is_stream, container_name, object_name): def __init__(self, url, is_stream, container_name, object_name):