mirror of
https://github.com/clearml/clearml
synced 2025-01-31 00:56:57 +00:00
Black formatting
This commit is contained in:
parent
ffcda558e7
commit
972696450e
@ -16,34 +16,34 @@ try:
|
||||
Task.add_requirements("boto3")
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"AwsAutoScaler requires 'boto3' package, it was not found\n"
|
||||
"install with: pip install boto3"
|
||||
"AwsAutoScaler requires 'boto3' package, it was not found\n" "install with: pip install boto3"
|
||||
) from err
|
||||
|
||||
|
||||
@attr.s
|
||||
class AWSDriver(CloudDriver):
|
||||
"""AWS Driver"""
|
||||
aws_access_key_id = attr.ib(validator=instance_of(str), default='')
|
||||
aws_secret_access_key = attr.ib(validator=instance_of(str), default='')
|
||||
aws_session_token = attr.ib(validator=instance_of(str), default='')
|
||||
aws_region = attr.ib(validator=instance_of(str), default='')
|
||||
|
||||
aws_access_key_id = attr.ib(validator=instance_of(str), default="")
|
||||
aws_secret_access_key = attr.ib(validator=instance_of(str), default="")
|
||||
aws_session_token = attr.ib(validator=instance_of(str), default="")
|
||||
aws_region = attr.ib(validator=instance_of(str), default="")
|
||||
use_credentials_chain = attr.ib(validator=instance_of(bool), default=False)
|
||||
use_iam_instance_profile = attr.ib(validator=instance_of(bool), default=False)
|
||||
iam_arn = attr.ib(validator=instance_of(str), default='')
|
||||
iam_name = attr.ib(validator=instance_of(str), default='')
|
||||
iam_arn = attr.ib(validator=instance_of(str), default="")
|
||||
iam_name = attr.ib(validator=instance_of(str), default="")
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
obj = super().from_config(config)
|
||||
obj.aws_access_key_id = config['hyper_params'].get('cloud_credentials_key')
|
||||
obj.aws_secret_access_key = config['hyper_params'].get('cloud_credentials_secret')
|
||||
obj.aws_session_token = config['hyper_params'].get('cloud_credentials_token')
|
||||
obj.aws_region = config['hyper_params'].get('cloud_credentials_region')
|
||||
obj.use_credentials_chain = config['hyper_params'].get('use_credentials_chain', False)
|
||||
obj.use_iam_instance_profile = config['hyper_params'].get('use_iam_instance_profile', False)
|
||||
obj.iam_arn = config['hyper_params'].get('iam_arn')
|
||||
obj.iam_name = config['hyper_params'].get('iam_name')
|
||||
obj.aws_access_key_id = config["hyper_params"].get("cloud_credentials_key")
|
||||
obj.aws_secret_access_key = config["hyper_params"].get("cloud_credentials_secret")
|
||||
obj.aws_session_token = config["hyper_params"].get("cloud_credentials_token")
|
||||
obj.aws_region = config["hyper_params"].get("cloud_credentials_region")
|
||||
obj.use_credentials_chain = config["hyper_params"].get("use_credentials_chain", False)
|
||||
obj.use_iam_instance_profile = config["hyper_params"].get("use_iam_instance_profile", False)
|
||||
obj.iam_arn = config["hyper_params"].get("iam_arn")
|
||||
obj.iam_name = config["hyper_params"].get("iam_name")
|
||||
return obj
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
@ -60,7 +60,7 @@ class AWSDriver(CloudDriver):
|
||||
launch_specification = ConfigFactory.from_dict(
|
||||
{
|
||||
"ImageId": resource_conf["ami_id"],
|
||||
"Monitoring": {'Enabled': bool(resource_conf.get('enable_monitoring', False))},
|
||||
"Monitoring": {"Enabled": bool(resource_conf.get("enable_monitoring", False))},
|
||||
"InstanceType": resource_conf["instance_type"],
|
||||
}
|
||||
)
|
||||
@ -70,9 +70,7 @@ class AWSDriver(CloudDriver):
|
||||
launch_specification["BlockDeviceMappings"] = [
|
||||
{
|
||||
"DeviceName": resource_conf["ebs_device_name"],
|
||||
"Ebs": {
|
||||
"SnapshotId": resource_conf["ebs_snapshot_id"]
|
||||
}
|
||||
"Ebs": {"SnapshotId": resource_conf["ebs_snapshot_id"]},
|
||||
}
|
||||
]
|
||||
elif resource_conf.get("ebs_device_name"):
|
||||
@ -81,8 +79,8 @@ class AWSDriver(CloudDriver):
|
||||
"DeviceName": resource_conf["ebs_device_name"],
|
||||
"Ebs": {
|
||||
"VolumeSize": resource_conf.get("ebs_volume_size", 80),
|
||||
"VolumeType": resource_conf.get("ebs_volume_type", "gp3")
|
||||
}
|
||||
"VolumeType": resource_conf.get("ebs_volume_type", "gp3"),
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
@ -91,45 +89,33 @@ class AWSDriver(CloudDriver):
|
||||
elif resource_conf.get("availability_zone", None):
|
||||
launch_specification["Placement"] = {"AvailabilityZone": resource_conf["availability_zone"]}
|
||||
else:
|
||||
raise Exception('subnet_id or availability_zone must to be specified in the config')
|
||||
raise Exception("subnet_id or availability_zone must to be specified in the config")
|
||||
if resource_conf.get("key_name", None):
|
||||
launch_specification["KeyName"] = resource_conf["key_name"]
|
||||
if resource_conf.get("security_group_ids", None):
|
||||
launch_specification["SecurityGroupIds"] = resource_conf[
|
||||
"security_group_ids"
|
||||
]
|
||||
launch_specification["SecurityGroupIds"] = resource_conf["security_group_ids"]
|
||||
# Adding iam role - you can have Arn OR Name, not both, Arn getting priority
|
||||
if self.iam_arn:
|
||||
launch_specification["IamInstanceProfile"] = {
|
||||
'Arn': self.iam_arn,
|
||||
"Arn": self.iam_arn,
|
||||
}
|
||||
elif self.iam_name:
|
||||
launch_specification["IamInstanceProfile"] = {
|
||||
'Name': self.iam_name
|
||||
}
|
||||
launch_specification["IamInstanceProfile"] = {"Name": self.iam_name}
|
||||
|
||||
if resource_conf["is_spot"]:
|
||||
# Create a request for a spot instance in AWS
|
||||
encoded_user_data = base64.b64encode(user_data.encode("ascii")).decode(
|
||||
"ascii"
|
||||
)
|
||||
encoded_user_data = base64.b64encode(user_data.encode("ascii")).decode("ascii")
|
||||
launch_specification["UserData"] = encoded_user_data
|
||||
ConfigTree.merge_configs(
|
||||
launch_specification, resource_conf.get("extra_configurations", {})
|
||||
)
|
||||
ConfigTree.merge_configs(launch_specification, resource_conf.get("extra_configurations", {}))
|
||||
|
||||
instances = ec2.request_spot_instances(
|
||||
LaunchSpecification=launch_specification
|
||||
)
|
||||
instances = ec2.request_spot_instances(LaunchSpecification=launch_specification)
|
||||
|
||||
# Wait until spot request is fulfilled
|
||||
request_id = instances["SpotInstanceRequests"][0]["SpotInstanceRequestId"]
|
||||
waiter = ec2.get_waiter("spot_instance_request_fulfilled")
|
||||
waiter.wait(SpotInstanceRequestIds=[request_id])
|
||||
# Get the instance object for later use
|
||||
response = ec2.describe_spot_instance_requests(
|
||||
SpotInstanceRequestIds=[request_id]
|
||||
)
|
||||
response = ec2.describe_spot_instance_requests(SpotInstanceRequestIds=[request_id])
|
||||
instance_id = response["SpotInstanceRequests"][0]["InstanceId"]
|
||||
|
||||
else:
|
||||
@ -140,9 +126,7 @@ class AWSDriver(CloudDriver):
|
||||
UserData=user_data,
|
||||
InstanceInitiatedShutdownBehavior="terminate",
|
||||
)
|
||||
ConfigTree.merge_configs(
|
||||
launch_specification, resource_conf.get("extra_configurations", {})
|
||||
)
|
||||
ConfigTree.merge_configs(launch_specification, resource_conf.get("extra_configurations", {}))
|
||||
|
||||
instances = ec2.run_instances(**launch_specification)
|
||||
|
||||
@ -165,30 +149,32 @@ class AWSDriver(CloudDriver):
|
||||
|
||||
def creds(self):
|
||||
creds = {
|
||||
'region_name': self.aws_region or None,
|
||||
"region_name": self.aws_region or None,
|
||||
}
|
||||
|
||||
if not self.use_credentials_chain:
|
||||
creds.update({
|
||||
'aws_secret_access_key': self.aws_secret_access_key or None,
|
||||
'aws_access_key_id': self.aws_access_key_id or None,
|
||||
'aws_session_token': self.aws_session_token or None,
|
||||
})
|
||||
creds.update(
|
||||
{
|
||||
"aws_secret_access_key": self.aws_secret_access_key or None,
|
||||
"aws_access_key_id": self.aws_access_key_id or None,
|
||||
"aws_session_token": self.aws_session_token or None,
|
||||
}
|
||||
)
|
||||
return creds
|
||||
|
||||
def instance_id_command(self):
|
||||
return 'curl http://169.254.169.254/latest/meta-data/instance-id'
|
||||
return "curl http://169.254.169.254/latest/meta-data/instance-id"
|
||||
|
||||
def instance_type_key(self):
|
||||
return 'instance_type'
|
||||
return "instance_type"
|
||||
|
||||
def kind(self):
|
||||
return 'AWS'
|
||||
return "AWS"
|
||||
|
||||
def console_log(self, instance_id):
|
||||
ec2 = boto3.client("ec2", **self.creds())
|
||||
try:
|
||||
out = ec2.get_console_output(InstanceId=instance_id)
|
||||
return out.get('Output', '')
|
||||
return out.get("Output", "")
|
||||
except ClientError as err:
|
||||
return 'error: cannot get logs for {}:\n{}'.format(instance_id, err)
|
||||
return "error: cannot get logs for {}:\n{}".format(instance_id, err)
|
||||
|
Loading…
Reference in New Issue
Block a user