Improve k8s integration

This commit is contained in:
allegroai 2024-05-20 15:56:57 +03:00
parent 8b100f0ccb
commit cee824dd8f
2 changed files with 49 additions and 20 deletions

View File

@ -671,6 +671,7 @@ def clone_task(state, project_id=None):
task_params["{}/vscode_extensions".format(section)] = state.get('vscode_extensions') or ''
task_params["{}/force_dropbear".format(section)] = bool(state.get('force_dropbear'))
task_params["{}/store_workspace".format(section)] = state.get('store_workspace')
task_params["{}/use_ssh_proxy".format(section)] = state.get('keepalive')
if state.get('user_folder'):
task_params['{}/user_base_directory'.format(section)] = state.get('user_folder')
docker = state.get('docker') or task.get_base_docker()

View File

@ -577,7 +577,11 @@ def setup_ssh_server(hostname, hostnames, param, task, env):
min_port = int(ssh_port.split(":")[0])
max_port = max(min_port+32, int(ssh_port.split(":")[-1]))
port = get_free_port(min_port, max_port)
proxy_port = get_free_port(min_port, max_port)
if param.get("use_ssh_proxy"):
proxy_port = port
port = get_free_port(min_port, max_port)
else:
proxy_port = None
use_dropbear = bool(param.get("force_dropbear", False))
# if we are root, install open-ssh
@ -629,7 +633,7 @@ def setup_ssh_server(hostname, hostnames, param, task, env):
except Exception:
# noinspection PyBroadException
try:
print('SSHd was not found default to user space dropbear sshd server')
print('WARNING: SSHd was not found defaulting to user-space dropbear sshd server')
dropbear_download_link = \
os.environ.get("CLEARML_DROPBEAR_EXEC") or \
'https://github.com/allegroai/dropbear/releases/download/DROPBEAR_CLEARML_2023.02/dropbearmulti'
@ -678,6 +682,15 @@ def setup_ssh_server(hostname, hostnames, param, task, env):
ssh_password = None
task.set_parameter('{}/ssh_password'.format(config_section_name), '')
# get current user:
# noinspection PyBroadException
try:
current_user = getpass.getuser() or "root"
except Exception:
# we failed getting the user, let's assume root
print("Warning: failed getting active user name, assuming 'root'")
current_user = "root"
# create fingerprint files
Path(ssh_config_path).mkdir(parents=True, exist_ok=True)
keys_filename = {}
@ -689,8 +702,7 @@ def setup_ssh_server(hostname, hostnames, param, task, env):
pass
if v:
with open(filename, 'wt') as f:
f.write(v + (' {}@{}'.format(
getpass.getuser() or "root", hostname) if filename.endswith('.pub') else ''))
f.write(v + (' {}@{}'.format(current_user, hostname) if filename.endswith('.pub') else ''))
os.chmod(filename, 0o600 if filename.endswith('.pub') else 0o600)
keys_filename[k] = filename
@ -727,19 +739,26 @@ def setup_ssh_server(hostname, hostnames, param, task, env):
if result != 0:
raise ValueError("Failed launching sshd: ", proc_args)
# noinspection PyBroadException
try:
TcpProxy(listen_port=proxy_port, target_port=port, proxy_state={}, verbose=False, # noqa
keep_connection=True, is_connection_server=True)
except Exception as ex:
print('Warning: Could not setup stable ssh port, {}'.format(ex))
proxy_port = None
if proxy_port:
# noinspection PyBroadException
try:
TcpProxy(listen_port=proxy_port, target_port=port, proxy_state={}, verbose=False, # noqa
keep_connection=True, is_connection_server=True)
except Exception as ex:
print('Warning: Could not setup stable ssh port, {}'.format(ex))
proxy_port = None
if task:
if proxy_port:
task.set_parameter(name='properties/internal_stable_ssh_port', value=str(proxy_port))
task.set_parameter(name='properties/internal_ssh_port', value=str(port))
task._set_runtime_properties(runtime_properties={'internal_ssh_port': str(port)})
# noinspection PyProtectedMember
task._set_runtime_properties(
runtime_properties={
'internal_ssh_port': str(proxy_port or port),
'_ssh_user': current_user,
}
)
print(
"\n#\n# SSH Server running on {} [{}] port {}\n# LOGIN u:root p:{}\n#\n".format(
@ -970,14 +989,22 @@ def get_host_name(task, param):
pass
# update host name
if not task.get_parameter(name='properties/external_address'):
external_addr = hostnames
if param.get('public_ip'):
# noinspection PyBroadException
try:
external_addr = requests.get('https://checkip.amazonaws.com').text.strip()
except Exception:
pass
if (not task.get_parameter(name='properties/external_address') and
not task.get_parameter(name='properties/k8s-gateway-address')):
if task._get_runtime_properties().get("external_address"):
external_addr = task._get_runtime_properties().get("external_address")
else:
external_addr = hostnames
if param.get('public_ip'):
# noinspection PyBroadException
try:
external_addr = requests.get('https://checkip.amazonaws.com').text.strip()
except Exception:
pass
# make sure we set it to the runtime properties
task._set_runtime_properties({"external_address": external_addr})
# make sure we set it back to the Task user properties
task.set_parameter(name='properties/external_address', value=str(external_addr))
return hostname, hostnames
@ -1207,6 +1234,7 @@ def main():
"ssh_ports": None,
"force_dropbear": False,
"store_workspace": None,
"use_ssh_proxy": False,
}
task = init_task(param, default_ssh_fingerprint)