Add finer grade ssh port routing

This commit is contained in:
allegroai 2024-05-20 15:53:14 +03:00
parent f6fe1cff56
commit 6bca7dfd7b
2 changed files with 27 additions and 4 deletions

View File

@ -36,6 +36,7 @@ except Exception:
system_tag = 'interactive'
default_docker_image = 'nvidia/cuda:11.6.2-runtime-ubuntu20.04'
internal_tcp_port_request = 10022
class NonInteractiveError(Exception):
@ -186,6 +187,7 @@ def create_base_task(state, project_name=None, task_name=None, continue_task_id=
"_user_secret": '',
"_jupyter_token": '',
"_ssh_password": "training",
"internal_tcp_port_request": str(internal_tcp_port_request),
})
# noinspection PyProtectedMember
task._set_runtime_properties(_runtime_prop)
@ -267,6 +269,7 @@ def create_debugging_task(state, debug_task_id, task_name=None, task_project_id=
"_user_secret": '',
"_jupyter_token": '',
"_ssh_password": "training",
"internal_tcp_port_request": str(internal_tcp_port_request),
})
# noinspection PyProtectedMember
task._set_runtime_properties(_runtime_prop)
@ -635,6 +638,7 @@ def clone_task(state, project_id=None):
runtime_properties['_ssh_password'] = str(state['password'])
runtime_properties['_user_key'] = str(config_obj.get("api.credentials.access_key"))
runtime_properties['_user_secret'] = (config_obj.get("api.credentials.secret_key"))
runtime_properties['internal_tcp_port_request'] = str(internal_tcp_port_request)
# noinspection PyProtectedMember
task._set_runtime_properties(runtime_properties)
@ -955,25 +959,35 @@ def monitor_ssh_tunnel(state, task):
remote_address
]):
task.reload()
internal_ssh_port = None
remote_address = None
ssh_port = None
task_parameters = task.get_parameters()
if Session.check_min_api_version("2.13"):
# noinspection PyProtectedMember
runtime_prop = task._get_runtime_properties()
ssh_password = runtime_prop.get('_ssh_password') or state.get('password', '')
jupyter_token = runtime_prop.get('_jupyter_token')
internal_ssh_port = runtime_prop.get('internal_tcp_port')
remote_address = runtime_prop.get('external_address')
ssh_port = runtime_prop.get('external_tcp_port')
else:
section = 'General' if 'General/ssh_server' in task_parameters else default_section
ssh_password = task_parameters.get('{}/ssh_password'.format(section)) or state.get('password', '')
jupyter_token = task_parameters.get('properties/jupyter_token')
remote_address = \
remote_address = remote_address or \
task_parameters.get('properties/k8s-gateway-address') or \
task_parameters.get('properties/external_address')
internal_ssh_port = task_parameters.get('properties/internal_ssh_port')
internal_ssh_port = internal_ssh_port or task_parameters.get('properties/internal_ssh_port')
jupyter_port = task_parameters.get('properties/jupyter_port')
ssh_port = \
ssh_port = ssh_port or \
task_parameters.get('properties/k8s-pod-port') or \
task_parameters.get('properties/external_ssh_port') or internal_ssh_port
if state.get('keepalive'):
internal_ssh_port = task_parameters.get('properties/internal_stable_ssh_port') or internal_ssh_port
local_remote_pair_list = [(local_ssh_port_, internal_ssh_port)]

View File

@ -562,9 +562,18 @@ def setup_ssh_server(hostname, hostnames, param, task, env):
print("Installing SSH Server on {} [{}]".format(hostname, hostnames))
ssh_password = param.get("ssh_password", "training")
ssh_port = None
if Session.check_min_api_version("2.13"):
try:
# noinspection PyProtectedMember
ssh_port = task._get_runtime_properties().get("internal_tcp_port")
except Exception as ex:
print("Failed retrieving internal TCP port for SSH daemon: {}".format(ex))
# noinspection PyBroadException
try:
ssh_port = param.get("ssh_ports") or "10022:15000"
ssh_port = ssh_port or param.get("ssh_ports") or "10022:15000"
min_port = int(ssh_port.split(":")[0])
max_port = max(min_port+32, int(ssh_port.split(":")[-1]))
port = get_free_port(min_port, max_port)