Improve configuration wizard

This commit is contained in:
allegroai 2019-12-15 00:02:04 +02:00
parent 88f1031e5d
commit 1f0bb4906b

View File

@ -1,19 +1,21 @@
from __future__ import print_function from __future__ import print_function
from six.moves import input from six.moves import input
from pyhocon import ConfigFactory from pyhocon import ConfigFactory, ConfigMissingException
from pathlib2 import Path from pathlib2 import Path
from six.moves.urllib.parse import urlparse from six.moves.urllib.parse import urlparse
from trains_agent.backend_api.session import Session
from trains_agent.backend_api.session.defs import ENV_HOST from trains_agent.backend_api.session.defs import ENV_HOST
from trains_agent.backend_config.defs import LOCAL_CONFIG_FILES from trains_agent.backend_config.defs import LOCAL_CONFIG_FILES
description = """ description = """
Please create new credentials using the web app: {}/profile Please create new trains credentials through the profile page in your trains web app (e.g. https://demoapp.trains.allegro.ai/profile)
In the Admin page, press "Create new credentials", then press "Copy to clipboard" In the profile page, press "Create new credentials", then press "Copy to clipboard".
Paste credentials here: """ Paste copied configuration here:
"""
def_host = 'http://localhost:8080' def_host = 'http://localhost:8080'
try: try:
@ -38,20 +40,39 @@ def main():
print('Leaving setup, feel free to edit the configuration file.') print('Leaving setup, feel free to edit the configuration file.')
return return
print(host_description) print(description, end='')
web_host = input_url('Web Application Host', '') sentinel = ''
parsed_host = verify_url(web_host) parse_input = '\n'.join(iter(input, sentinel))
credentials = None
api_host = None
web_server = None
# noinspection PyBroadException
try:
parsed = ConfigFactory.parse_string(parse_input)
if parsed:
# Take the credentials in raw form or from api section
credentials = get_parsed_field(parsed, ["credentials"])
api_host = get_parsed_field(parsed, ["api_server", "host"])
web_server = get_parsed_field(parsed, ["web_server"])
except Exception:
credentials = credentials or None
api_host = api_host or None
web_server = web_server or None
if parsed_host.port == 8008: while not credentials or set(credentials) != {"access_key", "secret_key"}:
print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application') print('Could not parse credentials, please try entering them manually.')
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path credentials = read_manual_credentials()
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'],
elif parsed_host.port == 8080: credentials['secret_key'][0:4] + "***"))
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path if api_host:
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path api_host = input_url('API Host', api_host)
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path else:
elif parsed_host.netloc.startswith('demoapp.'): print(host_description)
api_host = input_url('API Host', '')
parsed_host = verify_url(api_host)
if parsed_host.netloc.startswith('demoapp.'):
# this is our demo server # this is our demo server
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
@ -73,6 +94,15 @@ def main():
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path
elif parsed_host.port == 8008:
print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application')
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path
elif parsed_host.port == 8080:
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
else: else:
api_host = '' api_host = ''
web_host = '' web_host = ''
@ -91,47 +121,23 @@ def main():
if not api_host: if not api_host:
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
api_host = input_url('API Host', api_host) web_host = input_url('Web Application Host', web_server if web_server else web_host)
files_host = input_url('File Store Host', files_host) files_host = input_url('File Store Host', files_host)
print('\nTRAINS Hosts configuration:\nWeb App: {}\nAPI: {}\nFile Store: {}\n'.format( print('\nTRAINS Hosts configuration:\nWeb App: {}\nAPI: {}\nFile Store: {}\n'.format(
web_host, api_host, files_host)) web_host, api_host, files_host))
while True: retry = 1
print(description.format(web_host), end='') max_retries = 2
parse_input = input() while retry <= max_retries: # Up to 2 tries by the user
# check if these are valid credentials if verify_credentials(api_host, credentials):
credentials = None
# noinspection PyBroadException
try:
parsed = ConfigFactory.parse_string(parse_input)
if parsed:
credentials = parsed.get("credentials", None)
except Exception:
credentials = None
if not credentials or set(credentials) != {"access_key", "secret_key"}:
print('Could not parse user credentials, try again one after the other.')
credentials = {}
# parse individual
print('Enter user access key: ', end='')
credentials['access_key'] = input()
print('Enter user secret: ', end='')
credentials['secret_key'] = input()
print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'],
credentials['secret_key'], ))
from trains_agent.backend_api.session import Session
# noinspection PyBroadException
try:
print('Verifying credentials ...')
Session(api_key=credentials['access_key'], secret_key=credentials['secret_key'], host=api_host)
print('Credentials verified!')
break break
except Exception: retry += 1
print('Error: could not verify credentials: host={} access={} secret={}'.format( if retry < max_retries + 1:
api_host, credentials['access_key'], credentials['secret_key'])) credentials = read_manual_credentials()
else:
print('Exiting setup without creating configuration file')
return
# get GIT User/Pass for cloning # get GIT User/Pass for cloning
print('Enter git username for repository cloning (leave blank for SSH key authentication): [] ', end='') print('Enter git username for repository cloning (leave blank for SSH key authentication): [] ', end='')
@ -186,6 +192,52 @@ def main():
print('TRAINS-AGENT setup completed successfully.') print('TRAINS-AGENT setup completed successfully.')
def verify_credentials(api_host, credentials):
"""check if the credentials are valid"""
# noinspection PyBroadException
try:
print('Verifying credentials ...')
if api_host:
Session(api_key=credentials['access_key'], secret_key=credentials['secret_key'], host=api_host)
print('Credentials verified!')
return True
else:
print("Can't verify credentials")
return False
except Exception:
print('Error: could not verify credentials: key={} secret={}'.format(
credentials.get('access_key'), credentials.get('secret_key')))
return False
def get_parsed_field(parsed_config, fields):
"""
Parsed the value from web profile page, 'copy to clipboard' option
:param parsed_config: The parsed value from the web ui
:type parsed_config: Config object
:param fields: list of values to parse, will parse by the list order
:type fields: List[str]
:return: parsed value if found, None else
"""
try:
return parsed_config.get("api").get(fields[0])
except ConfigMissingException: # fallback - try to parse the field like it was in web older version
if len(fields) == 1:
return parsed_config.get(fields[0])
elif len(fields) == 2:
return parsed_config.get(fields[1])
else:
return None
def read_manual_credentials():
print('Enter user access key: ', end='')
access_key = input()
print('Enter user secret: ', end='')
secret_key = input()
return {"access_key": access_key, "secret_key": secret_key}
def input_url(host_type, host=None): def input_url(host_type, host=None):
while True: while True:
print('{} configured to: [{}] '.format(host_type, host), end='') print('{} configured to: [{}] '.format(host_type, host), end='')