From 84a34428b61a55131736881b5b71ba4e1c8b3e60 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 12 Mar 2020 18:09:03 +0200 Subject: [PATCH] Add trains-init support for config file env override (as well as argument) --- trains/config/default/__main__.py | 48 ++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/trains/config/default/__main__.py b/trains/config/default/__main__.py index b8c13c27..ee8b1878 100644 --- a/trains/config/default/__main__.py +++ b/trains/config/default/__main__.py @@ -1,18 +1,19 @@ +""" Trains configuration wizard""" from __future__ import print_function +import argparse import os -from six.moves import input from pathlib2 import Path +from six.moves import input from six.moves.urllib.parse import urlparse from trains.backend_api.session import Session from trains.backend_api.session.defs import ENV_HOST -from trains.backend_config.defs import LOCAL_CONFIG_FILES +from trains.backend_config.defs import LOCAL_CONFIG_FILES, LOCAL_CONFIG_FILE_OVERRIDE_VAR from trains.config import config_obj from trains.utilities.pyhocon import ConfigFactory, ConfigMissingException - description = """ Please create new trains credentials through the profile page in your trains web app (e.g. http://localhost:8080/profile) In the profile page, press "Create new credentials", then press "Copy to clipboard". @@ -25,14 +26,6 @@ try: except Exception: def_host = 'http://localhost:8080' -host_description = """ -Editing configuration file: {CONFIG_FILE} -Enter the url of the trains-server's Web service, for example: {HOST} -""".format( - CONFIG_FILE=LOCAL_CONFIG_FILES[0], - HOST=def_host, -) - def get_user_input(): """ @@ -64,9 +57,27 @@ def get_user_input(): return os.linesep.join(input_list) +def validate_file(string): + if not string: + raise argparse.ArgumentTypeError("expected a valid file path") + return string + + def main(): + default_config_file = os.getenv(LOCAL_CONFIG_FILE_OVERRIDE_VAR) or LOCAL_CONFIG_FILES[0] + + p = argparse.ArgumentParser(description=__doc__) + p.add_argument( + "--file", "-F", help="Target configuration file path (default is %(default)s)", + default=default_config_file, + type=validate_file + ) + + args = p.parse_args() + print('TRAINS SDK setup process') - conf_file = Path(LOCAL_CONFIG_FILES[0]).absolute() + + conf_file = Path(args.file).absolute() if conf_file.exists() and conf_file.is_file() and conf_file.stat().st_size > 0: print('Configuration file already exists: {}'.format(str(conf_file))) print('Leaving setup, feel free to edit the configuration file.') @@ -96,6 +107,15 @@ def main(): print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'], credentials['secret_key'][0:4] + "***")) + + host_description = """ + Editing configuration file: {CONFIG_FILE} + Enter the url of the trains-server's Web service, for example: {HOST} + """.format( + CONFIG_FILE=args.file, + HOST=def_host, + ) + if api_host: api_host = input_url('API Host', api_host) else: @@ -265,8 +285,8 @@ def input_url(host_type, host=None): def input_host_port(host_type, parsed_host): print('Enter port for {} host '.format(host_type), end='') replace_port = input().lower() - return parsed_host.scheme + "://" + parsed_host.netloc + (':{}'.format(replace_port) if replace_port else '') + \ - parsed_host.path + return (parsed_host.scheme + "://" + parsed_host.netloc + (':{}'.format(replace_port) if replace_port else '') + + parsed_host.path) def verify_url(parse_input):