diff --git a/examples/services/aws-autoscaler/aws_autoscaler.py b/examples/services/aws-autoscaler/aws_autoscaler.py index 8be76eba..4ff689e8 100644 --- a/examples/services/aws-autoscaler/aws_autoscaler.py +++ b/examples/services/aws-autoscaler/aws_autoscaler.py @@ -178,11 +178,9 @@ def run_wizard(): ), "key_name": get_input( "the Amazon Key Pair name", - required=True, ), "security_group_ids": input_list( "Amazon Security Group ID", - required=True, ), } diff --git a/trains/automation/aws_auto_scaler.py b/trains/automation/aws_auto_scaler.py index bc100440..0c8c4b2d 100644 --- a/trains/automation/aws_auto_scaler.py +++ b/trains/automation/aws_auto_scaler.py @@ -5,15 +5,18 @@ import attr from .auto_scaler import AutoScaler from .. import Task +from ..utilities.pyhocon import ConfigTree, ConfigFactory try: # noinspection PyPackageRequirements import boto3 - Task.add_requirements('boto3') + Task.add_requirements("boto3") except ImportError: - raise ValueError("AwsAutoScaler requires 'boto3' package, it was not found\n" - "install with: pip install boto3") + raise ValueError( + "AwsAutoScaler requires 'boto3' package, it was not found\n" + "install with: pip install boto3" + ) class AwsAutoScaler(AutoScaler): @@ -98,31 +101,41 @@ class AwsAutoScaler(AutoScaler): region_name=self.cloud_credentials_region, ) + launch_specification = ConfigFactory.from_dict( + { + "ImageId": resource_conf["ami_id"], + "InstanceType": resource_conf["instance_type"], + "BlockDeviceMappings": [ + { + "DeviceName": resource_conf["ebs_device_name"], + "Ebs": { + "VolumeSize": resource_conf["ebs_volume_size"], + "VolumeType": resource_conf["ebs_volume_type"], + }, + } + ], + "Placement": {"AvailabilityZone": resource_conf["availability_zone"]}, + } + ) + if resource_conf.get("key_name", None): + launch_specification["KeyName"] = resource_conf["key_name"] + if resource_conf.get("security_group_ids", None): + launch_specification["SecurityGroupIds"] = resource_conf[ + "security_group_ids" + ] + if resource_conf["is_spot"]: # Create a request for a spot instance in AWS encoded_user_data = base64.b64encode(user_data.encode("ascii")).decode( "ascii" ) + launch_specification["UserData"] = encoded_user_data + ConfigTree.merge_configs( + launch_specification, resource_conf.get("extra_configurations", {}) + ) + instances = ec2.request_spot_instances( - LaunchSpecification={ - "ImageId": resource_conf["ami_id"], - "InstanceType": resource_conf["instance_type"], - "KeyName": resource_conf["key_name"], - "SecurityGroupIds": resource_conf["security_group_ids"], - "Placement": { - "AvailabilityZone": resource_conf["availability_zone"] - }, - "UserData": encoded_user_data, - "BlockDeviceMappings": [ - { - "DeviceName": resource_conf["ebs_device_name"], - "Ebs": { - "VolumeSize": resource_conf["ebs_volume_size"], - "VolumeType": resource_conf["ebs_volume_type"], - }, - } - ], - } + LaunchSpecification=launch_specification ) # Wait until spot request is fulfilled @@ -137,25 +150,17 @@ class AwsAutoScaler(AutoScaler): else: # Create a new EC2 instance - instances = ec2.run_instances( - ImageId=resource_conf["ami_id"], + launch_specification.update( MinCount=1, MaxCount=1, - InstanceType=resource_conf["instance_type"], - KeyName=resource_conf["key_name"], - SecurityGroupIds=resource_conf["security_group_ids"], UserData=user_data, InstanceInitiatedShutdownBehavior="terminate", - BlockDeviceMappings=[ - { - "DeviceName": resource_conf["ebs_device_name"], - "Ebs": { - "VolumeSize": resource_conf["ebs_volume_size"], - "VolumeType": resource_conf["ebs_volume_type"], - }, - } - ], ) + ConfigTree.merge_configs( + launch_specification, resource_conf.get("extra_configurations", {}) + ) + + instances = ec2.run_instances(**launch_specification) # Get the instance object for later use instance_id = instances["Instances"][0]["InstanceId"] diff --git a/trains/utilities/wizard/user_input.py b/trains/utilities/wizard/user_input.py index c5b4bb84..b83e5e21 100644 --- a/trains/utilities/wizard/user_input.py +++ b/trains/utilities/wizard/user_input.py @@ -79,7 +79,11 @@ def input_list( default=None, # type: Optional[str] new_line=False, # type: bool ): - res_list = [get_input(key, description, question, required, default, new_line)] + res = get_input(key, description, question, required, default, new_line) + if not res: + return None + + res_list = [res] while input_bool("\nDefine another {}? [y/N]".format(key)): response = get_input( key=key,