Add key pair and security groups to AWS auto-scaler

This commit is contained in:
allegroai 2020-10-12 11:12:33 +03:00
parent 2a34d6cec2
commit 98ea965e6d
3 changed files with 45 additions and 4 deletions

View File

@ -1,16 +1,22 @@
from argparse import ArgumentParser from argparse import ArgumentParser
from collections import defaultdict from collections import defaultdict
from pathlib2 import Path
from typing import Tuple
from itertools import chain from itertools import chain
from typing import Tuple
import yaml import yaml
from pathlib2 import Path
from six.moves import input from six.moves import input
from trains import Task from trains import Task
from trains.automation.aws_auto_scaler import AwsAutoScaler from trains.automation.aws_auto_scaler import AwsAutoScaler
from trains.config import running_remotely from trains.config import running_remotely
from trains.utilities.wizard.user_input import get_input, input_int, input_bool, multiline_input from trains.utilities.wizard.user_input import (
get_input,
input_int,
input_bool,
multiline_input,
input_list,
)
CONF_FILE = "aws_autoscaler.yaml" CONF_FILE = "aws_autoscaler.yaml"
DEFAULT_DOCKER_IMAGE = "nvidia/cuda:10.1-runtime-ubuntu18.04" DEFAULT_DOCKER_IMAGE = "nvidia/cuda:10.1-runtime-ubuntu18.04"
@ -170,6 +176,14 @@ def run_wizard():
"['gp2']", "['gp2']",
default="gp2", default="gp2",
), ),
"key_name": get_input(
"the Amazon Key Pair name",
required=True,
),
"security_group_ids": input_list(
"Amazon Security Group ID",
required=True,
),
} }
while True: while True:

View File

@ -84,7 +84,7 @@ class AwsAutoScaler(AutoScaler):
queue=queue_name, queue=queue_name,
git_user=self.git_user or "", git_user=self.git_user or "",
git_pass=self.git_pass or "", git_pass=self.git_pass or "",
trains_conf=self.extra_trains_conf, trains_conf='\\"'.join(self.extra_trains_conf.split('"')),
bash_script=self.extra_vm_bash_script, bash_script=self.extra_vm_bash_script,
docker="--docker '{}'".format(self.default_docker_image) docker="--docker '{}'".format(self.default_docker_image)
if self.default_docker_image if self.default_docker_image
@ -107,6 +107,8 @@ class AwsAutoScaler(AutoScaler):
LaunchSpecification={ LaunchSpecification={
"ImageId": resource_conf["ami_id"], "ImageId": resource_conf["ami_id"],
"InstanceType": resource_conf["instance_type"], "InstanceType": resource_conf["instance_type"],
"KeyName": resource_conf["key_name"],
"SecurityGroupIds": resource_conf["security_group_ids"],
"Placement": { "Placement": {
"AvailabilityZone": resource_conf["availability_zone"] "AvailabilityZone": resource_conf["availability_zone"]
}, },
@ -140,6 +142,8 @@ class AwsAutoScaler(AutoScaler):
MinCount=1, MinCount=1,
MaxCount=1, MaxCount=1,
InstanceType=resource_conf["instance_type"], InstanceType=resource_conf["instance_type"],
KeyName=resource_conf["key_name"],
SecurityGroupIds=resource_conf["security_group_ids"],
UserData=user_data, UserData=user_data,
InstanceInitiatedShutdownBehavior="terminate", InstanceInitiatedShutdownBehavior="terminate",
BlockDeviceMappings=[ BlockDeviceMappings=[

View File

@ -71,6 +71,29 @@ def input_bool(question, default=False):
print("Invalid input: please enter 'yes' or 'no'") print("Invalid input: please enter 'yes' or 'no'")
def input_list(
key, # type: str
description="", # type: str
question="Enter", # type: str
required=False, # type: bool
default=None, # type: Optional[str]
new_line=False, # type: bool
):
res_list = [get_input(key, description, question, required, default, new_line)]
while input_bool("\nDefine another {}? [y/N]".format(key)):
response = get_input(
key=key,
description=description,
question=question,
required=False,
default=default,
new_line=new_line,
)
if response:
res_list.append(response)
return res_list
def multiline_input(description=""): def multiline_input(description=""):
print("{} \nNote: two consecutive empty lines would terminate the input : ".format(description)) print("{} \nNote: two consecutive empty lines would terminate the input : ".format(description))
lines = [] lines = []