Improve Auto-Scaler: add extra configurations, key name and security group are now optional

This commit is contained in:
allegroai 2020-12-06 11:20:36 +02:00
parent c3fd3ed7c6
commit 3898563200
3 changed files with 46 additions and 39 deletions

View File

@ -178,11 +178,9 @@ def run_wizard():
), ),
"key_name": get_input( "key_name": get_input(
"the Amazon Key Pair name", "the Amazon Key Pair name",
required=True,
), ),
"security_group_ids": input_list( "security_group_ids": input_list(
"Amazon Security Group ID", "Amazon Security Group ID",
required=True,
), ),
} }

View File

@ -5,15 +5,18 @@ import attr
from .auto_scaler import AutoScaler from .auto_scaler import AutoScaler
from .. import Task from .. import Task
from ..utilities.pyhocon import ConfigTree, ConfigFactory
try: try:
# noinspection PyPackageRequirements # noinspection PyPackageRequirements
import boto3 import boto3
Task.add_requirements('boto3') Task.add_requirements("boto3")
except ImportError: except ImportError:
raise ValueError("AwsAutoScaler requires 'boto3' package, it was not found\n" raise ValueError(
"install with: pip install boto3") "AwsAutoScaler requires 'boto3' package, it was not found\n"
"install with: pip install boto3"
)
class AwsAutoScaler(AutoScaler): class AwsAutoScaler(AutoScaler):
@ -98,31 +101,41 @@ class AwsAutoScaler(AutoScaler):
region_name=self.cloud_credentials_region, region_name=self.cloud_credentials_region,
) )
launch_specification = ConfigFactory.from_dict(
{
"ImageId": resource_conf["ami_id"],
"InstanceType": resource_conf["instance_type"],
"BlockDeviceMappings": [
{
"DeviceName": resource_conf["ebs_device_name"],
"Ebs": {
"VolumeSize": resource_conf["ebs_volume_size"],
"VolumeType": resource_conf["ebs_volume_type"],
},
}
],
"Placement": {"AvailabilityZone": resource_conf["availability_zone"]},
}
)
if resource_conf.get("key_name", None):
launch_specification["KeyName"] = resource_conf["key_name"]
if resource_conf.get("security_group_ids", None):
launch_specification["SecurityGroupIds"] = resource_conf[
"security_group_ids"
]
if resource_conf["is_spot"]: if resource_conf["is_spot"]:
# Create a request for a spot instance in AWS # Create a request for a spot instance in AWS
encoded_user_data = base64.b64encode(user_data.encode("ascii")).decode( encoded_user_data = base64.b64encode(user_data.encode("ascii")).decode(
"ascii" "ascii"
) )
launch_specification["UserData"] = encoded_user_data
ConfigTree.merge_configs(
launch_specification, resource_conf.get("extra_configurations", {})
)
instances = ec2.request_spot_instances( instances = ec2.request_spot_instances(
LaunchSpecification={ LaunchSpecification=launch_specification
"ImageId": resource_conf["ami_id"],
"InstanceType": resource_conf["instance_type"],
"KeyName": resource_conf["key_name"],
"SecurityGroupIds": resource_conf["security_group_ids"],
"Placement": {
"AvailabilityZone": resource_conf["availability_zone"]
},
"UserData": encoded_user_data,
"BlockDeviceMappings": [
{
"DeviceName": resource_conf["ebs_device_name"],
"Ebs": {
"VolumeSize": resource_conf["ebs_volume_size"],
"VolumeType": resource_conf["ebs_volume_type"],
},
}
],
}
) )
# Wait until spot request is fulfilled # Wait until spot request is fulfilled
@ -137,25 +150,17 @@ class AwsAutoScaler(AutoScaler):
else: else:
# Create a new EC2 instance # Create a new EC2 instance
instances = ec2.run_instances( launch_specification.update(
ImageId=resource_conf["ami_id"],
MinCount=1, MinCount=1,
MaxCount=1, MaxCount=1,
InstanceType=resource_conf["instance_type"],
KeyName=resource_conf["key_name"],
SecurityGroupIds=resource_conf["security_group_ids"],
UserData=user_data, UserData=user_data,
InstanceInitiatedShutdownBehavior="terminate", InstanceInitiatedShutdownBehavior="terminate",
BlockDeviceMappings=[
{
"DeviceName": resource_conf["ebs_device_name"],
"Ebs": {
"VolumeSize": resource_conf["ebs_volume_size"],
"VolumeType": resource_conf["ebs_volume_type"],
},
}
],
) )
ConfigTree.merge_configs(
launch_specification, resource_conf.get("extra_configurations", {})
)
instances = ec2.run_instances(**launch_specification)
# Get the instance object for later use # Get the instance object for later use
instance_id = instances["Instances"][0]["InstanceId"] instance_id = instances["Instances"][0]["InstanceId"]

View File

@ -79,7 +79,11 @@ def input_list(
default=None, # type: Optional[str] default=None, # type: Optional[str]
new_line=False, # type: bool new_line=False, # type: bool
): ):
res_list = [get_input(key, description, question, required, default, new_line)] res = get_input(key, description, question, required, default, new_line)
if not res:
return None
res_list = [res]
while input_bool("\nDefine another {}? [y/N]".format(key)): while input_bool("\nDefine another {}? [y/N]".format(key)):
response = get_input( response = get_input(
key=key, key=key,