From 9add031fe8bd0c863cd8c6766d9b7fd6b9664f03 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 4 Dec 2019 23:47:39 +0200 Subject: [PATCH] Improve trains-init wizard --- trains/config/default/__main__.py | 115 +++++++++++++++++++----------- 1 file changed, 73 insertions(+), 42 deletions(-) diff --git a/trains/config/default/__main__.py b/trains/config/default/__main__.py index 997e9c1e..93f29080 100644 --- a/trains/config/default/__main__.py +++ b/trains/config/default/__main__.py @@ -4,17 +4,19 @@ from six.moves import input from pathlib2 import Path from six.moves.urllib.parse import urlparse -from trains.utilities.pyhocon import ConfigFactory +from trains.backend_api.session import Session from trains.backend_api.session.defs import ENV_HOST from trains.backend_config.defs import LOCAL_CONFIG_FILES from trains.config import config_obj +from trains.utilities.pyhocon import ConfigFactory description = """ -Please create new credentials using the web app: {}/profile -In the Admin page, press "Create new credentials", then press "Copy to clipboard" +Please create new trains credentials through the profile page in your trains web app (e.g. https://demoapp.trains.allegro.ai/profile) +In the profile page, press "Create new credentials", then press "Copy to clipboard". -Paste credentials here: """ +Paste copied configuration here: +""" try: def_host = ENV_HOST.get(default=config_obj.get("api.web_server")) or 'http://localhost:8080' @@ -38,9 +40,36 @@ def main(): print('Leaving setup, feel free to edit the configuration file.') return - print(host_description) - web_host = input_url('Web Application Host', '') - parsed_host = verify_url(web_host) + print(description, end='') + parse_input = input() + credentials = None + api_host = None + web_server = None + # noinspection PyBroadException + try: + parsed = ConfigFactory.parse_string(parse_input) + if parsed: + # Take the credentials in raw form or from api section + credentials = parsed.get("credentials", None) or parsed.get("api", {}).get("credentials") + api_host = parsed.get("api", {}).get("host") or parsed.get("api", {}).get("api_server") + web_server = parsed.get("api", {}).get("web_server") + except Exception: + credentials = credentials or None + api_host = api_host or None + web_server = web_server or None + + while not credentials or set(credentials) != {"access_key", "secret_key"}: + print('Could not parse credentials, please try entering them manually.') + credentials = read_manual_credentials() + + print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'], + credentials['secret_key'][0:4] + "***")) + if api_host: + api_host = input_url('API Host', api_host) + else: + print(host_description) + api_host = input_url('API Host', '') + parsed_host = verify_url(api_host) if parsed_host.port == 8008: print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application') @@ -91,47 +120,23 @@ def main(): if not api_host: api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path - api_host = input_url('API Host', api_host) + web_host = input_url('Web Application Host', web_server if web_server else web_host) files_host = input_url('File Store Host', files_host) print('\nTRAINS Hosts configuration:\nWeb App: {}\nAPI: {}\nFile Store: {}\n'.format( web_host, api_host, files_host)) - while True: - print(description.format(web_host), end='') - parse_input = input() - # check if these are valid credentials - credentials = None - # noinspection PyBroadException - try: - parsed = ConfigFactory.parse_string(parse_input) - if parsed: - credentials = parsed.get("credentials", None) - except Exception: - credentials = None - - if not credentials or set(credentials) != {"access_key", "secret_key"}: - print('Could not parse user credentials, try again one after the other.') - credentials = {} - # parse individual - print('Enter user access key: ', end='') - credentials['access_key'] = input() - print('Enter user secret: ', end='') - credentials['secret_key'] = input() - - print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'], - credentials['secret_key'], )) - - from trains.backend_api.session import Session - # noinspection PyBroadException - try: - print('Verifying credentials ...') - Session(api_key=credentials['access_key'], secret_key=credentials['secret_key'], host=api_host) - print('Credentials verified!') + retry = 1 + max_retries = 2 + while retry <= max_retries: # Up to 2 tries by the user + if verify_credentials(api_host, credentials): break - except Exception: - print('Error: could not verify credentials: host={} access={} secret={}'.format( - api_host, credentials['access_key'], credentials['secret_key'])) + retry += 1 + if retry < max_retries + 1: + credentials = read_manual_credentials() + else: + print('Exiting setup without creating configuration file') + return # noinspection PyBroadException try: @@ -165,6 +170,32 @@ def main(): print('TRAINS setup completed successfully.') +def verify_credentials(api_host, credentials): + """check if the credentials are valid""" + # noinspection PyBroadException + try: + print('Verifying credentials ...') + if api_host: + Session(api_key=credentials['access_key'], secret_key=credentials['secret_key'], host=api_host) + print('Credentials verified!') + return True + else: + print("Can't verify credentials") + return False + except Exception: + print('Error: could not verify credentials: key={} secret={}'.format( + credentials.get('access_key'), credentials.get('secret_key'))) + return False + + +def read_manual_credentials(): + print('Enter user access key: ', end='') + access_key = input() + print('Enter user secret: ', end='') + secret_key = input() + return {"access_key": access_key, "secret_key": secret_key} + + def input_url(host_type, host=None): while True: print('{} configured to: [{}] '.format(host_type, host), end='')