Add key pair and security groups to AWS auto-scaler

This commit is contained in:
allegroai 2020-10-12 11:12:33 +03:00
parent 2a34d6cec2
commit 98ea965e6d
3 changed files with 45 additions and 4 deletions

View File

@ -1,16 +1,22 @@
from argparse import ArgumentParser
from collections import defaultdict
from pathlib2 import Path
from typing import Tuple
from itertools import chain
from typing import Tuple
import yaml
from pathlib2 import Path
from six.moves import input
from trains import Task
from trains.automation.aws_auto_scaler import AwsAutoScaler
from trains.config import running_remotely
from trains.utilities.wizard.user_input import get_input, input_int, input_bool, multiline_input
from trains.utilities.wizard.user_input import (
get_input,
input_int,
input_bool,
multiline_input,
input_list,
)
CONF_FILE = "aws_autoscaler.yaml"
DEFAULT_DOCKER_IMAGE = "nvidia/cuda:10.1-runtime-ubuntu18.04"
@ -170,6 +176,14 @@ def run_wizard():
"['gp2']",
default="gp2",
),
"key_name": get_input(
"the Amazon Key Pair name",
required=True,
),
"security_group_ids": input_list(
"Amazon Security Group ID",
required=True,
),
}
while True:

View File

@ -84,7 +84,7 @@ class AwsAutoScaler(AutoScaler):
queue=queue_name,
git_user=self.git_user or "",
git_pass=self.git_pass or "",
trains_conf=self.extra_trains_conf,
trains_conf='\\"'.join(self.extra_trains_conf.split('"')),
bash_script=self.extra_vm_bash_script,
docker="--docker '{}'".format(self.default_docker_image)
if self.default_docker_image
@ -107,6 +107,8 @@ class AwsAutoScaler(AutoScaler):
LaunchSpecification={
"ImageId": resource_conf["ami_id"],
"InstanceType": resource_conf["instance_type"],
"KeyName": resource_conf["key_name"],
"SecurityGroupIds": resource_conf["security_group_ids"],
"Placement": {
"AvailabilityZone": resource_conf["availability_zone"]
},
@ -140,6 +142,8 @@ class AwsAutoScaler(AutoScaler):
MinCount=1,
MaxCount=1,
InstanceType=resource_conf["instance_type"],
KeyName=resource_conf["key_name"],
SecurityGroupIds=resource_conf["security_group_ids"],
UserData=user_data,
InstanceInitiatedShutdownBehavior="terminate",
BlockDeviceMappings=[

View File

@ -71,6 +71,29 @@ def input_bool(question, default=False):
print("Invalid input: please enter 'yes' or 'no'")
def input_list(
key, # type: str
description="", # type: str
question="Enter", # type: str
required=False, # type: bool
default=None, # type: Optional[str]
new_line=False, # type: bool
):
res_list = [get_input(key, description, question, required, default, new_line)]
while input_bool("\nDefine another {}? [y/N]".format(key)):
response = get_input(
key=key,
description=description,
question=question,
required=False,
default=default,
new_line=new_line,
)
if response:
res_list.append(response)
return res_list
def multiline_input(description=""):
print("{} \nNote: two consecutive empty lines would terminate the input : ".format(description))
lines = []