diff --git a/README.md b/README.md index 7576195..474fd11 100644 --- a/README.md +++ b/README.md @@ -418,10 +418,15 @@ optional arguments: or previously used one) --username USERNAME Advanced: Select ssh username for the interactive session (default: `root` or previously used one) + --randomize Advanced: Recreate a new random ssh password for the interactive session options: + `--randomize` one time recreate, --randomize `always` create a new random password for + every session --force-dropbear [true/false] Force using `dropbear` instead of SSHd --disable-store-defaults If set, do not store current setup as new default configuration + --disable-fingerprint-check + Advanced: If set, ignore the remote SSH server fingerprint check --verbose Advanced: If set, print verbose progress information, e.g. the remote machine setup process log --yes, -y Automatic yes to prompts; assume "yes" as answer to all prompts and run non-interactively diff --git a/clearml_session/__main__.py b/clearml_session/__main__.py index 25561ed..c09568b 100644 --- a/clearml_session/__main__.py +++ b/clearml_session/__main__.py @@ -10,6 +10,7 @@ from argparse import ArgumentParser, FileType from functools import reduce, partial from getpass import getpass from io import TextIOBase, StringIO +from tempfile import NamedTemporaryFile from time import time, sleep from uuid import uuid4 @@ -468,7 +469,7 @@ def get_user_inputs(args, parser, state, client): for a in user_args: v = getattr(args, a, None) - if a in ('requirements', 'packages', 'attach', 'config_file'): + if a in ('requirements', 'packages', 'attach', 'config_file', 'randomize'): continue if isinstance(v, TextIOBase): state[a] = v.read() @@ -491,8 +492,9 @@ def get_user_inputs(args, parser, state, client): "\nCould not locate previously used value of '{}', please provide it?" "\n Help: {}\n> ".format( a, parser._option_string_actions['--{}'.format(a.replace('_', '-'))].help)) + # if no password was set, create a new random one - if not state.get('password'): + if not state.get('password') or state.get("randomize") is not False: state['password'] = hashlib.sha256("seed me {} {}".format(uuid4(), time()).encode()).hexdigest() # store the requirements from the requirements.txt @@ -576,7 +578,7 @@ def save_state(state, state_file): print("INFO: current configuration stored as new default") -def load_state(state_file): +def load_state(state_file, args=None): # noinspection PyBroadException try: with open(state_file, 'rt') as f: @@ -590,6 +592,25 @@ def load_state(state_file): state.pop('upload_files', None) state.pop('continue_session', None) state.pop('disable_store_defaults', None) + state.pop('disable_fingerprint_check', None) + + # update back based on args + if args: + # make sure we can override randomize + if "always" in (state.get("randomize") or []) and args.randomize is not False: + state["randomize"] = [] + elif args.randomize is False and "always" not in (state.get("randomize") or []): + state["randomize"] = False + elif "always" in (args.randomize or []): + state["randomize"] = ["always"] + + if args.verbose: + state['verbose'] = args.verbose + + state['shell'] = bool(args.shell) + state['disable_store_defaults'] = bool(args.disable_store_defaults) + state['disable_fingerprint_check'] = bool(args.disable_fingerprint_check) + return state @@ -832,17 +853,50 @@ def wait_for_machine(state, task, only_wait_for_ssh=False): return task -def start_ssh_tunnel(username, remote_address, ssh_port, ssh_password, local_remote_pair_list, debug=False): +def start_ssh_tunnel(username, remote_address, ssh_port, ssh_password, local_remote_pair_list, + debug=False, task=None, ignore_fingerprint_verification=False): print('Starting SSH tunnel to {}@{}, port {}'.format(username, remote_address, ssh_port)) child = None args = ['-C', '{}@{}'.format(username, remote_address), '-p', '{}'.format(ssh_port), - '-o', 'UserKnownHostsFile=/dev/null', '-o', 'Compression=yes', - '-o', 'StrictHostKeyChecking=no', '-o', 'ServerAliveInterval=10', '-o', 'ServerAliveCountMax=10', ] + found_server_ssh_fingerprint = None + if task: + if Session.check_min_api_version('2.20'): + # noinspection PyBroadException + try: + res = task.session.send_request( + "users", "get_vaults", + params="enabled=true&types=remote_session_ssh_server&" + "types=remote_session_ssh_server").json() + found_server_ssh_fingerprint = json.loads(res['data']['vaults'][-1]['data']) + except Exception: + pass + + known_host_lines = "" + if found_server_ssh_fingerprint: + # create the known host file + for k in found_server_ssh_fingerprint: + if k.endswith("__pub"): + known_host_lines += "{} {}\n".format(remote_address, found_server_ssh_fingerprint[k]) + + temp_host_file = None + if known_host_lines: + print("SECURING CONNECTION: using secure remote host fingerprinting") + temp_host_file = NamedTemporaryFile( + prefix="remote_ssh_host_", suffix=".pub", mode="wt", delete=True) + temp_host_file.write(known_host_lines) + temp_host_file.flush() + args += ['-o', 'UserKnownHostsFile={}'.format(temp_host_file.name)] + else: + args += [ + '-o', 'UserKnownHostsFile=/dev/null', + '-o', 'StrictHostKeyChecking=no', + ] + for local, remote in local_remote_pair_list: args.extend(['-L', '{}:localhost:{}'.format(local, remote)]) @@ -875,6 +929,14 @@ def start_ssh_tunnel(username, remote_address, ssh_port, ssh_password, local_rem pass elif i == 1: + if known_host_lines: + print("{}! Secure fingerprint of remote server failed to verify!".format( + "WARNING" if ignore_fingerprint_verification else "ERROR")) + if not ignore_fingerprint_verification: + # we should have never gotten here! + child.terminate(force=True) + exit(1) + child.sendline("yes") ret1 = child.expect([r"(?i)password:", pexpect.EOF]) if ret1 == 0: @@ -909,7 +971,7 @@ def start_ssh_tunnel(username, remote_address, ssh_port, ssh_password, local_rem print('\n') if child: child.logfile = None - return child, ssh_password + return child, ssh_password, temp_host_file def monitor_ssh_tunnel(state, task, ssh_setup_completed_callback=None): @@ -1037,11 +1099,13 @@ def monitor_ssh_tunnel(state, task, ssh_setup_completed_callback=None): "Enter \"r\" (\"reconnect\"), `s` (\"shell\"), `Ctrl-C` (\"quit\") or \"Shutdown\"" if not ssh_process or not ssh_process.isalive(): - ssh_process, ssh_password = start_ssh_tunnel( + ssh_process, ssh_password, known_host_file = start_ssh_tunnel( state.get('username') or 'root', remote_address, ssh_port, ssh_password, local_remote_pair_list=local_remote_pair_list, debug=state.get('verbose', False), + task=task, + ignore_fingerprint_verification=state.get('disable_fingerprint_check', False), ) if ssh_process and ssh_process.isalive(): @@ -1358,6 +1422,10 @@ def setup_parser(parser): parser.add_argument('--password', type=str, default=None, help='Advanced: Select ssh password for the interactive session ' '(default: `randomly-generated` or previously used one)') + parser.add_argument('--randomize', type=str, nargs='*', default=False, + help='Advanced: Recreate a new random ssh password for the interactive session ' + 'options: `--randomize` one time recreate random password, ' + '--randomize `always` create a new random password for every session') parser.add_argument('--username', type=str, default=None, help='Advanced: Select ssh username for the interactive session ' '(default: `root` or previously used one)') @@ -1366,6 +1434,8 @@ def setup_parser(parser): help='Force using `dropbear` instead of SSHd') parser.add_argument('--disable-store-defaults', action='store_true', default=None, help='If set, do not store current setup as new default configuration') + parser.add_argument('--disable-fingerprint-check', action='store_true', default=None, + help='Advanced: If set, ignore the remote SSH server fingerprint check') parser.add_argument('--verbose', action='store_true', default=None, help='Advanced: If set, print verbose progress information, ' 'e.g. the remote machine setup process log') @@ -1417,13 +1487,7 @@ def cli(): # load previous state state_file = os.path.abspath(os.path.expandvars(os.path.expanduser(args.config_file))) - state = load_state(state_file) - - if args.verbose: - state['verbose'] = args.verbose - - state['shell'] = bool(args.shell) - state['disable_store_defaults'] = bool(args.disable_store_defaults) + state = load_state(state_file, args=args) if args.command: if args.command in ("info", "shutdown") and not args.id: diff --git a/clearml_session/interactive_session_task.py b/clearml_session/interactive_session_task.py index 5e27cb2..91253ed 100644 --- a/clearml_session/interactive_session_task.py +++ b/clearml_session/interactive_session_task.py @@ -127,13 +127,36 @@ def init_task(param, a_default_ssh_fingerprint): # connect ssh fingerprint configuration (with fallback if section is missing) old_default_ssh_fingerprint = deepcopy(a_default_ssh_fingerprint) - try: - task.connect_configuration(configuration=a_default_ssh_fingerprint, name=config_object_section_ssh) - except (TypeError, ValueError): - a_default_ssh_fingerprint.clear() - a_default_ssh_fingerprint.update(old_default_ssh_fingerprint) + found_server_ssh_fingerprint = None + if Session.check_min_api_version('2.20'): + print("INFO: checking remote ssh server fingerprint from server vault") + # noinspection PyBroadException + try: + res = task.session.send_request( + "users", "get_vaults", + params="enabled=true&types=remote_session_ssh_server&" + "types=remote_session_ssh_server").json() + if res.get('data', {}).get('vaults'): + found_server_ssh_fingerprint = json.loads(res['data']['vaults'][-1]['data']) + a_default_ssh_fingerprint.update(found_server_ssh_fingerprint) + print("INFO: loading fingerprint from server vault successfully: {}".format( + list(found_server_ssh_fingerprint.keys()))) + else: + print("INFO: server side fingerprint was not found") + except Exception as ex: + print("DEBUG: server side fingerprint parsing error: {}".format(ex)) + + if not found_server_ssh_fingerprint: + try: + # print("DEBUG: loading fingerprint from task") + task.connect_configuration(configuration=a_default_ssh_fingerprint, name=config_object_section_ssh) + except (TypeError, ValueError): + a_default_ssh_fingerprint.clear() + a_default_ssh_fingerprint.update(old_default_ssh_fingerprint) + if param.get('default_docker') and task.running_locally(): task.set_base_docker("{} --network host".format(param['default_docker'])) + # leave local process, only run remotely task.execute_remotely() return task