Add trains-init support for config file env override (as well as argument)

This commit is contained in:
allegroai 2020-03-12 18:09:03 +02:00
parent b3dff9a4eb
commit 84a34428b6

View File

@ -1,18 +1,19 @@
""" Trains configuration wizard"""
from __future__ import print_function from __future__ import print_function
import argparse
import os import os
from six.moves import input
from pathlib2 import Path from pathlib2 import Path
from six.moves import input
from six.moves.urllib.parse import urlparse from six.moves.urllib.parse import urlparse
from trains.backend_api.session import Session from trains.backend_api.session import Session
from trains.backend_api.session.defs import ENV_HOST from trains.backend_api.session.defs import ENV_HOST
from trains.backend_config.defs import LOCAL_CONFIG_FILES from trains.backend_config.defs import LOCAL_CONFIG_FILES, LOCAL_CONFIG_FILE_OVERRIDE_VAR
from trains.config import config_obj from trains.config import config_obj
from trains.utilities.pyhocon import ConfigFactory, ConfigMissingException from trains.utilities.pyhocon import ConfigFactory, ConfigMissingException
description = """ description = """
Please create new trains credentials through the profile page in your trains web app (e.g. http://localhost:8080/profile) Please create new trains credentials through the profile page in your trains web app (e.g. http://localhost:8080/profile)
In the profile page, press "Create new credentials", then press "Copy to clipboard". In the profile page, press "Create new credentials", then press "Copy to clipboard".
@ -25,14 +26,6 @@ try:
except Exception: except Exception:
def_host = 'http://localhost:8080' def_host = 'http://localhost:8080'
host_description = """
Editing configuration file: {CONFIG_FILE}
Enter the url of the trains-server's Web service, for example: {HOST}
""".format(
CONFIG_FILE=LOCAL_CONFIG_FILES[0],
HOST=def_host,
)
def get_user_input(): def get_user_input():
""" """
@ -64,9 +57,27 @@ def get_user_input():
return os.linesep.join(input_list) return os.linesep.join(input_list)
def validate_file(string):
if not string:
raise argparse.ArgumentTypeError("expected a valid file path")
return string
def main(): def main():
default_config_file = os.getenv(LOCAL_CONFIG_FILE_OVERRIDE_VAR) or LOCAL_CONFIG_FILES[0]
p = argparse.ArgumentParser(description=__doc__)
p.add_argument(
"--file", "-F", help="Target configuration file path (default is %(default)s)",
default=default_config_file,
type=validate_file
)
args = p.parse_args()
print('TRAINS SDK setup process') print('TRAINS SDK setup process')
conf_file = Path(LOCAL_CONFIG_FILES[0]).absolute()
conf_file = Path(args.file).absolute()
if conf_file.exists() and conf_file.is_file() and conf_file.stat().st_size > 0: if conf_file.exists() and conf_file.is_file() and conf_file.stat().st_size > 0:
print('Configuration file already exists: {}'.format(str(conf_file))) print('Configuration file already exists: {}'.format(str(conf_file)))
print('Leaving setup, feel free to edit the configuration file.') print('Leaving setup, feel free to edit the configuration file.')
@ -96,6 +107,15 @@ def main():
print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'], print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'],
credentials['secret_key'][0:4] + "***")) credentials['secret_key'][0:4] + "***"))
host_description = """
Editing configuration file: {CONFIG_FILE}
Enter the url of the trains-server's Web service, for example: {HOST}
""".format(
CONFIG_FILE=args.file,
HOST=def_host,
)
if api_host: if api_host:
api_host = input_url('API Host', api_host) api_host = input_url('API Host', api_host)
else: else:
@ -265,8 +285,8 @@ def input_url(host_type, host=None):
def input_host_port(host_type, parsed_host): def input_host_port(host_type, parsed_host):
print('Enter port for {} host '.format(host_type), end='') print('Enter port for {} host '.format(host_type), end='')
replace_port = input().lower() replace_port = input().lower()
return parsed_host.scheme + "://" + parsed_host.netloc + (':{}'.format(replace_port) if replace_port else '') + \ return (parsed_host.scheme + "://" + parsed_host.netloc + (':{}'.format(replace_port) if replace_port else '') +
parsed_host.path parsed_host.path)
def verify_url(parse_input): def verify_url(parse_input):