Improve trains-init configuration wizard

This commit is contained in:
allegroai 2019-12-15 00:11:01 +02:00
parent c1cc80ba1b
commit 3bd997c4dc
2 changed files with 37 additions and 16 deletions

View File

@ -8,7 +8,7 @@ from trains.backend_api.session import Session
from trains.backend_api.session.defs import ENV_HOST from trains.backend_api.session.defs import ENV_HOST
from trains.backend_config.defs import LOCAL_CONFIG_FILES from trains.backend_config.defs import LOCAL_CONFIG_FILES
from trains.config import config_obj from trains.config import config_obj
from trains.utilities.pyhocon import ConfigFactory from trains.utilities.pyhocon import ConfigFactory, ConfigMissingException
description = """ description = """
@ -41,7 +41,8 @@ def main():
return return
print(description, end='') print(description, end='')
parse_input = input() sentinel = ''
parse_input = '\n'.join(iter(input, sentinel))
credentials = None credentials = None
api_host = None api_host = None
web_server = None web_server = None
@ -50,9 +51,9 @@ def main():
parsed = ConfigFactory.parse_string(parse_input) parsed = ConfigFactory.parse_string(parse_input)
if parsed: if parsed:
# Take the credentials in raw form or from api section # Take the credentials in raw form or from api section
credentials = parsed.get("credentials", None) or parsed.get("api", {}).get("credentials") credentials = get_parsed_field(parsed, ["credentials"])
api_host = parsed.get("api", {}).get("host") or parsed.get("api", {}).get("api_server") api_host = get_parsed_field(parsed, ["api_server", "host"])
web_server = parsed.get("api", {}).get("web_server") web_server = get_parsed_field(parsed, ["web_server"])
except Exception: except Exception:
credentials = credentials or None credentials = credentials or None
api_host = api_host or None api_host = api_host or None
@ -71,16 +72,7 @@ def main():
api_host = input_url('API Host', '') api_host = input_url('API Host', '')
parsed_host = verify_url(api_host) parsed_host = verify_url(api_host)
if parsed_host.port == 8008: if parsed_host.netloc.startswith('demoapp.'):
print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application')
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path
elif parsed_host.port == 8080:
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
elif parsed_host.netloc.startswith('demoapp.'):
# this is our demo server # this is our demo server
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
@ -102,6 +94,15 @@ def main():
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path
elif parsed_host.port == 8008:
print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application')
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path
elif parsed_host.port == 8080:
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
else: else:
api_host = '' api_host = ''
web_host = '' web_host = ''
@ -188,6 +189,26 @@ def verify_credentials(api_host, credentials):
return False return False
def get_parsed_field(parsed_config, fields):
"""
Parsed the value from web profile page, 'copy to clipboard' option
:param parsed_config: The parsed value from the web ui
:type parsed_config: Config object
:param fields: list of values to parse, will parse by the list order
:type fields: List[str]
:return: parsed value if found, None else
"""
try:
return parsed_config.get("api").get(fields[0])
except ConfigMissingException: # fallback - try to parse the field like it was in web older version
if len(fields) == 1:
return parsed_config.get(fields[0])
elif len(fields) == 2:
return parsed_config.get(fields[1])
else:
return None
def read_manual_credentials(): def read_manual_credentials():
print('Enter user access key: ', end='') print('Enter user access key: ', end='')
access_key = input() access_key = input()

View File

@ -1,3 +1,3 @@
from .config_parser import ConfigParser, ConfigFactory from .config_parser import ConfigParser, ConfigFactory, ConfigMissingException
from .config_tree import ConfigTree from .config_tree import ConfigTree
from .converter import HOCONConverter from .converter import HOCONConverter