diff --git a/trains_agent/commands/config.py b/trains_agent/commands/config.py index 725c1ab..af28ba0 100644 --- a/trains_agent/commands/config.py +++ b/trains_agent/commands/config.py @@ -1,19 +1,21 @@ from __future__ import print_function from six.moves import input -from pyhocon import ConfigFactory +from pyhocon import ConfigFactory, ConfigMissingException from pathlib2 import Path from six.moves.urllib.parse import urlparse +from trains_agent.backend_api.session import Session from trains_agent.backend_api.session.defs import ENV_HOST from trains_agent.backend_config.defs import LOCAL_CONFIG_FILES description = """ -Please create new credentials using the web app: {}/profile -In the Admin page, press "Create new credentials", then press "Copy to clipboard" +Please create new trains credentials through the profile page in your trains web app (e.g. https://demoapp.trains.allegro.ai/profile) +In the profile page, press "Create new credentials", then press "Copy to clipboard". -Paste credentials here: """ +Paste copied configuration here: +""" def_host = 'http://localhost:8080' try: @@ -38,20 +40,39 @@ def main(): print('Leaving setup, feel free to edit the configuration file.') return - print(host_description) - web_host = input_url('Web Application Host', '') - parsed_host = verify_url(web_host) + print(description, end='') + sentinel = '' + parse_input = '\n'.join(iter(input, sentinel)) + credentials = None + api_host = None + web_server = None + # noinspection PyBroadException + try: + parsed = ConfigFactory.parse_string(parse_input) + if parsed: + # Take the credentials in raw form or from api section + credentials = get_parsed_field(parsed, ["credentials"]) + api_host = get_parsed_field(parsed, ["api_server", "host"]) + web_server = get_parsed_field(parsed, ["web_server"]) + except Exception: + credentials = credentials or None + api_host = api_host or None + web_server = web_server or None - if parsed_host.port == 8008: - print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application') - api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path - web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path - files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path - elif parsed_host.port == 8080: - api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path - web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path - files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path - elif parsed_host.netloc.startswith('demoapp.'): + while not credentials or set(credentials) != {"access_key", "secret_key"}: + print('Could not parse credentials, please try entering them manually.') + credentials = read_manual_credentials() + + print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'], + credentials['secret_key'][0:4] + "***")) + if api_host: + api_host = input_url('API Host', api_host) + else: + print(host_description) + api_host = input_url('API Host', '') + parsed_host = verify_url(api_host) + + if parsed_host.netloc.startswith('demoapp.'): # this is our demo server api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path @@ -73,6 +94,15 @@ def main(): api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path + elif parsed_host.port == 8008: + print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application') + api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path + web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path + files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path + elif parsed_host.port == 8080: + api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path + web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path + files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path else: api_host = '' web_host = '' @@ -91,47 +121,23 @@ def main(): if not api_host: api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path - api_host = input_url('API Host', api_host) + web_host = input_url('Web Application Host', web_server if web_server else web_host) files_host = input_url('File Store Host', files_host) print('\nTRAINS Hosts configuration:\nWeb App: {}\nAPI: {}\nFile Store: {}\n'.format( web_host, api_host, files_host)) - while True: - print(description.format(web_host), end='') - parse_input = input() - # check if these are valid credentials - credentials = None - # noinspection PyBroadException - try: - parsed = ConfigFactory.parse_string(parse_input) - if parsed: - credentials = parsed.get("credentials", None) - except Exception: - credentials = None - - if not credentials or set(credentials) != {"access_key", "secret_key"}: - print('Could not parse user credentials, try again one after the other.') - credentials = {} - # parse individual - print('Enter user access key: ', end='') - credentials['access_key'] = input() - print('Enter user secret: ', end='') - credentials['secret_key'] = input() - - print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'], - credentials['secret_key'], )) - - from trains_agent.backend_api.session import Session - # noinspection PyBroadException - try: - print('Verifying credentials ...') - Session(api_key=credentials['access_key'], secret_key=credentials['secret_key'], host=api_host) - print('Credentials verified!') + retry = 1 + max_retries = 2 + while retry <= max_retries: # Up to 2 tries by the user + if verify_credentials(api_host, credentials): break - except Exception: - print('Error: could not verify credentials: host={} access={} secret={}'.format( - api_host, credentials['access_key'], credentials['secret_key'])) + retry += 1 + if retry < max_retries + 1: + credentials = read_manual_credentials() + else: + print('Exiting setup without creating configuration file') + return # get GIT User/Pass for cloning print('Enter git username for repository cloning (leave blank for SSH key authentication): [] ', end='') @@ -186,6 +192,52 @@ def main(): print('TRAINS-AGENT setup completed successfully.') +def verify_credentials(api_host, credentials): + """check if the credentials are valid""" + # noinspection PyBroadException + try: + print('Verifying credentials ...') + if api_host: + Session(api_key=credentials['access_key'], secret_key=credentials['secret_key'], host=api_host) + print('Credentials verified!') + return True + else: + print("Can't verify credentials") + return False + except Exception: + print('Error: could not verify credentials: key={} secret={}'.format( + credentials.get('access_key'), credentials.get('secret_key'))) + return False + + +def get_parsed_field(parsed_config, fields): + """ + Parsed the value from web profile page, 'copy to clipboard' option + :param parsed_config: The parsed value from the web ui + :type parsed_config: Config object + :param fields: list of values to parse, will parse by the list order + :type fields: List[str] + :return: parsed value if found, None else + """ + try: + return parsed_config.get("api").get(fields[0]) + except ConfigMissingException: # fallback - try to parse the field like it was in web older version + if len(fields) == 1: + return parsed_config.get(fields[0]) + elif len(fields) == 2: + return parsed_config.get(fields[1]) + else: + return None + + +def read_manual_credentials(): + print('Enter user access key: ', end='') + access_key = input() + print('Enter user secret: ', end='') + secret_key = input() + return {"access_key": access_key, "secret_key": secret_key} + + def input_url(host_type, host=None): while True: print('{} configured to: [{}] '.format(host_type, host), end='')