Improve trains-init wizard

This commit is contained in:
allegroai 2019-12-04 23:47:39 +02:00
parent f92278750a
commit 9add031fe8

View File

@ -4,17 +4,19 @@ from six.moves import input
from pathlib2 import Path from pathlib2 import Path
from six.moves.urllib.parse import urlparse from six.moves.urllib.parse import urlparse
from trains.utilities.pyhocon import ConfigFactory from trains.backend_api.session import Session
from trains.backend_api.session.defs import ENV_HOST from trains.backend_api.session.defs import ENV_HOST
from trains.backend_config.defs import LOCAL_CONFIG_FILES from trains.backend_config.defs import LOCAL_CONFIG_FILES
from trains.config import config_obj from trains.config import config_obj
from trains.utilities.pyhocon import ConfigFactory
description = """ description = """
Please create new credentials using the web app: {}/profile Please create new trains credentials through the profile page in your trains web app (e.g. https://demoapp.trains.allegro.ai/profile)
In the Admin page, press "Create new credentials", then press "Copy to clipboard" In the profile page, press "Create new credentials", then press "Copy to clipboard".
Paste credentials here: """ Paste copied configuration here:
"""
try: try:
def_host = ENV_HOST.get(default=config_obj.get("api.web_server")) or 'http://localhost:8080' def_host = ENV_HOST.get(default=config_obj.get("api.web_server")) or 'http://localhost:8080'
@ -38,9 +40,36 @@ def main():
print('Leaving setup, feel free to edit the configuration file.') print('Leaving setup, feel free to edit the configuration file.')
return return
print(host_description) print(description, end='')
web_host = input_url('Web Application Host', '') parse_input = input()
parsed_host = verify_url(web_host) credentials = None
api_host = None
web_server = None
# noinspection PyBroadException
try:
parsed = ConfigFactory.parse_string(parse_input)
if parsed:
# Take the credentials in raw form or from api section
credentials = parsed.get("credentials", None) or parsed.get("api", {}).get("credentials")
api_host = parsed.get("api", {}).get("host") or parsed.get("api", {}).get("api_server")
web_server = parsed.get("api", {}).get("web_server")
except Exception:
credentials = credentials or None
api_host = api_host or None
web_server = web_server or None
while not credentials or set(credentials) != {"access_key", "secret_key"}:
print('Could not parse credentials, please try entering them manually.')
credentials = read_manual_credentials()
print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'],
credentials['secret_key'][0:4] + "***"))
if api_host:
api_host = input_url('API Host', api_host)
else:
print(host_description)
api_host = input_url('API Host', '')
parsed_host = verify_url(api_host)
if parsed_host.port == 8008: if parsed_host.port == 8008:
print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application') print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application')
@ -91,47 +120,23 @@ def main():
if not api_host: if not api_host:
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
api_host = input_url('API Host', api_host) web_host = input_url('Web Application Host', web_server if web_server else web_host)
files_host = input_url('File Store Host', files_host) files_host = input_url('File Store Host', files_host)
print('\nTRAINS Hosts configuration:\nWeb App: {}\nAPI: {}\nFile Store: {}\n'.format( print('\nTRAINS Hosts configuration:\nWeb App: {}\nAPI: {}\nFile Store: {}\n'.format(
web_host, api_host, files_host)) web_host, api_host, files_host))
while True: retry = 1
print(description.format(web_host), end='') max_retries = 2
parse_input = input() while retry <= max_retries: # Up to 2 tries by the user
# check if these are valid credentials if verify_credentials(api_host, credentials):
credentials = None
# noinspection PyBroadException
try:
parsed = ConfigFactory.parse_string(parse_input)
if parsed:
credentials = parsed.get("credentials", None)
except Exception:
credentials = None
if not credentials or set(credentials) != {"access_key", "secret_key"}:
print('Could not parse user credentials, try again one after the other.')
credentials = {}
# parse individual
print('Enter user access key: ', end='')
credentials['access_key'] = input()
print('Enter user secret: ', end='')
credentials['secret_key'] = input()
print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'],
credentials['secret_key'], ))
from trains.backend_api.session import Session
# noinspection PyBroadException
try:
print('Verifying credentials ...')
Session(api_key=credentials['access_key'], secret_key=credentials['secret_key'], host=api_host)
print('Credentials verified!')
break break
except Exception: retry += 1
print('Error: could not verify credentials: host={} access={} secret={}'.format( if retry < max_retries + 1:
api_host, credentials['access_key'], credentials['secret_key'])) credentials = read_manual_credentials()
else:
print('Exiting setup without creating configuration file')
return
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -165,6 +170,32 @@ def main():
print('TRAINS setup completed successfully.') print('TRAINS setup completed successfully.')
def verify_credentials(api_host, credentials):
"""check if the credentials are valid"""
# noinspection PyBroadException
try:
print('Verifying credentials ...')
if api_host:
Session(api_key=credentials['access_key'], secret_key=credentials['secret_key'], host=api_host)
print('Credentials verified!')
return True
else:
print("Can't verify credentials")
return False
except Exception:
print('Error: could not verify credentials: key={} secret={}'.format(
credentials.get('access_key'), credentials.get('secret_key')))
return False
def read_manual_credentials():
print('Enter user access key: ', end='')
access_key = input()
print('Enter user secret: ', end='')
secret_key = input()
return {"access_key": access_key, "secret_key": secret_key}
def input_url(host_type, host=None): def input_url(host_type, host=None):
while True: while True:
print('{} configured to: [{}] '.format(host_type, host), end='') print('{} configured to: [{}] '.format(host_type, host), end='')