mirror of
https://github.com/clearml/clearml
synced 2025-01-31 09:07:00 +00:00
Add key pair and security groups to AWS auto-scaler
This commit is contained in:
parent
2a34d6cec2
commit
98ea965e6d
@ -1,16 +1,22 @@
|
||||
from argparse import ArgumentParser
|
||||
from collections import defaultdict
|
||||
from pathlib2 import Path
|
||||
from typing import Tuple
|
||||
from itertools import chain
|
||||
from typing import Tuple
|
||||
|
||||
import yaml
|
||||
from pathlib2 import Path
|
||||
from six.moves import input
|
||||
|
||||
from trains import Task
|
||||
from trains.automation.aws_auto_scaler import AwsAutoScaler
|
||||
from trains.config import running_remotely
|
||||
from trains.utilities.wizard.user_input import get_input, input_int, input_bool, multiline_input
|
||||
from trains.utilities.wizard.user_input import (
|
||||
get_input,
|
||||
input_int,
|
||||
input_bool,
|
||||
multiline_input,
|
||||
input_list,
|
||||
)
|
||||
|
||||
CONF_FILE = "aws_autoscaler.yaml"
|
||||
DEFAULT_DOCKER_IMAGE = "nvidia/cuda:10.1-runtime-ubuntu18.04"
|
||||
@ -170,6 +176,14 @@ def run_wizard():
|
||||
"['gp2']",
|
||||
default="gp2",
|
||||
),
|
||||
"key_name": get_input(
|
||||
"the Amazon Key Pair name",
|
||||
required=True,
|
||||
),
|
||||
"security_group_ids": input_list(
|
||||
"Amazon Security Group ID",
|
||||
required=True,
|
||||
),
|
||||
}
|
||||
|
||||
while True:
|
||||
|
@ -84,7 +84,7 @@ class AwsAutoScaler(AutoScaler):
|
||||
queue=queue_name,
|
||||
git_user=self.git_user or "",
|
||||
git_pass=self.git_pass or "",
|
||||
trains_conf=self.extra_trains_conf,
|
||||
trains_conf='\\"'.join(self.extra_trains_conf.split('"')),
|
||||
bash_script=self.extra_vm_bash_script,
|
||||
docker="--docker '{}'".format(self.default_docker_image)
|
||||
if self.default_docker_image
|
||||
@ -107,6 +107,8 @@ class AwsAutoScaler(AutoScaler):
|
||||
LaunchSpecification={
|
||||
"ImageId": resource_conf["ami_id"],
|
||||
"InstanceType": resource_conf["instance_type"],
|
||||
"KeyName": resource_conf["key_name"],
|
||||
"SecurityGroupIds": resource_conf["security_group_ids"],
|
||||
"Placement": {
|
||||
"AvailabilityZone": resource_conf["availability_zone"]
|
||||
},
|
||||
@ -140,6 +142,8 @@ class AwsAutoScaler(AutoScaler):
|
||||
MinCount=1,
|
||||
MaxCount=1,
|
||||
InstanceType=resource_conf["instance_type"],
|
||||
KeyName=resource_conf["key_name"],
|
||||
SecurityGroupIds=resource_conf["security_group_ids"],
|
||||
UserData=user_data,
|
||||
InstanceInitiatedShutdownBehavior="terminate",
|
||||
BlockDeviceMappings=[
|
||||
|
@ -71,6 +71,29 @@ def input_bool(question, default=False):
|
||||
print("Invalid input: please enter 'yes' or 'no'")
|
||||
|
||||
|
||||
def input_list(
|
||||
key, # type: str
|
||||
description="", # type: str
|
||||
question="Enter", # type: str
|
||||
required=False, # type: bool
|
||||
default=None, # type: Optional[str]
|
||||
new_line=False, # type: bool
|
||||
):
|
||||
res_list = [get_input(key, description, question, required, default, new_line)]
|
||||
while input_bool("\nDefine another {}? [y/N]".format(key)):
|
||||
response = get_input(
|
||||
key=key,
|
||||
description=description,
|
||||
question=question,
|
||||
required=False,
|
||||
default=default,
|
||||
new_line=new_line,
|
||||
)
|
||||
if response:
|
||||
res_list.append(response)
|
||||
return res_list
|
||||
|
||||
|
||||
def multiline_input(description=""):
|
||||
print("{} \nNote: two consecutive empty lines would terminate the input : ".format(description))
|
||||
lines = []
|
||||
|
Loading…
Reference in New Issue
Block a user