diff --git a/clearml_session/__main__.py b/clearml_session/__main__.py index 00540d5..f68caa7 100644 --- a/clearml_session/__main__.py +++ b/clearml_session/__main__.py @@ -671,6 +671,7 @@ def clone_task(state, project_id=None): task_params["{}/vscode_extensions".format(section)] = state.get('vscode_extensions') or '' task_params["{}/force_dropbear".format(section)] = bool(state.get('force_dropbear')) task_params["{}/store_workspace".format(section)] = state.get('store_workspace') + task_params["{}/use_ssh_proxy".format(section)] = state.get('keepalive') if state.get('user_folder'): task_params['{}/user_base_directory'.format(section)] = state.get('user_folder') docker = state.get('docker') or task.get_base_docker() diff --git a/clearml_session/interactive_session_task.py b/clearml_session/interactive_session_task.py index c557886..c7056e6 100644 --- a/clearml_session/interactive_session_task.py +++ b/clearml_session/interactive_session_task.py @@ -577,7 +577,11 @@ def setup_ssh_server(hostname, hostnames, param, task, env): min_port = int(ssh_port.split(":")[0]) max_port = max(min_port+32, int(ssh_port.split(":")[-1])) port = get_free_port(min_port, max_port) - proxy_port = get_free_port(min_port, max_port) + if param.get("use_ssh_proxy"): + proxy_port = port + port = get_free_port(min_port, max_port) + else: + proxy_port = None use_dropbear = bool(param.get("force_dropbear", False)) # if we are root, install open-ssh @@ -629,7 +633,7 @@ def setup_ssh_server(hostname, hostnames, param, task, env): except Exception: # noinspection PyBroadException try: - print('SSHd was not found default to user space dropbear sshd server') + print('WARNING: SSHd was not found defaulting to user-space dropbear sshd server') dropbear_download_link = \ os.environ.get("CLEARML_DROPBEAR_EXEC") or \ 'https://github.com/allegroai/dropbear/releases/download/DROPBEAR_CLEARML_2023.02/dropbearmulti' @@ -678,6 +682,15 @@ def setup_ssh_server(hostname, hostnames, param, task, env): ssh_password = None task.set_parameter('{}/ssh_password'.format(config_section_name), '') + # get current user: + # noinspection PyBroadException + try: + current_user = getpass.getuser() or "root" + except Exception: + # we failed getting the user, let's assume root + print("Warning: failed getting active user name, assuming 'root'") + current_user = "root" + # create fingerprint files Path(ssh_config_path).mkdir(parents=True, exist_ok=True) keys_filename = {} @@ -689,8 +702,7 @@ def setup_ssh_server(hostname, hostnames, param, task, env): pass if v: with open(filename, 'wt') as f: - f.write(v + (' {}@{}'.format( - getpass.getuser() or "root", hostname) if filename.endswith('.pub') else '')) + f.write(v + (' {}@{}'.format(current_user, hostname) if filename.endswith('.pub') else '')) os.chmod(filename, 0o600 if filename.endswith('.pub') else 0o600) keys_filename[k] = filename @@ -727,19 +739,26 @@ def setup_ssh_server(hostname, hostnames, param, task, env): if result != 0: raise ValueError("Failed launching sshd: ", proc_args) - # noinspection PyBroadException - try: - TcpProxy(listen_port=proxy_port, target_port=port, proxy_state={}, verbose=False, # noqa - keep_connection=True, is_connection_server=True) - except Exception as ex: - print('Warning: Could not setup stable ssh port, {}'.format(ex)) - proxy_port = None + if proxy_port: + # noinspection PyBroadException + try: + TcpProxy(listen_port=proxy_port, target_port=port, proxy_state={}, verbose=False, # noqa + keep_connection=True, is_connection_server=True) + except Exception as ex: + print('Warning: Could not setup stable ssh port, {}'.format(ex)) + proxy_port = None if task: if proxy_port: task.set_parameter(name='properties/internal_stable_ssh_port', value=str(proxy_port)) task.set_parameter(name='properties/internal_ssh_port', value=str(port)) - task._set_runtime_properties(runtime_properties={'internal_ssh_port': str(port)}) + # noinspection PyProtectedMember + task._set_runtime_properties( + runtime_properties={ + 'internal_ssh_port': str(proxy_port or port), + '_ssh_user': current_user, + } + ) print( "\n#\n# SSH Server running on {} [{}] port {}\n# LOGIN u:root p:{}\n#\n".format( @@ -970,14 +989,22 @@ def get_host_name(task, param): pass # update host name - if not task.get_parameter(name='properties/external_address'): - external_addr = hostnames - if param.get('public_ip'): - # noinspection PyBroadException - try: - external_addr = requests.get('https://checkip.amazonaws.com').text.strip() - except Exception: - pass + if (not task.get_parameter(name='properties/external_address') and + not task.get_parameter(name='properties/k8s-gateway-address')): + if task._get_runtime_properties().get("external_address"): + external_addr = task._get_runtime_properties().get("external_address") + else: + external_addr = hostnames + if param.get('public_ip'): + # noinspection PyBroadException + try: + external_addr = requests.get('https://checkip.amazonaws.com').text.strip() + except Exception: + pass + # make sure we set it to the runtime properties + task._set_runtime_properties({"external_address": external_addr}) + + # make sure we set it back to the Task user properties task.set_parameter(name='properties/external_address', value=str(external_addr)) return hostname, hostnames @@ -1207,6 +1234,7 @@ def main(): "ssh_ports": None, "force_dropbear": False, "store_workspace": None, + "use_ssh_proxy": False, } task = init_task(param, default_ssh_fingerprint)