Add default output URI selection to "clearml-agent init"

This commit is contained in:
allegroai 2022-12-07 22:06:10 +02:00
parent 4c777fa2ee
commit 6be75abc86

View File

@ -1,14 +1,15 @@
from __future__ import print_function from __future__ import print_function
from six.moves import input from typing import Dict, Optional
from pathlib2 import Path from pathlib2 import Path
from six.moves import input
from six.moves.urllib.parse import urlparse from six.moves.urllib.parse import urlparse
from clearml_agent.external.pyhocon import ConfigFactory, ConfigMissingException
from clearml_agent.backend_api.session import Session from clearml_agent.backend_api.session import Session
from clearml_agent.backend_api.session.defs import ENV_HOST from clearml_agent.backend_api.session.defs import ENV_HOST
from clearml_agent.backend_config.defs import LOCAL_CONFIG_FILES from clearml_agent.backend_config.defs import LOCAL_CONFIG_FILES
from clearml_agent.external.pyhocon import ConfigFactory, ConfigMissingException
description = """ description = """
Please create new clearml credentials through the settings page in your `clearml-server` web app, Please create new clearml credentials through the settings page in your `clearml-server` web app,
@ -112,6 +113,21 @@ def main():
print('Exiting setup without creating configuration file') print('Exiting setup without creating configuration file')
return return
selection = input_options(
'Default Output URI (used to automatically store models and artifacts)',
{'N': 'None', 'S': 'ClearML Server', 'C': 'Custom'},
default='None'
)
if selection == 'Custom':
print('Custom Default Output URI: ', end='')
default_output_uri = input().strip()
elif selection == "ClearML Server":
default_output_uri = files_host
else:
default_output_uri = None
print('\nDefault Output URI: {}'.format(default_output_uri if default_output_uri else 'not set'))
# get GIT User/Pass for cloning # get GIT User/Pass for cloning
print('Enter git username for repository cloning (leave blank for SSH key authentication): [] ', end='') print('Enter git username for repository cloning (leave blank for SSH key authentication): [] ', end='')
git_user = input() git_user = input()
@ -179,6 +195,13 @@ def main():
'agent.package_manager.extra_index_url= ' \ 'agent.package_manager.extra_index_url= ' \
'[\n{}\n]\n\n'.format("\n".join(map("\"{}\"".format, extra_index_urls))) '[\n{}\n]\n\n'.format("\n".join(map("\"{}\"".format, extra_index_urls)))
f.write(extra_index_str) f.write(extra_index_str)
if default_output_uri:
default_output_url_str = '# Default Task output_uri. if output_uri is not provided to Task.init, ' \
'default_output_uri will be used instead.\n' \
'sdk.development.default_output_uri="{}"\n' \
'\n'.format(default_output_uri.strip('"'))
f.write(default_output_url_str)
default_conf = default_conf.replace('default_output_uri: ""', '# default_output_uri: ""')
f.write(default_conf) f.write(default_conf)
except Exception: except Exception:
print('Error! Could not write configuration file at: {}'.format(str(conf_file))) print('Error! Could not write configuration file at: {}'.format(str(conf_file)))
@ -305,6 +328,25 @@ def input_url(host_type, host=None):
return host return host
def input_options(message, options, default=None):
# type: (str, Dict[str, str], Optional[str]) -> str
options_msg = "/".join(
"".join(('(' + c.upper() + ')') if c == o else c for c in option)
for o, option in options.items()
)
if default:
options_msg += " [{}]".format(default)
while True:
print('{}: {} '.format(message, options_msg), end='')
res = input().strip()
if not res:
return default
elif res.lower() in options:
return options[res.lower()]
elif res.upper() in options:
return options[res.upper()]
def input_host_port(host_type, parsed_host): def input_host_port(host_type, parsed_host):
print('Enter port for {} host '.format(host_type), end='') print('Enter port for {} host '.format(host_type), end='')
replace_port = input().lower() replace_port = input().lower()