diff --git a/clearml_session/__main__.py b/clearml_session/__main__.py index 2fb9aa1..5490230 100644 --- a/clearml_session/__main__.py +++ b/clearml_session/__main__.py @@ -36,6 +36,7 @@ except Exception: system_tag = 'interactive' default_docker_image = 'nvidia/cuda:11.6.2-runtime-ubuntu20.04' +internal_tcp_port_request = 10022 class NonInteractiveError(Exception): @@ -186,6 +187,7 @@ def create_base_task(state, project_name=None, task_name=None, continue_task_id= "_user_secret": '', "_jupyter_token": '', "_ssh_password": "training", + "internal_tcp_port_request": str(internal_tcp_port_request), }) # noinspection PyProtectedMember task._set_runtime_properties(_runtime_prop) @@ -267,6 +269,7 @@ def create_debugging_task(state, debug_task_id, task_name=None, task_project_id= "_user_secret": '', "_jupyter_token": '', "_ssh_password": "training", + "internal_tcp_port_request": str(internal_tcp_port_request), }) # noinspection PyProtectedMember task._set_runtime_properties(_runtime_prop) @@ -635,6 +638,7 @@ def clone_task(state, project_id=None): runtime_properties['_ssh_password'] = str(state['password']) runtime_properties['_user_key'] = str(config_obj.get("api.credentials.access_key")) runtime_properties['_user_secret'] = (config_obj.get("api.credentials.secret_key")) + runtime_properties['internal_tcp_port_request'] = str(internal_tcp_port_request) # noinspection PyProtectedMember task._set_runtime_properties(runtime_properties) @@ -955,25 +959,35 @@ def monitor_ssh_tunnel(state, task): remote_address ]): task.reload() + internal_ssh_port = None + remote_address = None + ssh_port = None task_parameters = task.get_parameters() if Session.check_min_api_version("2.13"): # noinspection PyProtectedMember runtime_prop = task._get_runtime_properties() ssh_password = runtime_prop.get('_ssh_password') or state.get('password', '') jupyter_token = runtime_prop.get('_jupyter_token') + internal_ssh_port = runtime_prop.get('internal_tcp_port') + remote_address = runtime_prop.get('external_address') + ssh_port = runtime_prop.get('external_tcp_port') else: section = 'General' if 'General/ssh_server' in task_parameters else default_section ssh_password = task_parameters.get('{}/ssh_password'.format(section)) or state.get('password', '') jupyter_token = task_parameters.get('properties/jupyter_token') - remote_address = \ + remote_address = remote_address or \ task_parameters.get('properties/k8s-gateway-address') or \ task_parameters.get('properties/external_address') - internal_ssh_port = task_parameters.get('properties/internal_ssh_port') + + internal_ssh_port = internal_ssh_port or task_parameters.get('properties/internal_ssh_port') + jupyter_port = task_parameters.get('properties/jupyter_port') - ssh_port = \ + + ssh_port = ssh_port or \ task_parameters.get('properties/k8s-pod-port') or \ task_parameters.get('properties/external_ssh_port') or internal_ssh_port + if state.get('keepalive'): internal_ssh_port = task_parameters.get('properties/internal_stable_ssh_port') or internal_ssh_port local_remote_pair_list = [(local_ssh_port_, internal_ssh_port)] diff --git a/clearml_session/interactive_session_task.py b/clearml_session/interactive_session_task.py index fcdd8f1..af79a01 100644 --- a/clearml_session/interactive_session_task.py +++ b/clearml_session/interactive_session_task.py @@ -562,9 +562,18 @@ def setup_ssh_server(hostname, hostnames, param, task, env): print("Installing SSH Server on {} [{}]".format(hostname, hostnames)) ssh_password = param.get("ssh_password", "training") + + ssh_port = None + if Session.check_min_api_version("2.13"): + try: + # noinspection PyProtectedMember + ssh_port = task._get_runtime_properties().get("internal_tcp_port") + except Exception as ex: + print("Failed retrieving internal TCP port for SSH daemon: {}".format(ex)) + # noinspection PyBroadException try: - ssh_port = param.get("ssh_ports") or "10022:15000" + ssh_port = ssh_port or param.get("ssh_ports") or "10022:15000" min_port = int(ssh_port.split(":")[0]) max_port = max(min_port+32, int(ssh_port.split(":")[-1])) port = get_free_port(min_port, max_port)