Move autoscaler input functions to utilities

This commit is contained in:
allegroai 2020-07-14 23:40:05 +03:00
parent 5e0aecf1b2
commit e7864e6ba8
3 changed files with 65 additions and 63 deletions

View File

@ -1,8 +1,7 @@
import distutils
from argparse import ArgumentParser from argparse import ArgumentParser
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple from typing import Tuple
import yaml import yaml
from six.moves import input from six.moves import input
@ -10,6 +9,7 @@ from six.moves import input
from trains import Task from trains import Task
from trains.automation.aws_auto_scaler import AwsAutoScaler from trains.automation.aws_auto_scaler import AwsAutoScaler
from trains.config import running_remotely from trains.config import running_remotely
from trains.utilities.wizard.user_input import get_input, input_int, input_bool
CONF_FILE = "aws_autoscaler.yaml" CONF_FILE = "aws_autoscaler.yaml"
DEFAULT_DOCKER_IMAGE = "nvidia/cuda" DEFAULT_DOCKER_IMAGE = "nvidia/cuda"
@ -195,66 +195,5 @@ def run_wizard():
return configurations.as_dict(), hyper_params.as_dict() return configurations.as_dict(), hyper_params.as_dict()
def get_input(
key, # type: str
description="", # type: str
question="Enter", # type: str
required=False, # type: bool
default=None, # type: Optional[str]
new_line=False, # type: bool
):
# type: (...) -> Optional[str]
if new_line:
print()
while True:
value = input("{} {} {}: ".format(question, key, description))
if not value.strip() and required:
print("{} is required".format(key))
elif not (value.strip() or required):
return default
else:
return value
def input_int(
key, # type: str
description="", # type: str
required=False, # type: bool
default=None, # type: Optional[int]
new_line=False, # type: bool
):
# type: (...) -> Optional[int]
while True:
try:
value = int(
get_input(
key,
description,
required=required,
default=default,
new_line=new_line,
)
)
return value
except ValueError:
print(
"Invalid input: {} should be a number. Please enter an integer".format(
key
)
)
def input_bool(question, default=False):
# type: (str, bool) -> bool
while True:
try:
response = input("{}: ".format(question)).lower()
if not response:
return default
return distutils.util.strtobool(response)
except ValueError:
print("Invalid input: please enter yes or no")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

View File

@ -0,0 +1,63 @@
import distutils
from typing import Optional
def get_input(
key, # type: str
description="", # type: str
question="Enter", # type: str
required=False, # type: bool
default=None, # type: Optional[str]
new_line=False, # type: bool
):
# type: (...) -> Optional[str]
if new_line:
print()
while True:
value = input("{} {} {}: ".format(question, key, description))
if not value.strip() and required:
print("{} is required".format(key))
elif not (value.strip() or required):
return default
else:
return value
def input_int(
key, # type: str
description="", # type: str
required=False, # type: bool
default=None, # type: Optional[int]
new_line=False, # type: bool
):
# type: (...) -> Optional[int]
while True:
try:
value = int(
get_input(
key,
description,
required=required,
default=default,
new_line=new_line,
)
)
return value
except ValueError:
print(
"Invalid input: {} should be a number. Please enter an integer".format(
key
)
)
def input_bool(question, default=False):
# type: (str, bool) -> bool
while True:
try:
response = input("{}: ".format(question)).lower()
if not response:
return default
return distutils.util.strtobool(response)
except ValueError:
print("Invalid input: please enter yes or no")