mirror of
https://github.com/clearml/clearml
synced 2025-02-07 13:23:40 +00:00
Add key pair and security groups to AWS auto-scaler
This commit is contained in:
parent
2a34d6cec2
commit
98ea965e6d
@ -1,16 +1,22 @@
|
|||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib2 import Path
|
|
||||||
from typing import Tuple
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
from pathlib2 import Path
|
||||||
from six.moves import input
|
from six.moves import input
|
||||||
|
|
||||||
from trains import Task
|
from trains import Task
|
||||||
from trains.automation.aws_auto_scaler import AwsAutoScaler
|
from trains.automation.aws_auto_scaler import AwsAutoScaler
|
||||||
from trains.config import running_remotely
|
from trains.config import running_remotely
|
||||||
from trains.utilities.wizard.user_input import get_input, input_int, input_bool, multiline_input
|
from trains.utilities.wizard.user_input import (
|
||||||
|
get_input,
|
||||||
|
input_int,
|
||||||
|
input_bool,
|
||||||
|
multiline_input,
|
||||||
|
input_list,
|
||||||
|
)
|
||||||
|
|
||||||
CONF_FILE = "aws_autoscaler.yaml"
|
CONF_FILE = "aws_autoscaler.yaml"
|
||||||
DEFAULT_DOCKER_IMAGE = "nvidia/cuda:10.1-runtime-ubuntu18.04"
|
DEFAULT_DOCKER_IMAGE = "nvidia/cuda:10.1-runtime-ubuntu18.04"
|
||||||
@ -170,6 +176,14 @@ def run_wizard():
|
|||||||
"['gp2']",
|
"['gp2']",
|
||||||
default="gp2",
|
default="gp2",
|
||||||
),
|
),
|
||||||
|
"key_name": get_input(
|
||||||
|
"the Amazon Key Pair name",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"security_group_ids": input_list(
|
||||||
|
"Amazon Security Group ID",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
@ -84,7 +84,7 @@ class AwsAutoScaler(AutoScaler):
|
|||||||
queue=queue_name,
|
queue=queue_name,
|
||||||
git_user=self.git_user or "",
|
git_user=self.git_user or "",
|
||||||
git_pass=self.git_pass or "",
|
git_pass=self.git_pass or "",
|
||||||
trains_conf=self.extra_trains_conf,
|
trains_conf='\\"'.join(self.extra_trains_conf.split('"')),
|
||||||
bash_script=self.extra_vm_bash_script,
|
bash_script=self.extra_vm_bash_script,
|
||||||
docker="--docker '{}'".format(self.default_docker_image)
|
docker="--docker '{}'".format(self.default_docker_image)
|
||||||
if self.default_docker_image
|
if self.default_docker_image
|
||||||
@ -107,6 +107,8 @@ class AwsAutoScaler(AutoScaler):
|
|||||||
LaunchSpecification={
|
LaunchSpecification={
|
||||||
"ImageId": resource_conf["ami_id"],
|
"ImageId": resource_conf["ami_id"],
|
||||||
"InstanceType": resource_conf["instance_type"],
|
"InstanceType": resource_conf["instance_type"],
|
||||||
|
"KeyName": resource_conf["key_name"],
|
||||||
|
"SecurityGroupIds": resource_conf["security_group_ids"],
|
||||||
"Placement": {
|
"Placement": {
|
||||||
"AvailabilityZone": resource_conf["availability_zone"]
|
"AvailabilityZone": resource_conf["availability_zone"]
|
||||||
},
|
},
|
||||||
@ -140,6 +142,8 @@ class AwsAutoScaler(AutoScaler):
|
|||||||
MinCount=1,
|
MinCount=1,
|
||||||
MaxCount=1,
|
MaxCount=1,
|
||||||
InstanceType=resource_conf["instance_type"],
|
InstanceType=resource_conf["instance_type"],
|
||||||
|
KeyName=resource_conf["key_name"],
|
||||||
|
SecurityGroupIds=resource_conf["security_group_ids"],
|
||||||
UserData=user_data,
|
UserData=user_data,
|
||||||
InstanceInitiatedShutdownBehavior="terminate",
|
InstanceInitiatedShutdownBehavior="terminate",
|
||||||
BlockDeviceMappings=[
|
BlockDeviceMappings=[
|
||||||
|
@ -71,6 +71,29 @@ def input_bool(question, default=False):
|
|||||||
print("Invalid input: please enter 'yes' or 'no'")
|
print("Invalid input: please enter 'yes' or 'no'")
|
||||||
|
|
||||||
|
|
||||||
|
def input_list(
|
||||||
|
key, # type: str
|
||||||
|
description="", # type: str
|
||||||
|
question="Enter", # type: str
|
||||||
|
required=False, # type: bool
|
||||||
|
default=None, # type: Optional[str]
|
||||||
|
new_line=False, # type: bool
|
||||||
|
):
|
||||||
|
res_list = [get_input(key, description, question, required, default, new_line)]
|
||||||
|
while input_bool("\nDefine another {}? [y/N]".format(key)):
|
||||||
|
response = get_input(
|
||||||
|
key=key,
|
||||||
|
description=description,
|
||||||
|
question=question,
|
||||||
|
required=False,
|
||||||
|
default=default,
|
||||||
|
new_line=new_line,
|
||||||
|
)
|
||||||
|
if response:
|
||||||
|
res_list.append(response)
|
||||||
|
return res_list
|
||||||
|
|
||||||
|
|
||||||
def multiline_input(description=""):
|
def multiline_input(description=""):
|
||||||
print("{} \nNote: two consecutive empty lines would terminate the input : ".format(description))
|
print("{} \nNote: two consecutive empty lines would terminate the input : ".format(description))
|
||||||
lines = []
|
lines = []
|
||||||
|
Loading…
Reference in New Issue
Block a user