diff --git a/examples/services/aws-autoscaler/aws_autoscaler.py b/examples/services/aws-autoscaler/aws_autoscaler.py index 04e4a8e8..f92dfabc 100644 --- a/examples/services/aws-autoscaler/aws_autoscaler.py +++ b/examples/services/aws-autoscaler/aws_autoscaler.py @@ -10,7 +10,7 @@ from six.moves import input from trains import Task from trains.automation.aws_auto_scaler import AwsAutoScaler from trains.config import running_remotely -from trains.utilities.wizard.user_input import get_input, input_int, input_bool +from trains.utilities.wizard.user_input import get_input, input_int, input_bool, multiline_input CONF_FILE = "aws_autoscaler.yaml" DEFAULT_DOCKER_IMAGE = "nvidia/cuda:10.1-runtime-ubuntu18.04" @@ -190,9 +190,15 @@ def run_wizard(): configurations.resource_configurations = resource_configurations - configurations.extra_vm_bash_script = input( - "\nEnter any pre-execution bash script to be executed on the newly created instances []: " + configurations.extra_vm_bash_script, num_lines_bash_script = multiline_input( + "\nEnter any pre-execution bash script to be executed on the newly created instances []" ) + print("Entered {} lines of pre-execution bash script".format(num_lines_bash_script)) + + configurations.extra_trains_conf, num_lines_trains_conf = multiline_input( + "\nEnter anything you'd like to include in your trains.conf file []" + ) + print("Entered {} extra lines for trains.conf file".format(num_lines_trains_conf)) print("\nDefine the machines budget:") print("-----------------------------") diff --git a/trains/utilities/wizard/user_input.py b/trains/utilities/wizard/user_input.py index 5ea078f0..bd1bf542 100644 --- a/trains/utilities/wizard/user_input.py +++ b/trains/utilities/wizard/user_input.py @@ -69,3 +69,15 @@ def input_bool(question, default=False): raise ValueError() except ValueError: print("Invalid input: please enter 'yes' or 'no'") + + +def multiline_input(description=""): + print("{} \nNote: two consecutive empty lines would terminate the input : ".format(description)) + lines = [] + empty_lines = 0 + while empty_lines < 2: + line = input() + lines.append(line) + empty_lines = 0 if line else empty_lines + 1 + res = "\n".join(lines[:-1]) + return res, len(res.splitlines())