clearml/examples/services/aws-autoscaler/aws_autoscaler.py
2020-07-14 23:40:05 +03:00

200 lines
6.6 KiB
Python

from argparse import ArgumentParser
from collections import defaultdict
from pathlib import Path
from typing import Tuple
import yaml
from six.moves import input
from trains import Task
from trains.automation.aws_auto_scaler import AwsAutoScaler
from trains.config import running_remotely
from trains.utilities.wizard.user_input import get_input, input_int, input_bool
CONF_FILE = "aws_autoscaler.yaml"
DEFAULT_DOCKER_IMAGE = "nvidia/cuda"
def main():
parser = ArgumentParser()
parser.add_argument(
"--run",
help="Run the autoscaler after wizard finished",
action="store_true",
default=False,
)
args = parser.parse_args()
if running_remotely():
hyper_params = AwsAutoScaler.Settings().as_dict()
configurations = AwsAutoScaler.Configuration().as_dict()
else:
print("AWS Autoscaler setup\n")
config_file = Path(CONF_FILE).absolute()
if config_file.exists() and input_bool(
"Load configurations from config file '{}' [Y/n]? ".format(str(CONF_FILE)),
default=True,
):
with config_file.open("r") as f:
conf = yaml.load(f, Loader=yaml.SafeLoader)
hyper_params = conf["hyper_params"]
configurations = conf["configurations"]
else:
configurations, hyper_params = run_wizard()
try:
with config_file.open("w+") as f:
conf = {
"hyper_params": hyper_params,
"configurations": configurations,
}
yaml.safe_dump(conf, f)
except Exception:
print(
"Error! Could not write configuration file at: {}".format(
str(CONF_FILE)
)
)
return
task = Task.init(project_name="Auto-Scaler", task_name="AWS Auto-Scaler")
task.connect(hyper_params)
task.connect_configuration(configurations)
autoscaler = AwsAutoScaler(hyper_params, configurations)
if running_remotely() or args.run:
autoscaler.start()
def run_wizard():
# type: () -> Tuple[dict, dict]
hyper_params = AwsAutoScaler.Settings()
configurations = AwsAutoScaler.Configuration()
hyper_params.cloud_credentials_key = get_input("AWS Access Key ID", required=True)
hyper_params.cloud_credentials_secret = get_input(
"AWS Secret Access Key", required=True
)
hyper_params.cloud_credentials_region = get_input("AWS region name", required=True)
# get GIT User/Pass for cloning
print(
"\nGIT credentials:"
"\nEnter GIT username for repository cloning (leave blank for SSH key authentication): [] ",
end="",
)
git_user = input()
if git_user.strip():
print("Enter password for user '{}': ".format(git_user), end="")
git_pass = input()
print(
"Git repository cloning will be using user={} password={}".format(
git_user, git_pass
)
)
else:
git_user = None
git_pass = None
hyper_params.git_user = git_user
hyper_params.git_pass = git_pass
hyper_params.default_docker_image = get_input(
"default docker image/parameters",
"to use [default is {}]".format(DEFAULT_DOCKER_IMAGE),
default=DEFAULT_DOCKER_IMAGE,
new_line=True,
)
print("\nDefine the type of machines you want the autoscaler to use")
resource_configurations = {}
while True:
resource_name = get_input(
"machine type name",
"(remember it, we will later use it in the budget section)",
required=True,
new_line=True,
)
resource_configurations[resource_name] = {
"instance_type": get_input(
"instance type",
"for resource '{}' [default is 'g4dn.4xlarge']".format(resource_name),
default="g4dn.4xlarge",
),
"is_spot": input_bool(
"is '{}' resource using spot instances? [t/F]".format(resource_name)
),
"availability_zone": get_input(
"availability zone",
"for resource '{}' [default is 'us-east-1b']".format(resource_name),
default="us-east-1b",
),
"ami_id": get_input(
"ami_id",
"for resource '{}' [default is 'ami-07c95cafbb788face']".format(
resource_name
),
default="ami-07c95cafbb788face",
),
"ebs_device_name": get_input(
"ebs_device_name",
"for resource '{}' [default is '/dev/xvda']".format(resource_name),
default="/dev/xvda",
),
"ebs_volume_size": input_int(
"ebs_volume_size",
" for resource '{}' [default is '100']".format(resource_name),
default=100,
),
"ebs_volume_type": get_input(
"ebs_volume_type",
"for resource '{}' [default is 'gp2']".format(resource_name),
default="gp2",
),
}
if not input_bool("\nDefine another resource? [y/N]"):
break
configurations.resource_configurations = resource_configurations
configurations.extra_vm_bash_script = input(
"\nEnter any pre-execution bash script to be executed on the newly created instances: "
)
print("\nSet up the budget\n")
queues = defaultdict(list)
while True:
queue_name = get_input("queue name", required=True)
while True:
queue_type = get_input(
"queue type",
"(use the resources names defined earlier)",
required=True,
)
max_instances = input_int(
"maximum number of instances allowed", required=True
)
queues[queue_name].append((queue_type, max_instances))
if not input_bool("\nAdd another type to queue? [y/N]: "):
break
if not input_bool("Define another queue? [y/N]: "):
break
configurations.queues = dict(queues)
hyper_params.max_idle_time_min = input_int(
"maximum idle time",
"for the autoscaler (in minutes, default is 15)",
default=15,
new_line=True,
)
hyper_params.polling_interval_time_min = input_int(
"polling interval", "for the autoscaler (in minutes, default is 5)", default=5,
)
return configurations.as_dict(), hyper_params.as_dict()
if __name__ == "__main__":
main()