import distutils from argparse import ArgumentParser from collections import defaultdict from pathlib import Path from typing import Optional, Tuple import yaml from six.moves import input from trains import Task from trains.automation.aws_auto_scaler import AwsAutoScaler from trains.config import running_remotely CONF_FILE = "aws_autoscaler.yaml" DEFAULT_DOCKER_IMAGE = "nvidia/cuda" def main(): parser = ArgumentParser() parser.add_argument( "--run", help="Run the autoscaler after wizard finished", action="store_true", default=False, ) args = parser.parse_args() if running_remotely(): hyper_params = AwsAutoScaler.Settings().as_dict() configurations = AwsAutoScaler.Configuration().as_dict() else: print("AWS Autoscaler setup\n") config_file = Path(CONF_FILE).absolute() if config_file.exists() and input_bool( "Load configurations from config file '{}' [Y/n]? ".format(str(CONF_FILE)), default=True, ): with config_file.open("r") as f: conf = yaml.load(f, Loader=yaml.SafeLoader) hyper_params = conf["hyper_params"] configurations = conf["configurations"] else: configurations, hyper_params = run_wizard() try: with config_file.open("w+") as f: conf = { "hyper_params": hyper_params, "configurations": configurations, } yaml.safe_dump(conf, f) except Exception: print( "Error! Could not write configuration file at: {}".format( str(CONF_FILE) ) ) return task = Task.init(project_name="Auto-Scaler", task_name="AWS Auto-Scaler") task.connect(hyper_params) task.connect_configuration(configurations) autoscaler = AwsAutoScaler(hyper_params, configurations) if running_remotely() or args.run: autoscaler.start() def run_wizard(): # type: () -> Tuple[dict, dict] hyper_params = AwsAutoScaler.Settings() configurations = AwsAutoScaler.Configuration() hyper_params.cloud_credentials_key = get_input("AWS Access Key ID", required=True) hyper_params.cloud_credentials_secret = get_input( "AWS Secret Access Key", required=True ) hyper_params.cloud_credentials_region = get_input("AWS region name", required=True) # get GIT User/Pass for cloning print( "\nGIT credentials:" "\nEnter GIT username for repository cloning (leave blank for SSH key authentication): [] ", end="", ) git_user = input() if git_user.strip(): print("Enter password for user '{}': ".format(git_user), end="") git_pass = input() print( "Git repository cloning will be using user={} password={}".format( git_user, git_pass ) ) else: git_user = None git_pass = None hyper_params.git_user = git_user hyper_params.git_pass = git_pass hyper_params.default_docker_image = get_input( "default docker image/parameters", "to use [default is {}]".format(DEFAULT_DOCKER_IMAGE), default=DEFAULT_DOCKER_IMAGE, new_line=True, ) print("\nDefine the type of machines you want the autoscaler to use") resource_configurations = {} while True: resource_name = get_input( "machine type name", "(remember it, we will later use it in the budget section)", required=True, new_line=True, ) resource_configurations[resource_name] = { "instance_type": get_input( "instance type", "for resource '{}' [default is 'g4dn.4xlarge']".format(resource_name), default="g4dn.4xlarge", ), "is_spot": input_bool( "is '{}' resource using spot instances? [t/F]".format(resource_name) ), "availability_zone": get_input( "availability zone", "for resource '{}' [default is 'us-east-1b']".format(resource_name), default="us-east-1b", ), "ami_id": get_input( "ami_id", "for resource '{}' [default is 'ami-07c95cafbb788face']".format( resource_name ), default="ami-07c95cafbb788face", ), "ebs_device_name": get_input( "ebs_device_name", "for resource '{}' [default is '/dev/xvda']".format(resource_name), default="/dev/xvda", ), "ebs_volume_size": input_int( "ebs_volume_size", " for resource '{}' [default is '100']".format(resource_name), default=100, ), "ebs_volume_type": get_input( "ebs_volume_type", "for resource '{}' [default is 'gp2']".format(resource_name), default="gp2", ), } if not input_bool("\nDefine another resource? [y/N]"): break configurations.resource_configurations = resource_configurations configurations.extra_vm_bash_script = input( "\nEnter any pre-execution bash script to be executed on the newly created instances: " ) print("\nSet up the budget\n") queues = defaultdict(list) while True: queue_name = get_input("queue name", required=True) while True: queue_type = get_input( "queue type", "(use the resources names defined earlier)", required=True, ) max_instances = input_int( "maximum number of instances allowed", required=True ) queues[queue_name].append((queue_type, max_instances)) if not input_bool("\nAdd another type to queue? [y/N]: "): break if not input_bool("Define another queue? [y/N]: "): break configurations.queues = dict(queues) hyper_params.max_idle_time_min = input_int( "maximum idle time", "for the autoscaler (in minutes, default is 15)", default=15, new_line=True, ) hyper_params.polling_interval_time_min = input_int( "polling interval", "for the autoscaler (in minutes, default is 5)", default=5, ) return configurations.as_dict(), hyper_params.as_dict() def get_input( key, # type: str description="", # type: str question="Enter", # type: str required=False, # type: bool default=None, # type: Optional[str] new_line=False, # type: bool ): # type: (...) -> Optional[str] if new_line: print() while True: value = input("{} {} {}: ".format(question, key, description)) if not value.strip() and required: print("{} is required".format(key)) elif not (value.strip() or required): return default else: return value def input_int( key, # type: str description="", # type: str required=False, # type: bool default=None, # type: Optional[int] new_line=False, # type: bool ): # type: (...) -> Optional[int] while True: try: value = int( get_input( key, description, required=required, default=default, new_line=new_line, ) ) return value except ValueError: print( "Invalid input: {} should be a number. Please enter an integer".format( key ) ) def input_bool(question, default=False): # type: (str, bool) -> bool while True: try: response = input("{}: ".format(question)).lower() if not response: return default return distutils.util.strtobool(response) except ValueError: print("Invalid input: please enter yes or no") if __name__ == "__main__": main()