From 6be75abc862e27ef0f62bd40374ffe91889497b3 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 7 Dec 2022 22:06:10 +0200 Subject: [PATCH] Add default output URI selection to "clearml-agent init" --- clearml_agent/commands/config.py | 48 ++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/clearml_agent/commands/config.py b/clearml_agent/commands/config.py index 4c8d7a6..7ca4a15 100644 --- a/clearml_agent/commands/config.py +++ b/clearml_agent/commands/config.py @@ -1,14 +1,15 @@ from __future__ import print_function -from six.moves import input +from typing import Dict, Optional + from pathlib2 import Path +from six.moves import input from six.moves.urllib.parse import urlparse -from clearml_agent.external.pyhocon import ConfigFactory, ConfigMissingException from clearml_agent.backend_api.session import Session from clearml_agent.backend_api.session.defs import ENV_HOST from clearml_agent.backend_config.defs import LOCAL_CONFIG_FILES - +from clearml_agent.external.pyhocon import ConfigFactory, ConfigMissingException description = """ Please create new clearml credentials through the settings page in your `clearml-server` web app, @@ -112,6 +113,21 @@ def main(): print('Exiting setup without creating configuration file') return + selection = input_options( + 'Default Output URI (used to automatically store models and artifacts)', + {'N': 'None', 'S': 'ClearML Server', 'C': 'Custom'}, + default='None' + ) + if selection == 'Custom': + print('Custom Default Output URI: ', end='') + default_output_uri = input().strip() + elif selection == "ClearML Server": + default_output_uri = files_host + else: + default_output_uri = None + + print('\nDefault Output URI: {}'.format(default_output_uri if default_output_uri else 'not set')) + # get GIT User/Pass for cloning print('Enter git username for repository cloning (leave blank for SSH key authentication): [] ', end='') git_user = input() @@ -179,6 +195,13 @@ def main(): 'agent.package_manager.extra_index_url= ' \ '[\n{}\n]\n\n'.format("\n".join(map("\"{}\"".format, extra_index_urls))) f.write(extra_index_str) + if default_output_uri: + default_output_url_str = '# Default Task output_uri. if output_uri is not provided to Task.init, ' \ + 'default_output_uri will be used instead.\n' \ + 'sdk.development.default_output_uri="{}"\n' \ + '\n'.format(default_output_uri.strip('"')) + f.write(default_output_url_str) + default_conf = default_conf.replace('default_output_uri: ""', '# default_output_uri: ""') f.write(default_conf) except Exception: print('Error! Could not write configuration file at: {}'.format(str(conf_file))) @@ -305,6 +328,25 @@ def input_url(host_type, host=None): return host +def input_options(message, options, default=None): + # type: (str, Dict[str, str], Optional[str]) -> str + options_msg = "/".join( + "".join(('(' + c.upper() + ')') if c == o else c for c in option) + for o, option in options.items() + ) + if default: + options_msg += " [{}]".format(default) + while True: + print('{}: {} '.format(message, options_msg), end='') + res = input().strip() + if not res: + return default + elif res.lower() in options: + return options[res.lower()] + elif res.upper() in options: + return options[res.upper()] + + def input_host_port(host_type, parsed_host): print('Enter port for {} host '.format(host_type), end='') replace_port = input().lower()