diff --git a/clearml/backend_api/session/session.py b/clearml/backend_api/session/session.py index 19dda1a2..13f56f6b 100644 --- a/clearml/backend_api/session/session.py +++ b/clearml/backend_api/session/session.py @@ -83,6 +83,8 @@ class Session(TokenManager): default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8" force_max_api_version = None + legacy_file_servers = ["https://files.community.clear.ml"] + # TODO: add requests.codes.gateway_timeout once we support async commits _retry_codes = [ requests.codes.bad_gateway, diff --git a/clearml/storage/helper.py b/clearml/storage/helper.py index 547665e3..a741f1af 100644 --- a/clearml/storage/helper.py +++ b/clearml/storage/helper.py @@ -52,6 +52,7 @@ class DownloadError(Exception): @six.add_metaclass(ABCMeta) class _Driver(object): + _file_server_hosts = None @classmethod def get_logger(cls): @@ -97,6 +98,12 @@ class _Driver(object): def get_object(self, container_name, object_name, **kwargs): pass + @classmethod + def get_file_server_hosts(cls): + if cls._file_server_hosts is None: + cls._file_server_hosts = [Session.get_files_server_host()] + (Session.legacy_file_servers or []) + return cls._file_server_hosts + class StorageHelper(object): """ Storage helper. @@ -105,8 +112,6 @@ class StorageHelper(object): """ _temp_download_suffix = '.partially' - _file_server_host = None - @classmethod def _get_logger(cls): return get_logger('storage') @@ -880,9 +885,9 @@ class StorageHelper(object): conf = cls._gs_configurations.get_config_by_uri(base_url) return str(furl(scheme=parsed.scheme, netloc=conf.bucket)) elif parsed.scheme in _HttpDriver.schemes: - files_server = cls._get_file_server_host() - if base_url.startswith(files_server): - return files_server + for files_server in _Driver.get_file_server_hosts(): + if base_url.startswith(files_server): + return files_server return parsed.scheme + "://" else: # if parsed.scheme == 'file': # if we do not know what it is, we assume file @@ -910,12 +915,6 @@ class StorageHelper(object): return folder_uri - @classmethod - def _get_file_server_host(cls): - if cls._file_server_host is None: - cls._file_server_host = Session.get_files_server_host() - return cls._file_server_host - def _absolute_object_name(self, path): """ Returns absolute remote path, including any prefix that is handled by the container """ if not path.startswith(self.base_url): @@ -1044,7 +1043,6 @@ class _HttpDriver(_Driver): class _Container(object): _default_backend_session = None - _default_files_server_host = None def __init__(self, name, retries=5, **kwargs): self.name = name @@ -1064,17 +1062,18 @@ class _HttpDriver(_Driver): requests_codes.too_many_requests, ] ) + self.attach_auth_header = any( + (name.rstrip('/') == host.rstrip('/') or name.startswith(host.rstrip('/') + '/')) + for host in _HttpDriver.get_file_server_hosts() + ) - def get_headers(self, url): + def get_headers(self, _): if not self._default_backend_session: from ..backend_interface.base import InterfaceBase self._default_backend_session = InterfaceBase._get_default_session() - if self._default_files_server_host is None: - self._default_files_server_host = self._default_backend_session.get_files_server_host().rstrip('/') - if url == self._default_files_server_host or url.startswith(self._default_files_server_host + '/'): + if self.attach_auth_header: return self._default_backend_session.add_auth_headers({}) - return None class _HttpSessionHandle(object): def __init__(self, url, is_stream, container_name, object_name):