mirror of
https://github.com/clearml/clearml-agent
synced 2025-05-08 22:09:33 +00:00
Improve configuration wizard
This commit is contained in:
parent
88f1031e5d
commit
1f0bb4906b
@ -1,19 +1,21 @@
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from six.moves import input
|
from six.moves import input
|
||||||
from pyhocon import ConfigFactory
|
from pyhocon import ConfigFactory, ConfigMissingException
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
from six.moves.urllib.parse import urlparse
|
from six.moves.urllib.parse import urlparse
|
||||||
|
|
||||||
|
from trains_agent.backend_api.session import Session
|
||||||
from trains_agent.backend_api.session.defs import ENV_HOST
|
from trains_agent.backend_api.session.defs import ENV_HOST
|
||||||
from trains_agent.backend_config.defs import LOCAL_CONFIG_FILES
|
from trains_agent.backend_config.defs import LOCAL_CONFIG_FILES
|
||||||
|
|
||||||
|
|
||||||
description = """
|
description = """
|
||||||
Please create new credentials using the web app: {}/profile
|
Please create new trains credentials through the profile page in your trains web app (e.g. https://demoapp.trains.allegro.ai/profile)
|
||||||
In the Admin page, press "Create new credentials", then press "Copy to clipboard"
|
In the profile page, press "Create new credentials", then press "Copy to clipboard".
|
||||||
|
|
||||||
Paste credentials here: """
|
Paste copied configuration here:
|
||||||
|
"""
|
||||||
|
|
||||||
def_host = 'http://localhost:8080'
|
def_host = 'http://localhost:8080'
|
||||||
try:
|
try:
|
||||||
@ -38,20 +40,39 @@ def main():
|
|||||||
print('Leaving setup, feel free to edit the configuration file.')
|
print('Leaving setup, feel free to edit the configuration file.')
|
||||||
return
|
return
|
||||||
|
|
||||||
print(host_description)
|
print(description, end='')
|
||||||
web_host = input_url('Web Application Host', '')
|
sentinel = ''
|
||||||
parsed_host = verify_url(web_host)
|
parse_input = '\n'.join(iter(input, sentinel))
|
||||||
|
credentials = None
|
||||||
|
api_host = None
|
||||||
|
web_server = None
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
parsed = ConfigFactory.parse_string(parse_input)
|
||||||
|
if parsed:
|
||||||
|
# Take the credentials in raw form or from api section
|
||||||
|
credentials = get_parsed_field(parsed, ["credentials"])
|
||||||
|
api_host = get_parsed_field(parsed, ["api_server", "host"])
|
||||||
|
web_server = get_parsed_field(parsed, ["web_server"])
|
||||||
|
except Exception:
|
||||||
|
credentials = credentials or None
|
||||||
|
api_host = api_host or None
|
||||||
|
web_server = web_server or None
|
||||||
|
|
||||||
if parsed_host.port == 8008:
|
while not credentials or set(credentials) != {"access_key", "secret_key"}:
|
||||||
print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application')
|
print('Could not parse credentials, please try entering them manually.')
|
||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
credentials = read_manual_credentials()
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
|
|
||||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path
|
print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'],
|
||||||
elif parsed_host.port == 8080:
|
credentials['secret_key'][0:4] + "***"))
|
||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
|
if api_host:
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
api_host = input_url('API Host', api_host)
|
||||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
|
else:
|
||||||
elif parsed_host.netloc.startswith('demoapp.'):
|
print(host_description)
|
||||||
|
api_host = input_url('API Host', '')
|
||||||
|
parsed_host = verify_url(api_host)
|
||||||
|
|
||||||
|
if parsed_host.netloc.startswith('demoapp.'):
|
||||||
# this is our demo server
|
# this is our demo server
|
||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
@ -73,6 +94,15 @@ def main():
|
|||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
|
||||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path
|
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path
|
||||||
|
elif parsed_host.port == 8008:
|
||||||
|
print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application')
|
||||||
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
|
||||||
|
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path
|
||||||
|
elif parsed_host.port == 8080:
|
||||||
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
|
||||||
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
|
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
|
||||||
else:
|
else:
|
||||||
api_host = ''
|
api_host = ''
|
||||||
web_host = ''
|
web_host = ''
|
||||||
@ -91,47 +121,23 @@ def main():
|
|||||||
if not api_host:
|
if not api_host:
|
||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
|
|
||||||
api_host = input_url('API Host', api_host)
|
web_host = input_url('Web Application Host', web_server if web_server else web_host)
|
||||||
files_host = input_url('File Store Host', files_host)
|
files_host = input_url('File Store Host', files_host)
|
||||||
|
|
||||||
print('\nTRAINS Hosts configuration:\nWeb App: {}\nAPI: {}\nFile Store: {}\n'.format(
|
print('\nTRAINS Hosts configuration:\nWeb App: {}\nAPI: {}\nFile Store: {}\n'.format(
|
||||||
web_host, api_host, files_host))
|
web_host, api_host, files_host))
|
||||||
|
|
||||||
while True:
|
retry = 1
|
||||||
print(description.format(web_host), end='')
|
max_retries = 2
|
||||||
parse_input = input()
|
while retry <= max_retries: # Up to 2 tries by the user
|
||||||
# check if these are valid credentials
|
if verify_credentials(api_host, credentials):
|
||||||
credentials = None
|
|
||||||
# noinspection PyBroadException
|
|
||||||
try:
|
|
||||||
parsed = ConfigFactory.parse_string(parse_input)
|
|
||||||
if parsed:
|
|
||||||
credentials = parsed.get("credentials", None)
|
|
||||||
except Exception:
|
|
||||||
credentials = None
|
|
||||||
|
|
||||||
if not credentials or set(credentials) != {"access_key", "secret_key"}:
|
|
||||||
print('Could not parse user credentials, try again one after the other.')
|
|
||||||
credentials = {}
|
|
||||||
# parse individual
|
|
||||||
print('Enter user access key: ', end='')
|
|
||||||
credentials['access_key'] = input()
|
|
||||||
print('Enter user secret: ', end='')
|
|
||||||
credentials['secret_key'] = input()
|
|
||||||
|
|
||||||
print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'],
|
|
||||||
credentials['secret_key'], ))
|
|
||||||
|
|
||||||
from trains_agent.backend_api.session import Session
|
|
||||||
# noinspection PyBroadException
|
|
||||||
try:
|
|
||||||
print('Verifying credentials ...')
|
|
||||||
Session(api_key=credentials['access_key'], secret_key=credentials['secret_key'], host=api_host)
|
|
||||||
print('Credentials verified!')
|
|
||||||
break
|
break
|
||||||
except Exception:
|
retry += 1
|
||||||
print('Error: could not verify credentials: host={} access={} secret={}'.format(
|
if retry < max_retries + 1:
|
||||||
api_host, credentials['access_key'], credentials['secret_key']))
|
credentials = read_manual_credentials()
|
||||||
|
else:
|
||||||
|
print('Exiting setup without creating configuration file')
|
||||||
|
return
|
||||||
|
|
||||||
# get GIT User/Pass for cloning
|
# get GIT User/Pass for cloning
|
||||||
print('Enter git username for repository cloning (leave blank for SSH key authentication): [] ', end='')
|
print('Enter git username for repository cloning (leave blank for SSH key authentication): [] ', end='')
|
||||||
@ -186,6 +192,52 @@ def main():
|
|||||||
print('TRAINS-AGENT setup completed successfully.')
|
print('TRAINS-AGENT setup completed successfully.')
|
||||||
|
|
||||||
|
|
||||||
|
def verify_credentials(api_host, credentials):
|
||||||
|
"""check if the credentials are valid"""
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
print('Verifying credentials ...')
|
||||||
|
if api_host:
|
||||||
|
Session(api_key=credentials['access_key'], secret_key=credentials['secret_key'], host=api_host)
|
||||||
|
print('Credentials verified!')
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("Can't verify credentials")
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
print('Error: could not verify credentials: key={} secret={}'.format(
|
||||||
|
credentials.get('access_key'), credentials.get('secret_key')))
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_parsed_field(parsed_config, fields):
|
||||||
|
"""
|
||||||
|
Parsed the value from web profile page, 'copy to clipboard' option
|
||||||
|
:param parsed_config: The parsed value from the web ui
|
||||||
|
:type parsed_config: Config object
|
||||||
|
:param fields: list of values to parse, will parse by the list order
|
||||||
|
:type fields: List[str]
|
||||||
|
:return: parsed value if found, None else
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return parsed_config.get("api").get(fields[0])
|
||||||
|
except ConfigMissingException: # fallback - try to parse the field like it was in web older version
|
||||||
|
if len(fields) == 1:
|
||||||
|
return parsed_config.get(fields[0])
|
||||||
|
elif len(fields) == 2:
|
||||||
|
return parsed_config.get(fields[1])
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def read_manual_credentials():
|
||||||
|
print('Enter user access key: ', end='')
|
||||||
|
access_key = input()
|
||||||
|
print('Enter user secret: ', end='')
|
||||||
|
secret_key = input()
|
||||||
|
return {"access_key": access_key, "secret_key": secret_key}
|
||||||
|
|
||||||
|
|
||||||
def input_url(host_type, host=None):
|
def input_url(host_type, host=None):
|
||||||
while True:
|
while True:
|
||||||
print('{} configured to: [{}] '.format(host_type, host), end='')
|
print('{} configured to: [{}] '.format(host_type, host), end='')
|
||||||
|
Loading…
Reference in New Issue
Block a user