From 98ea965e6d0a39019b31e5fec146ab55da0d52e8 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 12 Oct 2020 11:12:33 +0300 Subject: [PATCH] Add key pair and security groups to AWS auto-scaler --- .../services/aws-autoscaler/aws_autoscaler.py | 20 +++++++++++++--- trains/automation/aws_auto_scaler.py | 6 ++++- trains/utilities/wizard/user_input.py | 23 +++++++++++++++++++ 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/examples/services/aws-autoscaler/aws_autoscaler.py b/examples/services/aws-autoscaler/aws_autoscaler.py index f92dfabc..8be76eba 100644 --- a/examples/services/aws-autoscaler/aws_autoscaler.py +++ b/examples/services/aws-autoscaler/aws_autoscaler.py @@ -1,16 +1,22 @@ from argparse import ArgumentParser from collections import defaultdict -from pathlib2 import Path -from typing import Tuple from itertools import chain +from typing import Tuple import yaml +from pathlib2 import Path from six.moves import input from trains import Task from trains.automation.aws_auto_scaler import AwsAutoScaler from trains.config import running_remotely -from trains.utilities.wizard.user_input import get_input, input_int, input_bool, multiline_input +from trains.utilities.wizard.user_input import ( + get_input, + input_int, + input_bool, + multiline_input, + input_list, +) CONF_FILE = "aws_autoscaler.yaml" DEFAULT_DOCKER_IMAGE = "nvidia/cuda:10.1-runtime-ubuntu18.04" @@ -170,6 +176,14 @@ def run_wizard(): "['gp2']", default="gp2", ), + "key_name": get_input( + "the Amazon Key Pair name", + required=True, + ), + "security_group_ids": input_list( + "Amazon Security Group ID", + required=True, + ), } while True: diff --git a/trains/automation/aws_auto_scaler.py b/trains/automation/aws_auto_scaler.py index 56ffca30..bc100440 100644 --- a/trains/automation/aws_auto_scaler.py +++ b/trains/automation/aws_auto_scaler.py @@ -84,7 +84,7 @@ class AwsAutoScaler(AutoScaler): queue=queue_name, git_user=self.git_user or "", git_pass=self.git_pass or "", - trains_conf=self.extra_trains_conf, + trains_conf='\\"'.join(self.extra_trains_conf.split('"')), bash_script=self.extra_vm_bash_script, docker="--docker '{}'".format(self.default_docker_image) if self.default_docker_image @@ -107,6 +107,8 @@ class AwsAutoScaler(AutoScaler): LaunchSpecification={ "ImageId": resource_conf["ami_id"], "InstanceType": resource_conf["instance_type"], + "KeyName": resource_conf["key_name"], + "SecurityGroupIds": resource_conf["security_group_ids"], "Placement": { "AvailabilityZone": resource_conf["availability_zone"] }, @@ -140,6 +142,8 @@ class AwsAutoScaler(AutoScaler): MinCount=1, MaxCount=1, InstanceType=resource_conf["instance_type"], + KeyName=resource_conf["key_name"], + SecurityGroupIds=resource_conf["security_group_ids"], UserData=user_data, InstanceInitiatedShutdownBehavior="terminate", BlockDeviceMappings=[ diff --git a/trains/utilities/wizard/user_input.py b/trains/utilities/wizard/user_input.py index bd1bf542..c5b4bb84 100644 --- a/trains/utilities/wizard/user_input.py +++ b/trains/utilities/wizard/user_input.py @@ -71,6 +71,29 @@ def input_bool(question, default=False): print("Invalid input: please enter 'yes' or 'no'") +def input_list( + key, # type: str + description="", # type: str + question="Enter", # type: str + required=False, # type: bool + default=None, # type: Optional[str] + new_line=False, # type: bool +): + res_list = [get_input(key, description, question, required, default, new_line)] + while input_bool("\nDefine another {}? [y/N]".format(key)): + response = get_input( + key=key, + description=description, + question=question, + required=False, + default=default, + new_line=new_line, + ) + if response: + res_list.append(response) + return res_list + + def multiline_input(description=""): print("{} \nNote: two consecutive empty lines would terminate the input : ".format(description)) lines = []