Black formatting

This commit is contained in:
clearml 2024-10-03 21:02:03 +03:00
parent ffcda558e7
commit 972696450e

View File

@ -16,34 +16,34 @@ try:
Task.add_requirements("boto3") Task.add_requirements("boto3")
except ImportError as err: except ImportError as err:
raise ImportError( raise ImportError(
"AwsAutoScaler requires 'boto3' package, it was not found\n" "AwsAutoScaler requires 'boto3' package, it was not found\n" "install with: pip install boto3"
"install with: pip install boto3"
) from err ) from err
@attr.s @attr.s
class AWSDriver(CloudDriver): class AWSDriver(CloudDriver):
"""AWS Driver""" """AWS Driver"""
aws_access_key_id = attr.ib(validator=instance_of(str), default='')
aws_secret_access_key = attr.ib(validator=instance_of(str), default='') aws_access_key_id = attr.ib(validator=instance_of(str), default="")
aws_session_token = attr.ib(validator=instance_of(str), default='') aws_secret_access_key = attr.ib(validator=instance_of(str), default="")
aws_region = attr.ib(validator=instance_of(str), default='') aws_session_token = attr.ib(validator=instance_of(str), default="")
aws_region = attr.ib(validator=instance_of(str), default="")
use_credentials_chain = attr.ib(validator=instance_of(bool), default=False) use_credentials_chain = attr.ib(validator=instance_of(bool), default=False)
use_iam_instance_profile = attr.ib(validator=instance_of(bool), default=False) use_iam_instance_profile = attr.ib(validator=instance_of(bool), default=False)
iam_arn = attr.ib(validator=instance_of(str), default='') iam_arn = attr.ib(validator=instance_of(str), default="")
iam_name = attr.ib(validator=instance_of(str), default='') iam_name = attr.ib(validator=instance_of(str), default="")
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
obj = super().from_config(config) obj = super().from_config(config)
obj.aws_access_key_id = config['hyper_params'].get('cloud_credentials_key') obj.aws_access_key_id = config["hyper_params"].get("cloud_credentials_key")
obj.aws_secret_access_key = config['hyper_params'].get('cloud_credentials_secret') obj.aws_secret_access_key = config["hyper_params"].get("cloud_credentials_secret")
obj.aws_session_token = config['hyper_params'].get('cloud_credentials_token') obj.aws_session_token = config["hyper_params"].get("cloud_credentials_token")
obj.aws_region = config['hyper_params'].get('cloud_credentials_region') obj.aws_region = config["hyper_params"].get("cloud_credentials_region")
obj.use_credentials_chain = config['hyper_params'].get('use_credentials_chain', False) obj.use_credentials_chain = config["hyper_params"].get("use_credentials_chain", False)
obj.use_iam_instance_profile = config['hyper_params'].get('use_iam_instance_profile', False) obj.use_iam_instance_profile = config["hyper_params"].get("use_iam_instance_profile", False)
obj.iam_arn = config['hyper_params'].get('iam_arn') obj.iam_arn = config["hyper_params"].get("iam_arn")
obj.iam_name = config['hyper_params'].get('iam_name') obj.iam_name = config["hyper_params"].get("iam_name")
return obj return obj
def __attrs_post_init__(self): def __attrs_post_init__(self):
@ -60,7 +60,7 @@ class AWSDriver(CloudDriver):
launch_specification = ConfigFactory.from_dict( launch_specification = ConfigFactory.from_dict(
{ {
"ImageId": resource_conf["ami_id"], "ImageId": resource_conf["ami_id"],
"Monitoring": {'Enabled': bool(resource_conf.get('enable_monitoring', False))}, "Monitoring": {"Enabled": bool(resource_conf.get("enable_monitoring", False))},
"InstanceType": resource_conf["instance_type"], "InstanceType": resource_conf["instance_type"],
} }
) )
@ -70,9 +70,7 @@ class AWSDriver(CloudDriver):
launch_specification["BlockDeviceMappings"] = [ launch_specification["BlockDeviceMappings"] = [
{ {
"DeviceName": resource_conf["ebs_device_name"], "DeviceName": resource_conf["ebs_device_name"],
"Ebs": { "Ebs": {"SnapshotId": resource_conf["ebs_snapshot_id"]},
"SnapshotId": resource_conf["ebs_snapshot_id"]
}
} }
] ]
elif resource_conf.get("ebs_device_name"): elif resource_conf.get("ebs_device_name"):
@ -81,8 +79,8 @@ class AWSDriver(CloudDriver):
"DeviceName": resource_conf["ebs_device_name"], "DeviceName": resource_conf["ebs_device_name"],
"Ebs": { "Ebs": {
"VolumeSize": resource_conf.get("ebs_volume_size", 80), "VolumeSize": resource_conf.get("ebs_volume_size", 80),
"VolumeType": resource_conf.get("ebs_volume_type", "gp3") "VolumeType": resource_conf.get("ebs_volume_type", "gp3"),
} },
} }
] ]
@ -91,45 +89,33 @@ class AWSDriver(CloudDriver):
elif resource_conf.get("availability_zone", None): elif resource_conf.get("availability_zone", None):
launch_specification["Placement"] = {"AvailabilityZone": resource_conf["availability_zone"]} launch_specification["Placement"] = {"AvailabilityZone": resource_conf["availability_zone"]}
else: else:
raise Exception('subnet_id or availability_zone must to be specified in the config') raise Exception("subnet_id or availability_zone must to be specified in the config")
if resource_conf.get("key_name", None): if resource_conf.get("key_name", None):
launch_specification["KeyName"] = resource_conf["key_name"] launch_specification["KeyName"] = resource_conf["key_name"]
if resource_conf.get("security_group_ids", None): if resource_conf.get("security_group_ids", None):
launch_specification["SecurityGroupIds"] = resource_conf[ launch_specification["SecurityGroupIds"] = resource_conf["security_group_ids"]
"security_group_ids"
]
# Adding iam role - you can have Arn OR Name, not both, Arn getting priority # Adding iam role - you can have Arn OR Name, not both, Arn getting priority
if self.iam_arn: if self.iam_arn:
launch_specification["IamInstanceProfile"] = { launch_specification["IamInstanceProfile"] = {
'Arn': self.iam_arn, "Arn": self.iam_arn,
} }
elif self.iam_name: elif self.iam_name:
launch_specification["IamInstanceProfile"] = { launch_specification["IamInstanceProfile"] = {"Name": self.iam_name}
'Name': self.iam_name
}
if resource_conf["is_spot"]: if resource_conf["is_spot"]:
# Create a request for a spot instance in AWS # Create a request for a spot instance in AWS
encoded_user_data = base64.b64encode(user_data.encode("ascii")).decode( encoded_user_data = base64.b64encode(user_data.encode("ascii")).decode("ascii")
"ascii"
)
launch_specification["UserData"] = encoded_user_data launch_specification["UserData"] = encoded_user_data
ConfigTree.merge_configs( ConfigTree.merge_configs(launch_specification, resource_conf.get("extra_configurations", {}))
launch_specification, resource_conf.get("extra_configurations", {})
)
instances = ec2.request_spot_instances( instances = ec2.request_spot_instances(LaunchSpecification=launch_specification)
LaunchSpecification=launch_specification
)
# Wait until spot request is fulfilled # Wait until spot request is fulfilled
request_id = instances["SpotInstanceRequests"][0]["SpotInstanceRequestId"] request_id = instances["SpotInstanceRequests"][0]["SpotInstanceRequestId"]
waiter = ec2.get_waiter("spot_instance_request_fulfilled") waiter = ec2.get_waiter("spot_instance_request_fulfilled")
waiter.wait(SpotInstanceRequestIds=[request_id]) waiter.wait(SpotInstanceRequestIds=[request_id])
# Get the instance object for later use # Get the instance object for later use
response = ec2.describe_spot_instance_requests( response = ec2.describe_spot_instance_requests(SpotInstanceRequestIds=[request_id])
SpotInstanceRequestIds=[request_id]
)
instance_id = response["SpotInstanceRequests"][0]["InstanceId"] instance_id = response["SpotInstanceRequests"][0]["InstanceId"]
else: else:
@ -140,9 +126,7 @@ class AWSDriver(CloudDriver):
UserData=user_data, UserData=user_data,
InstanceInitiatedShutdownBehavior="terminate", InstanceInitiatedShutdownBehavior="terminate",
) )
ConfigTree.merge_configs( ConfigTree.merge_configs(launch_specification, resource_conf.get("extra_configurations", {}))
launch_specification, resource_conf.get("extra_configurations", {})
)
instances = ec2.run_instances(**launch_specification) instances = ec2.run_instances(**launch_specification)
@ -165,30 +149,32 @@ class AWSDriver(CloudDriver):
def creds(self): def creds(self):
creds = { creds = {
'region_name': self.aws_region or None, "region_name": self.aws_region or None,
} }
if not self.use_credentials_chain: if not self.use_credentials_chain:
creds.update({ creds.update(
'aws_secret_access_key': self.aws_secret_access_key or None, {
'aws_access_key_id': self.aws_access_key_id or None, "aws_secret_access_key": self.aws_secret_access_key or None,
'aws_session_token': self.aws_session_token or None, "aws_access_key_id": self.aws_access_key_id or None,
}) "aws_session_token": self.aws_session_token or None,
}
)
return creds return creds
def instance_id_command(self): def instance_id_command(self):
return 'curl http://169.254.169.254/latest/meta-data/instance-id' return "curl http://169.254.169.254/latest/meta-data/instance-id"
def instance_type_key(self): def instance_type_key(self):
return 'instance_type' return "instance_type"
def kind(self): def kind(self):
return 'AWS' return "AWS"
def console_log(self, instance_id): def console_log(self, instance_id):
ec2 = boto3.client("ec2", **self.creds()) ec2 = boto3.client("ec2", **self.creds())
try: try:
out = ec2.get_console_output(InstanceId=instance_id) out = ec2.get_console_output(InstanceId=instance_id)
return out.get('Output', '') return out.get("Output", "")
except ClientError as err: except ClientError as err:
return 'error: cannot get logs for {}:\n{}'.format(instance_id, err) return "error: cannot get logs for {}:\n{}".format(instance_id, err)