mirror of
https://github.com/clearml/clearml-agent
synced 2025-01-31 09:06:52 +00:00
Improve configuration wizard
This commit is contained in:
parent
88f1031e5d
commit
1f0bb4906b
@ -1,19 +1,21 @@
|
||||
from __future__ import print_function
|
||||
|
||||
from six.moves import input
|
||||
from pyhocon import ConfigFactory
|
||||
from pyhocon import ConfigFactory, ConfigMissingException
|
||||
from pathlib2 import Path
|
||||
from six.moves.urllib.parse import urlparse
|
||||
|
||||
from trains_agent.backend_api.session import Session
|
||||
from trains_agent.backend_api.session.defs import ENV_HOST
|
||||
from trains_agent.backend_config.defs import LOCAL_CONFIG_FILES
|
||||
|
||||
|
||||
description = """
|
||||
Please create new credentials using the web app: {}/profile
|
||||
In the Admin page, press "Create new credentials", then press "Copy to clipboard"
|
||||
Please create new trains credentials through the profile page in your trains web app (e.g. https://demoapp.trains.allegro.ai/profile)
|
||||
In the profile page, press "Create new credentials", then press "Copy to clipboard".
|
||||
|
||||
Paste credentials here: """
|
||||
Paste copied configuration here:
|
||||
"""
|
||||
|
||||
def_host = 'http://localhost:8080'
|
||||
try:
|
||||
@ -38,20 +40,39 @@ def main():
|
||||
print('Leaving setup, feel free to edit the configuration file.')
|
||||
return
|
||||
|
||||
print(host_description)
|
||||
web_host = input_url('Web Application Host', '')
|
||||
parsed_host = verify_url(web_host)
|
||||
print(description, end='')
|
||||
sentinel = ''
|
||||
parse_input = '\n'.join(iter(input, sentinel))
|
||||
credentials = None
|
||||
api_host = None
|
||||
web_server = None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
parsed = ConfigFactory.parse_string(parse_input)
|
||||
if parsed:
|
||||
# Take the credentials in raw form or from api section
|
||||
credentials = get_parsed_field(parsed, ["credentials"])
|
||||
api_host = get_parsed_field(parsed, ["api_server", "host"])
|
||||
web_server = get_parsed_field(parsed, ["web_server"])
|
||||
except Exception:
|
||||
credentials = credentials or None
|
||||
api_host = api_host or None
|
||||
web_server = web_server or None
|
||||
|
||||
if parsed_host.port == 8008:
|
||||
print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application')
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path
|
||||
elif parsed_host.port == 8080:
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('demoapp.'):
|
||||
while not credentials or set(credentials) != {"access_key", "secret_key"}:
|
||||
print('Could not parse credentials, please try entering them manually.')
|
||||
credentials = read_manual_credentials()
|
||||
|
||||
print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'],
|
||||
credentials['secret_key'][0:4] + "***"))
|
||||
if api_host:
|
||||
api_host = input_url('API Host', api_host)
|
||||
else:
|
||||
print(host_description)
|
||||
api_host = input_url('API Host', '')
|
||||
parsed_host = verify_url(api_host)
|
||||
|
||||
if parsed_host.netloc.startswith('demoapp.'):
|
||||
# this is our demo server
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
@ -73,6 +94,15 @@ def main():
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path
|
||||
elif parsed_host.port == 8008:
|
||||
print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application')
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path
|
||||
elif parsed_host.port == 8080:
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
|
||||
else:
|
||||
api_host = ''
|
||||
web_host = ''
|
||||
@ -91,47 +121,23 @@ def main():
|
||||
if not api_host:
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
|
||||
api_host = input_url('API Host', api_host)
|
||||
web_host = input_url('Web Application Host', web_server if web_server else web_host)
|
||||
files_host = input_url('File Store Host', files_host)
|
||||
|
||||
print('\nTRAINS Hosts configuration:\nWeb App: {}\nAPI: {}\nFile Store: {}\n'.format(
|
||||
web_host, api_host, files_host))
|
||||
|
||||
while True:
|
||||
print(description.format(web_host), end='')
|
||||
parse_input = input()
|
||||
# check if these are valid credentials
|
||||
credentials = None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
parsed = ConfigFactory.parse_string(parse_input)
|
||||
if parsed:
|
||||
credentials = parsed.get("credentials", None)
|
||||
except Exception:
|
||||
credentials = None
|
||||
|
||||
if not credentials or set(credentials) != {"access_key", "secret_key"}:
|
||||
print('Could not parse user credentials, try again one after the other.')
|
||||
credentials = {}
|
||||
# parse individual
|
||||
print('Enter user access key: ', end='')
|
||||
credentials['access_key'] = input()
|
||||
print('Enter user secret: ', end='')
|
||||
credentials['secret_key'] = input()
|
||||
|
||||
print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'],
|
||||
credentials['secret_key'], ))
|
||||
|
||||
from trains_agent.backend_api.session import Session
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
print('Verifying credentials ...')
|
||||
Session(api_key=credentials['access_key'], secret_key=credentials['secret_key'], host=api_host)
|
||||
print('Credentials verified!')
|
||||
retry = 1
|
||||
max_retries = 2
|
||||
while retry <= max_retries: # Up to 2 tries by the user
|
||||
if verify_credentials(api_host, credentials):
|
||||
break
|
||||
except Exception:
|
||||
print('Error: could not verify credentials: host={} access={} secret={}'.format(
|
||||
api_host, credentials['access_key'], credentials['secret_key']))
|
||||
retry += 1
|
||||
if retry < max_retries + 1:
|
||||
credentials = read_manual_credentials()
|
||||
else:
|
||||
print('Exiting setup without creating configuration file')
|
||||
return
|
||||
|
||||
# get GIT User/Pass for cloning
|
||||
print('Enter git username for repository cloning (leave blank for SSH key authentication): [] ', end='')
|
||||
@ -186,6 +192,52 @@ def main():
|
||||
print('TRAINS-AGENT setup completed successfully.')
|
||||
|
||||
|
||||
def verify_credentials(api_host, credentials):
|
||||
"""check if the credentials are valid"""
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
print('Verifying credentials ...')
|
||||
if api_host:
|
||||
Session(api_key=credentials['access_key'], secret_key=credentials['secret_key'], host=api_host)
|
||||
print('Credentials verified!')
|
||||
return True
|
||||
else:
|
||||
print("Can't verify credentials")
|
||||
return False
|
||||
except Exception:
|
||||
print('Error: could not verify credentials: key={} secret={}'.format(
|
||||
credentials.get('access_key'), credentials.get('secret_key')))
|
||||
return False
|
||||
|
||||
|
||||
def get_parsed_field(parsed_config, fields):
|
||||
"""
|
||||
Parsed the value from web profile page, 'copy to clipboard' option
|
||||
:param parsed_config: The parsed value from the web ui
|
||||
:type parsed_config: Config object
|
||||
:param fields: list of values to parse, will parse by the list order
|
||||
:type fields: List[str]
|
||||
:return: parsed value if found, None else
|
||||
"""
|
||||
try:
|
||||
return parsed_config.get("api").get(fields[0])
|
||||
except ConfigMissingException: # fallback - try to parse the field like it was in web older version
|
||||
if len(fields) == 1:
|
||||
return parsed_config.get(fields[0])
|
||||
elif len(fields) == 2:
|
||||
return parsed_config.get(fields[1])
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def read_manual_credentials():
|
||||
print('Enter user access key: ', end='')
|
||||
access_key = input()
|
||||
print('Enter user secret: ', end='')
|
||||
secret_key = input()
|
||||
return {"access_key": access_key, "secret_key": secret_key}
|
||||
|
||||
|
||||
def input_url(host_type, host=None):
|
||||
while True:
|
||||
print('{} configured to: [{}] '.format(host_type, host), end='')
|
||||
|
Loading…
Reference in New Issue
Block a user