diff --git a/README.md b/README.md index 3f5ee42..63a96a2 100644 --- a/README.md +++ b/README.md @@ -238,7 +238,7 @@ usage: clearml-session [-h] [--version] [--attach [ATTACH]] [--queue-excluded-tag [QUEUE_EXCLUDED_TAG [QUEUE_EXCLUDED_TAG ...]]] [--queue-include-tag [QUEUE_INCLUDE_TAG [QUEUE_INCLUDE_TAG ...]]] [--skip-docker-network] [--password PASSWORD] - [--username USERNAME] [--verbose] + [--username USERNAME] [--verbose] [--yes] clearml-session - CLI for launching JupyterLab / VSCode on a remote machine @@ -331,6 +331,8 @@ optional arguments: session (default: `root` or previously used one) --verbose Advanced: If set, print verbose progress information, e.g. the remote machine setup process log - + --yes, -y Automatic yes to prompts; assume "yes" as answer to + all prompts and run non-interactively + Notice! all arguments are stored as new defaults for the next session ``` diff --git a/clearml_session/__main__.py b/clearml_session/__main__.py index 65735d6..492bd86 100644 --- a/clearml_session/__main__.py +++ b/clearml_session/__main__.py @@ -25,11 +25,14 @@ from clearml.backend_api import Session from .tcp_proxy import TcpProxy from .single_thread_proxy import SingleThreadProxy - system_tag = 'interactive' default_docker_image = 'nvidia/cuda:10.1-runtime-ubuntu18.04' +class NonInteractiveError(Exception): + pass + + def _read_std_input(timeout): # wait for user input with timeout, return None if timeout or user input if sys.platform == 'win32': @@ -156,6 +159,7 @@ def create_base_task(state, project_name=None, task_name=None): section, _, _ = _get_config_section_name() if Session.check_min_api_version('2.13'): + # noinspection PyProtectedMember _runtime_prop = dict(task._get_runtime_properties()) _runtime_prop.update({ "_user_key": '', @@ -226,6 +230,7 @@ def create_debugging_task(state, debug_task_id): section, _, _ = _get_config_section_name() if Session.check_min_api_version('2.13'): + # noinspection PyProtectedMember _runtime_prop = dict(task._get_runtime_properties()) _runtime_prop.update({ "_user_key": '', @@ -342,6 +347,7 @@ def get_project_id(project_name): def get_user_inputs(args, parser, state, client): default_needed_args = tuple() + assume_yes = args.yes user_args = sorted([a for a in args.__dict__ if not a.startswith('_')]) # clear some states if we replace the base_task_id @@ -373,6 +379,9 @@ def get_user_inputs(args, parser, state, client): state[a] = v if a in default_needed_args and not state.get(a): + if assume_yes: + raise NonInteractiveError( + "Using `--yes` but could not locate previously used value of '{}'".format(a)) # noinspection PyProtectedMember state[a] = input( "\nCould not locate previously used value of '{}', please provide it?" @@ -392,10 +401,16 @@ def get_user_inputs(args, parser, state, client): # allow to select queue ask_queues = not state.get('queue') - if state.get('queue'): + + if assume_yes: + if ask_queues: + raise NonInteractiveError("Using `--yes` but no queue provided or previously used") + print("Using previous queue (resource) '{}'".format(state["queue"])) + elif state.get('queue'): choice = input('Use previous queue (resource) \'{}\' [Y]/n? '.format(state['queue'])) if str(choice).strip().lower() in ('n', 'no'): ask_queues = True + if ask_queues: print('Select the queue (resource) you request:') queues = None @@ -422,6 +437,10 @@ def get_user_inputs(args, parser, state, client): print("\nInteractive session config:\n{}\n".format( json.dumps({k: v for k, v in state.items() if not str(k).startswith('__')}, indent=4, sort_keys=True))) + # no need to ask just return the value + if assume_yes: + return state + choice = input('Launch interactive session [Y]/n? ') if str(choice).strip().lower() in ('n', 'no'): print('User aborted') @@ -452,8 +471,9 @@ def load_state(state_file): state = json.load(f) except Exception: state = {} - # never reload --verbose state + # never reload --verbose and --yes states state.pop('verbose', None) + state.pop('yes', None) return state @@ -719,7 +739,7 @@ def start_ssh_tunnel(username, remote_address, ssh_port, ssh_password, local_rem raise ValueError('Incorrect password') except pexpect.TIMEOUT: pass - except Exception as ex: + except Exception: child.terminate(force=True) child = None print('\n') @@ -968,6 +988,10 @@ def setup_parser(parser): parser.add_argument('--verbose', action='store_true', default=None, help='Advanced: If set, print verbose progress information, ' 'e.g. the remote machine setup process log') + parser.add_argument("--yes", "-y", + action="store_true", default=False, + help="Automatic yes to prompts; assume \"yes\" as answer " + "to all prompts and run non-interactively",) def get_version(): @@ -1056,6 +1080,7 @@ def cli(): def _check_previous_session(client, args, state): + assume_yes = args.yes # now let's see if we have the requested Task if args.attach: task_id = args.attach @@ -1087,9 +1112,12 @@ def _check_previous_session(client, args, state): # a single running session if len(running_task_ids_created) == 1: task_id = running_task_ids_created[0][0] - choice = input('Connect to active session id={} [Y]/n? '.format(task_id)) - if str(choice).strip().lower() not in ('n', 'no'): - return Task.get_task(task_id=task_id) + if assume_yes: + print("Connecting to active session {}".format(task_id)) + else: + choice = input("Connect to active session id={} [Y]/n? ".format(task_id)) + if str(choice).strip().lower() not in ("n", "no"): + return Task.get_task(task_id=task_id) # multiple sessions running print('Active sessions:') @@ -1099,23 +1127,31 @@ def _check_previous_session(client, args, state): except StopIteration: default_i = None - session_list = '\n'.join( - '{}{}] {} id={}'.format(i, '*' if i == default_i else '', dt.strftime("%Y-%m-%d %H:%M:%S"), tid) - for i, (tid, dt, _) in enumerate(running_task_ids_created)) - while True: - try: - choice = input(session_list+'\nConnect to session [{}] or \'N\' to skip: '.format( - '0' if len(running_task_ids_created) <= 1 else '0-{}'.format(len(running_task_ids_created)-1))) - if choice.strip().lower().startswith('n'): - choice = None - elif default_i is not None and not choice.strip(): - choice = default_i - else: - choice = int(choice) - assert 0 <= choice < len(running_task_ids_created) - break - except (TypeError, ValueError, AssertionError): - pass + session_list = "\n".join( + "{}{}] {} id={}".format(i, "*" if i == default_i else "", dt.strftime("%Y-%m-%d %H:%M:%S"), tid) + for i, (tid, dt, _) in enumerate(running_task_ids_created) + ) + if assume_yes: + choice = 0 + else: + while True: + try: + choice = input( + session_list + + "\nConnect to session [{}] or 'N' to skip: ".format( + "0" if len(running_task_ids_created) <= 1 else "0-{}".format(len(running_task_ids_created) - 1) + ) + ) + if choice.strip().lower().startswith("n"): + choice = None + elif default_i is not None and not choice.strip(): + choice = default_i + else: + choice = int(choice) + assert 0 <= choice < len(running_task_ids_created) + break + except (TypeError, ValueError, AssertionError): + pass if choice is None: return None return Task.get_task(task_id=running_task_ids_created[choice][0])