Add support for AWS Session Token in AWS Storage configuration

This commit is contained in:
allegroai 2022-04-15 19:20:05 +03:00
parent 7a1b42c5ed
commit a1709d5d41
2 changed files with 39 additions and 12 deletions

View File

@ -25,6 +25,7 @@ class S3BucketConfig(object):
host = attrib(type=str, converter=_none_to_empty_string, default="")
key = attrib(type=str, converter=_none_to_empty_string, default="")
secret = attrib(type=str, converter=_none_to_empty_string, default="")
token = attrib(type=str, converter=_none_to_empty_string, default="")
multipart = attrib(type=bool, default=True)
acl = attrib(type=str, converter=_none_to_empty_string, default="")
secure = attrib(type=bool, default=True)
@ -32,9 +33,10 @@ class S3BucketConfig(object):
verify = attrib(type=bool, default=True)
use_credentials_chain = attrib(type=bool, default=False)
def update(self, key, secret, multipart=True, region=None, use_credentials_chain=False):
def update(self, key, secret, multipart=True, region=None, use_credentials_chain=False, token=""):
self.key = key
self.secret = secret
self.token = token
self.multipart = multipart
self.region = region
self.use_credentials_chain = use_credentials_chain
@ -91,12 +93,19 @@ class BaseBucketConfigurations(object):
class S3BucketConfigurations(BaseBucketConfigurations):
def __init__(
self, buckets=None, default_key="", default_secret="", default_region="", default_use_credentials_chain=False
self,
buckets=None,
default_key="",
default_secret="",
default_region="",
default_use_credentials_chain=False,
default_token="",
):
super(S3BucketConfigurations, self).__init__()
self._buckets = buckets if buckets else list()
self._default_key = default_key
self._default_secret = default_secret
self._default_token = default_token
self._default_region = default_region
self._default_multipart = True
self._default_use_credentials_chain = default_use_credentials_chain
@ -107,16 +116,18 @@ class S3BucketConfigurations(BaseBucketConfigurations):
s3_configuration.get("credentials", [])
)
default_key = s3_configuration.get("key") or getenv("AWS_ACCESS_KEY_ID", "")
default_secret = s3_configuration.get("secret") or getenv("AWS_SECRET_ACCESS_KEY", "")
default_region = s3_configuration.get("region") or getenv("AWS_DEFAULT_REGION", "")
default_key = s3_configuration.get("key", "") or getenv("AWS_ACCESS_KEY_ID", "")
default_secret = s3_configuration.get("secret", "") or getenv("AWS_SECRET_ACCESS_KEY", "")
default_token = s3_configuration.get("token", "") or getenv("AWS_SESSION_TOKEN", "")
default_region = s3_configuration.get("region", "") or getenv("AWS_DEFAULT_REGION", "")
default_use_credentials_chain = s3_configuration.get("use_credentials_chain") or False
default_key = _none_to_empty_string(default_key)
default_secret = _none_to_empty_string(default_secret)
default_token = _none_to_empty_string(default_token)
default_region = _none_to_empty_string(default_region)
return cls(config_list, default_key, default_secret, default_region, default_use_credentials_chain)
return cls(config_list, default_key, default_secret, default_region, default_use_credentials_chain, default_token)
def add_config(self, bucket_config):
self._buckets.insert(0, bucket_config)
@ -144,7 +155,8 @@ class S3BucketConfigurations(BaseBucketConfigurations):
secret=self._default_secret,
region=bucket_config.region or self._default_region,
multipart=bucket_config.multipart or self._default_multipart,
use_credentials_chain=self._default_use_credentials_chain
use_credentials_chain=self._default_use_credentials_chain,
token=self._default_token,
)
def _get_prefix_from_bucket_config(self, config):
@ -209,6 +221,7 @@ class S3BucketConfigurations(BaseBucketConfigurations):
use_credentials_chain=self._default_use_credentials_chain,
bucket=bucket,
host=host,
token=self._default_token
)

View File

@ -280,9 +280,20 @@ class StorageHelper(object):
local_path = mktemp(suffix=file_name)
return helper.download_to_file(remote_url, local_path, skip_zero_size_check=skip_zero_size_check)
def __init__(self, base_url, url, key=None, secret=None, region=None, verbose=False, logger=None, retries=5,
**kwargs):
level = config.get('storage.log.level', None)
def __init__(
self,
base_url,
url,
key=None,
secret=None,
region=None,
verbose=False,
logger=None,
retries=5,
token=None,
**kwargs,
):
level = config.get("storage.log.level", None)
if level:
try:
@ -332,7 +343,8 @@ class StorageHelper(object):
secret=secret or self._conf.secret,
multipart=self._conf.multipart,
region=final_region,
use_credentials_chain=self._conf.use_credentials_chain
use_credentials_chain=self._conf.use_credentials_chain,
token=token or self._conf.token
)
if not self._conf.use_credentials_chain:
@ -1352,9 +1364,11 @@ class _Boto3Driver(_Driver):
if not cfg.use_credentials_chain:
boto_kwargs["aws_access_key_id"] = cfg.key
boto_kwargs["aws_secret_access_key"] = cfg.secret
if cfg.token:
boto_kwargs["aws_session_token"] = cfg.token
self.resource = boto3.resource(
's3',
"s3",
**boto_kwargs
)