Allow specifying extension when using custom artifact serializer functions

This commit is contained in:
allegroai 2022-09-13 15:02:54 +03:00
parent b4490a8525
commit bc1a243ecd

View File

@ -401,11 +401,14 @@ class StorageHelper(object):
@classmethod
def get_aws_storage_uri_from_config(cls, bucket_config):
return (
uri = (
"s3://{}/{}".format(bucket_config.host, bucket_config.bucket)
if bucket_config.host
else "s3://{}".format(bucket_config.bucket)
)
if bucket_config.subdir:
uri += "/" + bucket_config.subdir
return uri
@classmethod
def get_gcp_storage_uri_from_config(cls, bucket_config):
@ -455,11 +458,11 @@ class StorageHelper(object):
# Test bucket config, fails if unsuccessful
if _test_config:
_Boto3Driver._test_bucket_config(bucket_config, log) # noqa
if existing:
if log:
log.warning("Overriding existing configuration for '{}'".format(uri))
configs.remove_config(existing)
configs.add_config(bucket_config)
else:
# Try to use existing configuration
good_config = False
@ -485,10 +488,12 @@ class StorageHelper(object):
configs = cls._gs_configurations
uri = cls.get_gcp_storage_uri_from_config(bucket_config)
if not use_existing and existing:
if not use_existing:
if existing:
if log:
log.warning("Overriding existing configuration for '{}'".format(uri))
configs.remove_config(existing)
configs.add_config(bucket_config)
else:
good_config = False
if existing:
@ -507,10 +512,13 @@ class StorageHelper(object):
existing = cls.get_azure_configuration(bucket_config)
configs = cls._azure_configurations
uri = cls.get_azure_storage_uri_from_config(bucket_config)
if not use_existing and existing:
if not use_existing:
if existing:
if log:
log.warning("Overriding existing configuration for '{}'".format(uri))
configs.remove_config(existing)
configs.add_config(bucket_config)
else:
good_config = False
if existing:
@ -1642,6 +1650,8 @@ class _Boto3Driver(_Driver):
fullname = furl(conf.bucket).add(path=test_path).add(path='%s-upload_test' % cls.__module__)
bucket_name = str(fullname.path.segments[0])
filename = str(furl(path=fullname.path.segments[1:]))
if conf.subdir:
filename = "{}/{}".format(conf.subdir, filename)
data = {
'user': getpass.getuser(),
@ -1651,7 +1661,7 @@ class _Boto3Driver(_Driver):
boto_session = boto3.Session(conf.key, conf.secret, aws_session_token=conf.token)
endpoint = (('https://' if conf.secure else 'http://') + conf.host) if conf.host else None
boto_resource = boto_session.resource('s3', region_name=conf.region, endpoint_url=endpoint)
boto_resource = boto_session.resource('s3', region_name=conf.region or None, endpoint_url=endpoint)
bucket = boto_resource.Bucket(bucket_name)
bucket.put_object(Key=filename, Body=six.b(json.dumps(data)))