From 3bd997c4dcfb2a22c73051eec699f7d78446d9c9 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 15 Dec 2019 00:11:01 +0200 Subject: [PATCH] Improve trains-init configuration wizard --- trains/config/default/__main__.py | 51 ++++++++++++++++++++-------- trains/utilities/pyhocon/__init__.py | 2 +- 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/trains/config/default/__main__.py b/trains/config/default/__main__.py index 93f29080..3b358e02 100644 --- a/trains/config/default/__main__.py +++ b/trains/config/default/__main__.py @@ -8,7 +8,7 @@ from trains.backend_api.session import Session from trains.backend_api.session.defs import ENV_HOST from trains.backend_config.defs import LOCAL_CONFIG_FILES from trains.config import config_obj -from trains.utilities.pyhocon import ConfigFactory +from trains.utilities.pyhocon import ConfigFactory, ConfigMissingException description = """ @@ -41,7 +41,8 @@ def main(): return print(description, end='') - parse_input = input() + sentinel = '' + parse_input = '\n'.join(iter(input, sentinel)) credentials = None api_host = None web_server = None @@ -50,9 +51,9 @@ def main(): parsed = ConfigFactory.parse_string(parse_input) if parsed: # Take the credentials in raw form or from api section - credentials = parsed.get("credentials", None) or parsed.get("api", {}).get("credentials") - api_host = parsed.get("api", {}).get("host") or parsed.get("api", {}).get("api_server") - web_server = parsed.get("api", {}).get("web_server") + credentials = get_parsed_field(parsed, ["credentials"]) + api_host = get_parsed_field(parsed, ["api_server", "host"]) + web_server = get_parsed_field(parsed, ["web_server"]) except Exception: credentials = credentials or None api_host = api_host or None @@ -71,16 +72,7 @@ def main(): api_host = input_url('API Host', '') parsed_host = verify_url(api_host) - if parsed_host.port == 8008: - print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application') - api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path - web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path - files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path - elif parsed_host.port == 8080: - api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path - web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path - files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path - elif parsed_host.netloc.startswith('demoapp.'): + if parsed_host.netloc.startswith('demoapp.'): # this is our demo server api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path @@ -102,6 +94,15 @@ def main(): api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path + elif parsed_host.port == 8008: + print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application') + api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path + web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path + files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path + elif parsed_host.port == 8080: + api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path + web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path + files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path else: api_host = '' web_host = '' @@ -188,6 +189,26 @@ def verify_credentials(api_host, credentials): return False +def get_parsed_field(parsed_config, fields): + """ + Parsed the value from web profile page, 'copy to clipboard' option + :param parsed_config: The parsed value from the web ui + :type parsed_config: Config object + :param fields: list of values to parse, will parse by the list order + :type fields: List[str] + :return: parsed value if found, None else + """ + try: + return parsed_config.get("api").get(fields[0]) + except ConfigMissingException: # fallback - try to parse the field like it was in web older version + if len(fields) == 1: + return parsed_config.get(fields[0]) + elif len(fields) == 2: + return parsed_config.get(fields[1]) + else: + return None + + def read_manual_credentials(): print('Enter user access key: ', end='') access_key = input() diff --git a/trains/utilities/pyhocon/__init__.py b/trains/utilities/pyhocon/__init__.py index 4aa44f0c..fd2bbdc3 100755 --- a/trains/utilities/pyhocon/__init__.py +++ b/trains/utilities/pyhocon/__init__.py @@ -1,3 +1,3 @@ -from .config_parser import ConfigParser, ConfigFactory +from .config_parser import ConfigParser, ConfigFactory, ConfigMissingException from .config_tree import ConfigTree from .converter import HOCONConverter