Black formatting

This commit is contained in:
clearml 2024-10-03 21:02:03 +03:00
parent ffcda558e7
commit 972696450e

View File

@ -16,34 +16,34 @@ try:
Task.add_requirements("boto3")
except ImportError as err:
raise ImportError(
"AwsAutoScaler requires 'boto3' package, it was not found\n"
"install with: pip install boto3"
"AwsAutoScaler requires 'boto3' package, it was not found\n" "install with: pip install boto3"
) from err
@attr.s
class AWSDriver(CloudDriver):
"""AWS Driver"""
aws_access_key_id = attr.ib(validator=instance_of(str), default='')
aws_secret_access_key = attr.ib(validator=instance_of(str), default='')
aws_session_token = attr.ib(validator=instance_of(str), default='')
aws_region = attr.ib(validator=instance_of(str), default='')
aws_access_key_id = attr.ib(validator=instance_of(str), default="")
aws_secret_access_key = attr.ib(validator=instance_of(str), default="")
aws_session_token = attr.ib(validator=instance_of(str), default="")
aws_region = attr.ib(validator=instance_of(str), default="")
use_credentials_chain = attr.ib(validator=instance_of(bool), default=False)
use_iam_instance_profile = attr.ib(validator=instance_of(bool), default=False)
iam_arn = attr.ib(validator=instance_of(str), default='')
iam_name = attr.ib(validator=instance_of(str), default='')
iam_arn = attr.ib(validator=instance_of(str), default="")
iam_name = attr.ib(validator=instance_of(str), default="")
@classmethod
def from_config(cls, config):
obj = super().from_config(config)
obj.aws_access_key_id = config['hyper_params'].get('cloud_credentials_key')
obj.aws_secret_access_key = config['hyper_params'].get('cloud_credentials_secret')
obj.aws_session_token = config['hyper_params'].get('cloud_credentials_token')
obj.aws_region = config['hyper_params'].get('cloud_credentials_region')
obj.use_credentials_chain = config['hyper_params'].get('use_credentials_chain', False)
obj.use_iam_instance_profile = config['hyper_params'].get('use_iam_instance_profile', False)
obj.iam_arn = config['hyper_params'].get('iam_arn')
obj.iam_name = config['hyper_params'].get('iam_name')
obj.aws_access_key_id = config["hyper_params"].get("cloud_credentials_key")
obj.aws_secret_access_key = config["hyper_params"].get("cloud_credentials_secret")
obj.aws_session_token = config["hyper_params"].get("cloud_credentials_token")
obj.aws_region = config["hyper_params"].get("cloud_credentials_region")
obj.use_credentials_chain = config["hyper_params"].get("use_credentials_chain", False)
obj.use_iam_instance_profile = config["hyper_params"].get("use_iam_instance_profile", False)
obj.iam_arn = config["hyper_params"].get("iam_arn")
obj.iam_name = config["hyper_params"].get("iam_name")
return obj
def __attrs_post_init__(self):
@ -60,7 +60,7 @@ class AWSDriver(CloudDriver):
launch_specification = ConfigFactory.from_dict(
{
"ImageId": resource_conf["ami_id"],
"Monitoring": {'Enabled': bool(resource_conf.get('enable_monitoring', False))},
"Monitoring": {"Enabled": bool(resource_conf.get("enable_monitoring", False))},
"InstanceType": resource_conf["instance_type"],
}
)
@ -70,9 +70,7 @@ class AWSDriver(CloudDriver):
launch_specification["BlockDeviceMappings"] = [
{
"DeviceName": resource_conf["ebs_device_name"],
"Ebs": {
"SnapshotId": resource_conf["ebs_snapshot_id"]
}
"Ebs": {"SnapshotId": resource_conf["ebs_snapshot_id"]},
}
]
elif resource_conf.get("ebs_device_name"):
@ -81,8 +79,8 @@ class AWSDriver(CloudDriver):
"DeviceName": resource_conf["ebs_device_name"],
"Ebs": {
"VolumeSize": resource_conf.get("ebs_volume_size", 80),
"VolumeType": resource_conf.get("ebs_volume_type", "gp3")
}
"VolumeType": resource_conf.get("ebs_volume_type", "gp3"),
},
}
]
@ -91,45 +89,33 @@ class AWSDriver(CloudDriver):
elif resource_conf.get("availability_zone", None):
launch_specification["Placement"] = {"AvailabilityZone": resource_conf["availability_zone"]}
else:
raise Exception('subnet_id or availability_zone must to be specified in the config')
raise Exception("subnet_id or availability_zone must to be specified in the config")
if resource_conf.get("key_name", None):
launch_specification["KeyName"] = resource_conf["key_name"]
if resource_conf.get("security_group_ids", None):
launch_specification["SecurityGroupIds"] = resource_conf[
"security_group_ids"
]
launch_specification["SecurityGroupIds"] = resource_conf["security_group_ids"]
# Adding iam role - you can have Arn OR Name, not both, Arn getting priority
if self.iam_arn:
launch_specification["IamInstanceProfile"] = {
'Arn': self.iam_arn,
"Arn": self.iam_arn,
}
elif self.iam_name:
launch_specification["IamInstanceProfile"] = {
'Name': self.iam_name
}
launch_specification["IamInstanceProfile"] = {"Name": self.iam_name}
if resource_conf["is_spot"]:
# Create a request for a spot instance in AWS
encoded_user_data = base64.b64encode(user_data.encode("ascii")).decode(
"ascii"
)
encoded_user_data = base64.b64encode(user_data.encode("ascii")).decode("ascii")
launch_specification["UserData"] = encoded_user_data
ConfigTree.merge_configs(
launch_specification, resource_conf.get("extra_configurations", {})
)
ConfigTree.merge_configs(launch_specification, resource_conf.get("extra_configurations", {}))
instances = ec2.request_spot_instances(
LaunchSpecification=launch_specification
)
instances = ec2.request_spot_instances(LaunchSpecification=launch_specification)
# Wait until spot request is fulfilled
request_id = instances["SpotInstanceRequests"][0]["SpotInstanceRequestId"]
waiter = ec2.get_waiter("spot_instance_request_fulfilled")
waiter.wait(SpotInstanceRequestIds=[request_id])
# Get the instance object for later use
response = ec2.describe_spot_instance_requests(
SpotInstanceRequestIds=[request_id]
)
response = ec2.describe_spot_instance_requests(SpotInstanceRequestIds=[request_id])
instance_id = response["SpotInstanceRequests"][0]["InstanceId"]
else:
@ -140,9 +126,7 @@ class AWSDriver(CloudDriver):
UserData=user_data,
InstanceInitiatedShutdownBehavior="terminate",
)
ConfigTree.merge_configs(
launch_specification, resource_conf.get("extra_configurations", {})
)
ConfigTree.merge_configs(launch_specification, resource_conf.get("extra_configurations", {}))
instances = ec2.run_instances(**launch_specification)
@ -165,30 +149,32 @@ class AWSDriver(CloudDriver):
def creds(self):
creds = {
'region_name': self.aws_region or None,
"region_name": self.aws_region or None,
}
if not self.use_credentials_chain:
creds.update({
'aws_secret_access_key': self.aws_secret_access_key or None,
'aws_access_key_id': self.aws_access_key_id or None,
'aws_session_token': self.aws_session_token or None,
})
creds.update(
{
"aws_secret_access_key": self.aws_secret_access_key or None,
"aws_access_key_id": self.aws_access_key_id or None,
"aws_session_token": self.aws_session_token or None,
}
)
return creds
def instance_id_command(self):
return 'curl http://169.254.169.254/latest/meta-data/instance-id'
return "curl http://169.254.169.254/latest/meta-data/instance-id"
def instance_type_key(self):
return 'instance_type'
return "instance_type"
def kind(self):
return 'AWS'
return "AWS"
def console_log(self, instance_id):
ec2 = boto3.client("ec2", **self.creds())
try:
out = ec2.get_console_output(InstanceId=instance_id)
return out.get('Output', '')
return out.get("Output", "")
except ClientError as err:
return 'error: cannot get logs for {}:\n{}'.format(instance_id, err)
return "error: cannot get logs for {}:\n{}".format(instance_id, err)