mirror of
https://github.com/clearml/clearml
synced 2025-06-04 03:47:57 +00:00
Black formatting
This commit is contained in:
parent
ffcda558e7
commit
972696450e
@ -16,34 +16,34 @@ try:
|
|||||||
Task.add_requirements("boto3")
|
Task.add_requirements("boto3")
|
||||||
except ImportError as err:
|
except ImportError as err:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"AwsAutoScaler requires 'boto3' package, it was not found\n"
|
"AwsAutoScaler requires 'boto3' package, it was not found\n" "install with: pip install boto3"
|
||||||
"install with: pip install boto3"
|
|
||||||
) from err
|
) from err
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
class AWSDriver(CloudDriver):
|
class AWSDriver(CloudDriver):
|
||||||
"""AWS Driver"""
|
"""AWS Driver"""
|
||||||
aws_access_key_id = attr.ib(validator=instance_of(str), default='')
|
|
||||||
aws_secret_access_key = attr.ib(validator=instance_of(str), default='')
|
aws_access_key_id = attr.ib(validator=instance_of(str), default="")
|
||||||
aws_session_token = attr.ib(validator=instance_of(str), default='')
|
aws_secret_access_key = attr.ib(validator=instance_of(str), default="")
|
||||||
aws_region = attr.ib(validator=instance_of(str), default='')
|
aws_session_token = attr.ib(validator=instance_of(str), default="")
|
||||||
|
aws_region = attr.ib(validator=instance_of(str), default="")
|
||||||
use_credentials_chain = attr.ib(validator=instance_of(bool), default=False)
|
use_credentials_chain = attr.ib(validator=instance_of(bool), default=False)
|
||||||
use_iam_instance_profile = attr.ib(validator=instance_of(bool), default=False)
|
use_iam_instance_profile = attr.ib(validator=instance_of(bool), default=False)
|
||||||
iam_arn = attr.ib(validator=instance_of(str), default='')
|
iam_arn = attr.ib(validator=instance_of(str), default="")
|
||||||
iam_name = attr.ib(validator=instance_of(str), default='')
|
iam_name = attr.ib(validator=instance_of(str), default="")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config):
|
def from_config(cls, config):
|
||||||
obj = super().from_config(config)
|
obj = super().from_config(config)
|
||||||
obj.aws_access_key_id = config['hyper_params'].get('cloud_credentials_key')
|
obj.aws_access_key_id = config["hyper_params"].get("cloud_credentials_key")
|
||||||
obj.aws_secret_access_key = config['hyper_params'].get('cloud_credentials_secret')
|
obj.aws_secret_access_key = config["hyper_params"].get("cloud_credentials_secret")
|
||||||
obj.aws_session_token = config['hyper_params'].get('cloud_credentials_token')
|
obj.aws_session_token = config["hyper_params"].get("cloud_credentials_token")
|
||||||
obj.aws_region = config['hyper_params'].get('cloud_credentials_region')
|
obj.aws_region = config["hyper_params"].get("cloud_credentials_region")
|
||||||
obj.use_credentials_chain = config['hyper_params'].get('use_credentials_chain', False)
|
obj.use_credentials_chain = config["hyper_params"].get("use_credentials_chain", False)
|
||||||
obj.use_iam_instance_profile = config['hyper_params'].get('use_iam_instance_profile', False)
|
obj.use_iam_instance_profile = config["hyper_params"].get("use_iam_instance_profile", False)
|
||||||
obj.iam_arn = config['hyper_params'].get('iam_arn')
|
obj.iam_arn = config["hyper_params"].get("iam_arn")
|
||||||
obj.iam_name = config['hyper_params'].get('iam_name')
|
obj.iam_name = config["hyper_params"].get("iam_name")
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def __attrs_post_init__(self):
|
def __attrs_post_init__(self):
|
||||||
@ -60,7 +60,7 @@ class AWSDriver(CloudDriver):
|
|||||||
launch_specification = ConfigFactory.from_dict(
|
launch_specification = ConfigFactory.from_dict(
|
||||||
{
|
{
|
||||||
"ImageId": resource_conf["ami_id"],
|
"ImageId": resource_conf["ami_id"],
|
||||||
"Monitoring": {'Enabled': bool(resource_conf.get('enable_monitoring', False))},
|
"Monitoring": {"Enabled": bool(resource_conf.get("enable_monitoring", False))},
|
||||||
"InstanceType": resource_conf["instance_type"],
|
"InstanceType": resource_conf["instance_type"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -70,9 +70,7 @@ class AWSDriver(CloudDriver):
|
|||||||
launch_specification["BlockDeviceMappings"] = [
|
launch_specification["BlockDeviceMappings"] = [
|
||||||
{
|
{
|
||||||
"DeviceName": resource_conf["ebs_device_name"],
|
"DeviceName": resource_conf["ebs_device_name"],
|
||||||
"Ebs": {
|
"Ebs": {"SnapshotId": resource_conf["ebs_snapshot_id"]},
|
||||||
"SnapshotId": resource_conf["ebs_snapshot_id"]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
elif resource_conf.get("ebs_device_name"):
|
elif resource_conf.get("ebs_device_name"):
|
||||||
@ -81,8 +79,8 @@ class AWSDriver(CloudDriver):
|
|||||||
"DeviceName": resource_conf["ebs_device_name"],
|
"DeviceName": resource_conf["ebs_device_name"],
|
||||||
"Ebs": {
|
"Ebs": {
|
||||||
"VolumeSize": resource_conf.get("ebs_volume_size", 80),
|
"VolumeSize": resource_conf.get("ebs_volume_size", 80),
|
||||||
"VolumeType": resource_conf.get("ebs_volume_type", "gp3")
|
"VolumeType": resource_conf.get("ebs_volume_type", "gp3"),
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -91,45 +89,33 @@ class AWSDriver(CloudDriver):
|
|||||||
elif resource_conf.get("availability_zone", None):
|
elif resource_conf.get("availability_zone", None):
|
||||||
launch_specification["Placement"] = {"AvailabilityZone": resource_conf["availability_zone"]}
|
launch_specification["Placement"] = {"AvailabilityZone": resource_conf["availability_zone"]}
|
||||||
else:
|
else:
|
||||||
raise Exception('subnet_id or availability_zone must to be specified in the config')
|
raise Exception("subnet_id or availability_zone must to be specified in the config")
|
||||||
if resource_conf.get("key_name", None):
|
if resource_conf.get("key_name", None):
|
||||||
launch_specification["KeyName"] = resource_conf["key_name"]
|
launch_specification["KeyName"] = resource_conf["key_name"]
|
||||||
if resource_conf.get("security_group_ids", None):
|
if resource_conf.get("security_group_ids", None):
|
||||||
launch_specification["SecurityGroupIds"] = resource_conf[
|
launch_specification["SecurityGroupIds"] = resource_conf["security_group_ids"]
|
||||||
"security_group_ids"
|
|
||||||
]
|
|
||||||
# Adding iam role - you can have Arn OR Name, not both, Arn getting priority
|
# Adding iam role - you can have Arn OR Name, not both, Arn getting priority
|
||||||
if self.iam_arn:
|
if self.iam_arn:
|
||||||
launch_specification["IamInstanceProfile"] = {
|
launch_specification["IamInstanceProfile"] = {
|
||||||
'Arn': self.iam_arn,
|
"Arn": self.iam_arn,
|
||||||
}
|
}
|
||||||
elif self.iam_name:
|
elif self.iam_name:
|
||||||
launch_specification["IamInstanceProfile"] = {
|
launch_specification["IamInstanceProfile"] = {"Name": self.iam_name}
|
||||||
'Name': self.iam_name
|
|
||||||
}
|
|
||||||
|
|
||||||
if resource_conf["is_spot"]:
|
if resource_conf["is_spot"]:
|
||||||
# Create a request for a spot instance in AWS
|
# Create a request for a spot instance in AWS
|
||||||
encoded_user_data = base64.b64encode(user_data.encode("ascii")).decode(
|
encoded_user_data = base64.b64encode(user_data.encode("ascii")).decode("ascii")
|
||||||
"ascii"
|
|
||||||
)
|
|
||||||
launch_specification["UserData"] = encoded_user_data
|
launch_specification["UserData"] = encoded_user_data
|
||||||
ConfigTree.merge_configs(
|
ConfigTree.merge_configs(launch_specification, resource_conf.get("extra_configurations", {}))
|
||||||
launch_specification, resource_conf.get("extra_configurations", {})
|
|
||||||
)
|
|
||||||
|
|
||||||
instances = ec2.request_spot_instances(
|
instances = ec2.request_spot_instances(LaunchSpecification=launch_specification)
|
||||||
LaunchSpecification=launch_specification
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait until spot request is fulfilled
|
# Wait until spot request is fulfilled
|
||||||
request_id = instances["SpotInstanceRequests"][0]["SpotInstanceRequestId"]
|
request_id = instances["SpotInstanceRequests"][0]["SpotInstanceRequestId"]
|
||||||
waiter = ec2.get_waiter("spot_instance_request_fulfilled")
|
waiter = ec2.get_waiter("spot_instance_request_fulfilled")
|
||||||
waiter.wait(SpotInstanceRequestIds=[request_id])
|
waiter.wait(SpotInstanceRequestIds=[request_id])
|
||||||
# Get the instance object for later use
|
# Get the instance object for later use
|
||||||
response = ec2.describe_spot_instance_requests(
|
response = ec2.describe_spot_instance_requests(SpotInstanceRequestIds=[request_id])
|
||||||
SpotInstanceRequestIds=[request_id]
|
|
||||||
)
|
|
||||||
instance_id = response["SpotInstanceRequests"][0]["InstanceId"]
|
instance_id = response["SpotInstanceRequests"][0]["InstanceId"]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -140,9 +126,7 @@ class AWSDriver(CloudDriver):
|
|||||||
UserData=user_data,
|
UserData=user_data,
|
||||||
InstanceInitiatedShutdownBehavior="terminate",
|
InstanceInitiatedShutdownBehavior="terminate",
|
||||||
)
|
)
|
||||||
ConfigTree.merge_configs(
|
ConfigTree.merge_configs(launch_specification, resource_conf.get("extra_configurations", {}))
|
||||||
launch_specification, resource_conf.get("extra_configurations", {})
|
|
||||||
)
|
|
||||||
|
|
||||||
instances = ec2.run_instances(**launch_specification)
|
instances = ec2.run_instances(**launch_specification)
|
||||||
|
|
||||||
@ -165,30 +149,32 @@ class AWSDriver(CloudDriver):
|
|||||||
|
|
||||||
def creds(self):
|
def creds(self):
|
||||||
creds = {
|
creds = {
|
||||||
'region_name': self.aws_region or None,
|
"region_name": self.aws_region or None,
|
||||||
}
|
}
|
||||||
|
|
||||||
if not self.use_credentials_chain:
|
if not self.use_credentials_chain:
|
||||||
creds.update({
|
creds.update(
|
||||||
'aws_secret_access_key': self.aws_secret_access_key or None,
|
{
|
||||||
'aws_access_key_id': self.aws_access_key_id or None,
|
"aws_secret_access_key": self.aws_secret_access_key or None,
|
||||||
'aws_session_token': self.aws_session_token or None,
|
"aws_access_key_id": self.aws_access_key_id or None,
|
||||||
})
|
"aws_session_token": self.aws_session_token or None,
|
||||||
|
}
|
||||||
|
)
|
||||||
return creds
|
return creds
|
||||||
|
|
||||||
def instance_id_command(self):
|
def instance_id_command(self):
|
||||||
return 'curl http://169.254.169.254/latest/meta-data/instance-id'
|
return "curl http://169.254.169.254/latest/meta-data/instance-id"
|
||||||
|
|
||||||
def instance_type_key(self):
|
def instance_type_key(self):
|
||||||
return 'instance_type'
|
return "instance_type"
|
||||||
|
|
||||||
def kind(self):
|
def kind(self):
|
||||||
return 'AWS'
|
return "AWS"
|
||||||
|
|
||||||
def console_log(self, instance_id):
|
def console_log(self, instance_id):
|
||||||
ec2 = boto3.client("ec2", **self.creds())
|
ec2 = boto3.client("ec2", **self.creds())
|
||||||
try:
|
try:
|
||||||
out = ec2.get_console_output(InstanceId=instance_id)
|
out = ec2.get_console_output(InstanceId=instance_id)
|
||||||
return out.get('Output', '')
|
return out.get("Output", "")
|
||||||
except ClientError as err:
|
except ClientError as err:
|
||||||
return 'error: cannot get logs for {}:\n{}'.format(instance_id, err)
|
return "error: cannot get logs for {}:\n{}".format(instance_id, err)
|
||||||
|
Loading…
Reference in New Issue
Block a user