diff --git a/examples/services/aws-autoscaler/aws_autoscaler.py b/examples/services/aws-autoscaler/aws_autoscaler.py index 24de6c4d..8cd1ed83 100644 --- a/examples/services/aws-autoscaler/aws_autoscaler.py +++ b/examples/services/aws-autoscaler/aws_autoscaler.py @@ -1,8 +1,7 @@ -import distutils from argparse import ArgumentParser from collections import defaultdict from pathlib import Path -from typing import Optional, Tuple +from typing import Tuple import yaml from six.moves import input @@ -10,6 +9,7 @@ from six.moves import input from trains import Task from trains.automation.aws_auto_scaler import AwsAutoScaler from trains.config import running_remotely +from trains.utilities.wizard.user_input import get_input, input_int, input_bool CONF_FILE = "aws_autoscaler.yaml" DEFAULT_DOCKER_IMAGE = "nvidia/cuda" @@ -195,66 +195,5 @@ def run_wizard(): return configurations.as_dict(), hyper_params.as_dict() -def get_input( - key, # type: str - description="", # type: str - question="Enter", # type: str - required=False, # type: bool - default=None, # type: Optional[str] - new_line=False, # type: bool -): - # type: (...) -> Optional[str] - if new_line: - print() - while True: - value = input("{} {} {}: ".format(question, key, description)) - if not value.strip() and required: - print("{} is required".format(key)) - elif not (value.strip() or required): - return default - else: - return value - - -def input_int( - key, # type: str - description="", # type: str - required=False, # type: bool - default=None, # type: Optional[int] - new_line=False, # type: bool -): - # type: (...) -> Optional[int] - while True: - try: - value = int( - get_input( - key, - description, - required=required, - default=default, - new_line=new_line, - ) - ) - return value - except ValueError: - print( - "Invalid input: {} should be a number. Please enter an integer".format( - key - ) - ) - - -def input_bool(question, default=False): - # type: (str, bool) -> bool - while True: - try: - response = input("{}: ".format(question)).lower() - if not response: - return default - return distutils.util.strtobool(response) - except ValueError: - print("Invalid input: please enter yes or no") - - if __name__ == "__main__": main() diff --git a/trains/utilities/wizard/__init__.py b/trains/utilities/wizard/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trains/utilities/wizard/user_input.py b/trains/utilities/wizard/user_input.py new file mode 100644 index 00000000..d8f69ca8 --- /dev/null +++ b/trains/utilities/wizard/user_input.py @@ -0,0 +1,63 @@ +import distutils +from typing import Optional + + +def get_input( + key, # type: str + description="", # type: str + question="Enter", # type: str + required=False, # type: bool + default=None, # type: Optional[str] + new_line=False, # type: bool +): + # type: (...) -> Optional[str] + if new_line: + print() + while True: + value = input("{} {} {}: ".format(question, key, description)) + if not value.strip() and required: + print("{} is required".format(key)) + elif not (value.strip() or required): + return default + else: + return value + + +def input_int( + key, # type: str + description="", # type: str + required=False, # type: bool + default=None, # type: Optional[int] + new_line=False, # type: bool +): + # type: (...) -> Optional[int] + while True: + try: + value = int( + get_input( + key, + description, + required=required, + default=default, + new_line=new_line, + ) + ) + return value + except ValueError: + print( + "Invalid input: {} should be a number. Please enter an integer".format( + key + ) + ) + + +def input_bool(question, default=False): + # type: (str, bool) -> bool + while True: + try: + response = input("{}: ".format(question)).lower() + if not response: + return default + return distutils.util.strtobool(response) + except ValueError: + print("Invalid input: please enter yes or no")