mirror of
https://github.com/clearml/clearml-session
synced 2025-02-07 13:22:17 +00:00
691 lines
27 KiB
Python
691 lines
27 KiB
Python
import json
|
|
import os
|
|
import socket
|
|
import subprocess
|
|
import sys
|
|
from time import sleep
|
|
|
|
import requests
|
|
from copy import deepcopy
|
|
from tempfile import mkstemp
|
|
|
|
import psutil
|
|
from pathlib2 import Path
|
|
|
|
from clearml import Task, StorageManager
|
|
|
|
|
|
# noinspection SpellCheckingInspection
|
|
default_ssh_fingerprint = {
|
|
'ssh_host_ecdsa_key':
|
|
r"-----BEGIN EC PRIVATE KEY-----"+"\n"
|
|
r"MHcCAQEEIOCAf3KEN9Hrde53rqQM4eR8VfCnO0oc4XTEBw0w6lCfoAoGCCqGSM49"+"\n"
|
|
r"AwEHoUQDQgAEn/LlC/1UN1q6myfjs03LJdHY2LB0b1hBjAsLvQnDMt8QE6Rml3UF"+"\n"
|
|
r"QK/UFw4mEqCFCD+dcbyWqFsKxTm6WtFStg=="+"\n"
|
|
r"-----END EC PRIVATE KEY-----"+"\n",
|
|
|
|
'ssh_host_ed25519_key':
|
|
r"-----BEGIN OPENSSH PRIVATE KEY-----"+"\n"
|
|
r"b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW"+"\n"
|
|
r"QyNTUxOQAAACDvweeJHnUKtwY7/WRqDJEZTDk8AajWKFt/BXmEI3+A8gAAAJiEMTXOhDE1"+"\n"
|
|
r"zgAAAAtzc2gtZWQyNTUxOQAAACDvweeJHnUKtwY7/WRqDJEZTDk8AajWKFt/BXmEI3+A8g"+"\n"
|
|
r"AAAEBCHpidTBUN3+W8s3qRNkyaJpA/So4vEqDvOhseSqJeH+/B54kedQq3Bjv9ZGoMkRlM"+"\n"
|
|
r"OTwBqNYoW38FeYQjf4DyAAAAEXJvb3RAODQ1NmQ5YTdlYTk4AQIDBA=="+"\n"
|
|
r"-----END OPENSSH PRIVATE KEY-----"+"\n",
|
|
|
|
'ssh_host_rsa_key':
|
|
r"-----BEGIN RSA PRIVATE KEY-----"+"\n"
|
|
r"MIIEowIBAAKCAQEAs8R3BrinMM/k9Jak7UqsoONqLQoasYgkeBVOOfRJ6ORYWW5R"+"\n"
|
|
r"WLkYnPPUGRpbcoM1Imh7ODBgKzs0mh5/j3y0SKP/MpvT4bf38e+QGjuC+6fR4Ah0"+"\n"
|
|
r"L5ohGIMyqhAiBoXgj0k2BE6en/4Rb3BwNPMocCTus82SwajzMNgWneRC6GCq2M0n"+"\n"
|
|
r"0PWenhS0IQz7jUlw3JU8z6T3ROPiMBPU7ubBhiNlAzMYPr76Z7J6ZNrCclAvdGkI"+"\n"
|
|
r"YxK7RNq0HwfoUj0UFD9iaEHswDIlNc34p93lP6GIAbh7uVYfGhg4z7HdBoN2qweN"+"\n"
|
|
r"szo7iQX9N8EFP4WfpLzNFteThzgN/bdso8iv0wIDAQABAoIBAQCPvbF64110b1dg"+"\n"
|
|
r"p7AauVINl6oHd4PensCicE7LkmUi3qsyXz6WVfKzVVgr9mJWz0lGSQr14+CR0NZ/"+"\n"
|
|
r"wZE393vkdZWSLv2eB88vWeH8x8c1WHw9yiS1B2YdRpLVXu8GDjh/+gdCLGc0ASCJ"+"\n"
|
|
r"3fsqq5+TBEUF6oPFbEWAsdhryeAiFAokeIVEKkxRnIDvPCP6i0evUHAxEP+wOngu"+"\n"
|
|
r"4XONkixNmATNa1jP2YAjmh3uQbAf2BvDZuywJmqV8fqZa/BwuK3W+R/92t0ySZ5Q"+"\n"
|
|
r"Z7RCZzPzFvWY683/Cfx5+BH3XcIetbcZ/HKuc+TdBvvFgqrLNIJ4OXMp3osjZDMO"+"\n"
|
|
r"YZIE6DdBAoGBAOG8cgm2N+Kl2dl0q1r4S+hf//zPaDorNasvcXJcj/ypy1MdmDXt"+"\n"
|
|
r"whLSAuTN4r8axgbuws2Z870pIGd28koqg78U+pOPabkphloo8Fc97RO28ZJCK2g0"+"\n"
|
|
r"/prPgwSYymkhrvwdzIbI11BPL/rr9cLJ1eYDnzGDSqvXJDL79XxrzwMzAoGBAMve"+"\n"
|
|
r"ULkfqaYVlgY58d38XruyCpSmRSq39LTeTYRWkJTNFL6rkqL9A69z/ITdpSStEuR8"+"\n"
|
|
r"8MXQSsPz8xUhFrA2bEjW7AT0r6OqGbjljKeh1whYOfgGfMKQltTfikkrf5w0UrLw"+"\n"
|
|
r"NQ8USfpwWdFnBGQG0yE/AFknyLH14/pqfRlLzaDhAoGAcN3IJxL03l4OjqvHAbUk"+"\n"
|
|
r"PwvA8qbBdlQkgXM3RfcCB1LeVrB1aoF2h/J5f+1xchvw54Z54FMZi3sEuLbAblTT"+"\n"
|
|
r"irbyktUiB3K7uli90uEjqLfQEVEEYxYcN0uKNsIucmJlG6nKmZnSDlWJp+xS9RH1"+"\n"
|
|
r"4QvujNMYgtMPRm60T4GYAAECgYB6J9LMqik4CDUls/C2J7MH2m22lk5Zg3JQMefW"+"\n"
|
|
r"xRvK3XtxqFKr8NkVd3U2k6yRZlcsq6SFkwJJmdHsti/nFCUcHBO+AHOBqLnS7VCz"+"\n"
|
|
r"XSkAqgTKFfEJkCOgl/U/VJ4ZFcz7xSy1xV1yf4GCFK0v1lsJz7tAsLLz1zdsZARj"+"\n"
|
|
r"dOVYYQKBgC3IQHfd++r9kcL3+vU7bDVU4aKq0JFDA79DLhKDpSTVxqTwBT+/BIpS"+"\n"
|
|
r"8z79zBTjNy5gMqxZp/SWBVWmsO8d7IUk9O2L/bMhHF0lOKbaHQQ9oveCzIwDewcf"+"\n"
|
|
r"5I45LjjGPJS84IBYv4NElptRk/2eFFejr75xdm4lWfpLb1SXPOPB"+"\n"
|
|
r"-----END RSA PRIVATE KEY-----"+"\n",
|
|
|
|
'ssh_host_rsa_key__pub':
|
|
r'ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCzxHcGuKcwz+T0lqTtSqyg42otChqxiCR4FU459Eno5FhZblFYuRic89QZGlt'
|
|
r'ygzUiaHs4MGArOzSaHn+PfLRIo/8ym9Pht/fx75AaO4L7p9HgCHQvmiEYgzKqECIGheCPSTYETp6f/hFvcHA08yhwJO6zzZLBqPM'
|
|
r'w2Bad5ELoYKrYzSfQ9Z6eFLQhDPuNSXDclTzPpPdE4+IwE9Tu5sGGI2UDMxg+vvpnsnpk2sJyUC90aQhjErtE2rQfB+hSPRQUP2Jo'
|
|
r'QezAMiU1zfin3eU/oYgBuHu5Vh8aGDjPsd0Gg3arB42zOjuJBf03wQU/hZ+kvM0W15OHOA39t2yjyK/T',
|
|
'ssh_host_ecdsa_key__pub':
|
|
r'ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJ/y5Qv9VDdaupsn47NNyyXR2Niwd'
|
|
r'G9YQYwLC70JwzLfEBOkZpd1BUCv1BcOJhKghQg/nXG8lqhbCsU5ulrRUrY=',
|
|
'ssh_host_ed25519_key__pub': None,
|
|
}
|
|
config_section_name = 'interactive_session'
|
|
config_object_section_ssh = 'SSH'
|
|
config_object_section_bash_init = 'interactive_init_script'
|
|
|
|
|
|
__allocated_ports = []
|
|
|
|
|
|
def get_free_port(range_min, range_max):
|
|
global __allocated_ports
|
|
used_ports = [i.laddr.port for i in psutil.net_connections()]
|
|
port = next(i for i in range(range_min, range_max) if i not in used_ports and i not in __allocated_ports)
|
|
__allocated_ports.append(port)
|
|
return port
|
|
|
|
|
|
def init_task(param, a_default_ssh_fingerprint):
|
|
# initialize ClearML
|
|
Task.add_requirements('jupyter')
|
|
Task.add_requirements('jupyterlab')
|
|
Task.add_requirements('jupyterlab_git')
|
|
task = Task.init(
|
|
project_name="DevOps", task_name="Allocate Jupyter Notebook Instance", task_type=Task.TaskTypes.service)
|
|
|
|
# Add jupyter server base folder
|
|
task.connect(param, name=config_section_name)
|
|
# connect ssh finger print configuration (with fallback if section is missing)
|
|
old_default_ssh_fingerprint = deepcopy(a_default_ssh_fingerprint)
|
|
try:
|
|
task.connect_configuration(configuration=a_default_ssh_fingerprint, name=config_object_section_ssh)
|
|
except (TypeError, ValueError):
|
|
a_default_ssh_fingerprint.clear()
|
|
a_default_ssh_fingerprint.update(old_default_ssh_fingerprint)
|
|
if param.get('default_docker'):
|
|
task.set_base_docker("{} --network host".format(param['default_docker']))
|
|
# leave local process, only run remotely
|
|
task.execute_remotely()
|
|
return task
|
|
|
|
|
|
def setup_os_env(param):
|
|
# get rid of all the runtime ClearML
|
|
preserve = (
|
|
"_API_HOST",
|
|
"_WEB_HOST",
|
|
"_FILES_HOST",
|
|
"_CONFIG_FILE",
|
|
"_API_ACCESS_KEY",
|
|
"_API_SECRET_KEY",
|
|
"_API_HOST_VERIFY_CERT",
|
|
"_DOCKER_IMAGE",
|
|
)
|
|
# set default docker image, with network configuration
|
|
if param.get('default_docker', '').strip():
|
|
os.environ["TRAINS_DOCKER_IMAGE"] = param['default_docker'].strip()
|
|
os.environ["CLEARML_DOCKER_IMAGE"] = param['default_docker'].strip()
|
|
|
|
# setup os environment
|
|
env = deepcopy(os.environ)
|
|
for key in os.environ:
|
|
if (key.startswith("TRAINS") or key.startswith("CLEARML")) and not any(key.endswith(p) for p in preserve):
|
|
env.pop(key, None)
|
|
|
|
return env
|
|
|
|
|
|
def monitor_jupyter_server(fd, local_filename, process, task, jupyter_port, hostnames):
|
|
# todo: add auto spin down see: https://tljh.jupyter.org/en/latest/topic/idle-culler.html
|
|
# print stdout/stderr
|
|
prev_line_count = 0
|
|
process_running = True
|
|
token = None
|
|
while process_running:
|
|
process_running = False
|
|
try:
|
|
process.wait(timeout=2.0 if not token else 15.0)
|
|
except subprocess.TimeoutExpired:
|
|
process_running = True
|
|
|
|
# noinspection PyBroadException
|
|
try:
|
|
with open(local_filename, "rt") as f:
|
|
# read new lines
|
|
new_lines = f.readlines()
|
|
if not new_lines:
|
|
continue
|
|
os.lseek(fd, 0, 0)
|
|
os.ftruncate(fd, 0)
|
|
except Exception:
|
|
continue
|
|
|
|
print("".join(new_lines))
|
|
prev_line_count += len(new_lines)
|
|
# if we already have the token, do nothing, just monitor
|
|
if token:
|
|
continue
|
|
|
|
# update task with jupyter notebook server links (port / token)
|
|
line = ''
|
|
for line in new_lines:
|
|
if "http://" not in line and "https://" not in line:
|
|
continue
|
|
parts = line.split('?token=', 1)
|
|
if len(parts) != 2:
|
|
continue
|
|
token = parts[1]
|
|
port = parts[0].split(':')[-1]
|
|
# try to cast to int
|
|
try:
|
|
port = int(port) # noqa
|
|
except (TypeError, ValueError):
|
|
continue
|
|
break
|
|
# we could not locate the token, try again
|
|
if not token:
|
|
continue
|
|
# update the task with the correct links and token
|
|
task.set_parameter(name='properties/jupyter_token', value=str(token))
|
|
# we ignore the reported port, because jupyter server will get confused
|
|
# if we have multiple servers running and will point to the wrong port/server
|
|
task.set_parameter(name='properties/jupyter_port', value=str(jupyter_port))
|
|
jupyter_url = '{}://{}:{}?token={}'.format(
|
|
'https' if "https://" in line else 'http',
|
|
hostnames, jupyter_port, token
|
|
)
|
|
print('\nJupyter Lab URL: {}\n'.format(jupyter_url))
|
|
task.set_parameter(name='properties/jupyter_url', value=jupyter_url)
|
|
|
|
# cleanup
|
|
# noinspection PyBroadException
|
|
try:
|
|
os.close(fd)
|
|
except Exception:
|
|
pass
|
|
# noinspection PyBroadException
|
|
try:
|
|
os.unlink(local_filename)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def start_vscode_server(hostname, hostnames, param, task, env):
|
|
if not param.get("vscode_server"):
|
|
return
|
|
|
|
# make a copy of env and remove the pythonpath from it.
|
|
env = dict(**env)
|
|
env.pop('PYTHONPATH', None)
|
|
|
|
# find a free tcp port
|
|
port = get_free_port(9000, 9100)
|
|
|
|
if os.geteuid() == 0:
|
|
# installing VSCODE:
|
|
try:
|
|
python_ext = StorageManager.get_local_copy(
|
|
'https://github.com/microsoft/vscode-python/releases/download/2020.10.332292344/ms-python-release.vsix',
|
|
extract_archive=False)
|
|
code_server_deb = StorageManager.get_local_copy(
|
|
'https://github.com/cdr/code-server/releases/download/v3.7.4/code-server_3.7.4_amd64.deb',
|
|
extract_archive=False)
|
|
os.system("dpkg -i {}".format(code_server_deb))
|
|
except Exception as ex:
|
|
print("Failed installing vscode server: {}".format(ex))
|
|
return
|
|
vscode_path = 'code-server'
|
|
else:
|
|
python_ext = None
|
|
# check if code-server exists
|
|
# noinspection PyBroadException
|
|
try:
|
|
vscode_path = subprocess.check_output('which code-server', shell=True).decode().strip()
|
|
assert vscode_path
|
|
except Exception:
|
|
print('Error: Cannot install code-server (not root) and could not find code-server executable, skipping.')
|
|
task.set_parameter(name='properties/vscode_port', value=str(-1))
|
|
return
|
|
|
|
cwd = (
|
|
os.path.expandvars(os.path.expanduser(param["user_base_directory"]))
|
|
if param["user_base_directory"]
|
|
else os.getcwd()
|
|
)
|
|
# make sure we have the needed cwd
|
|
# noinspection PyBroadException
|
|
try:
|
|
Path(cwd).mkdir(parents=True, exist_ok=True)
|
|
except Exception:
|
|
pass
|
|
print("Running VSCode Server on {} [{}] port {} at {}".format(hostname, hostnames, port, cwd))
|
|
print("VSCode Server available: http://{}:{}/\n".format(hostnames, port))
|
|
user_folder = os.path.join(cwd, ".vscode/user/")
|
|
exts_folder = os.path.join(cwd, ".vscode/exts/")
|
|
|
|
try:
|
|
fd, local_filename = mkstemp()
|
|
subprocess.Popen(
|
|
[
|
|
vscode_path,
|
|
"--auth",
|
|
"none",
|
|
"--bind-addr",
|
|
"127.0.0.1:{}".format(port),
|
|
"--user-data-dir", user_folder,
|
|
"--extensions-dir", exts_folder,
|
|
"--install-extension", "ms-toolsai.jupyter",
|
|
# "--install-extension", "donjayamanne.python-extension-pack"
|
|
] + ["--install-extension", python_ext] if python_ext else [],
|
|
env=env,
|
|
stdout=fd,
|
|
stderr=fd,
|
|
)
|
|
settings = Path(os.path.expanduser(os.path.join(user_folder, 'User/settings.json')))
|
|
settings.parent.mkdir(parents=True, exist_ok=True)
|
|
# noinspection PyBroadException
|
|
try:
|
|
with open(settings.as_posix(), 'rt') as f:
|
|
base_json = json.load(f)
|
|
except Exception:
|
|
base_json = {}
|
|
# noinspection PyBroadException
|
|
try:
|
|
base_json.update({
|
|
"extensions.autoCheckUpdates": False,
|
|
"extensions.autoUpdate": False,
|
|
"python.pythonPath": sys.executable,
|
|
"terminal.integrated.shell.linux": "/bin/bash" if Path("/bin/bash").is_file() else None,
|
|
})
|
|
with open(settings.as_posix(), 'wt') as f:
|
|
json.dump(base_json, f)
|
|
except Exception:
|
|
pass
|
|
proc = subprocess.Popen(
|
|
['bash', '-c',
|
|
'{} --auth none --bind-addr 127.0.0.1:{} --disable-update-check '
|
|
'--user-data-dir {} --extensions-dir {}'.format(vscode_path, port, user_folder, exts_folder)],
|
|
env=env,
|
|
stdout=fd,
|
|
stderr=fd,
|
|
cwd=cwd,
|
|
)
|
|
try:
|
|
error_code = proc.wait(timeout=1)
|
|
raise ValueError("code-server failed starting, return code {}".format(error_code))
|
|
except subprocess.TimeoutExpired:
|
|
pass
|
|
|
|
except Exception as ex:
|
|
print('Failed running vscode server: {}'.format(ex))
|
|
return
|
|
|
|
task.set_parameter(name='properties/vscode_port', value=str(port))
|
|
|
|
|
|
def start_jupyter_server(hostname, hostnames, param, task, env):
|
|
if not param.get('jupyterlab', True):
|
|
print('no jupyterlab to monitor - going to sleep')
|
|
while True:
|
|
sleep(10.)
|
|
return
|
|
|
|
# execute jupyter notebook
|
|
fd, local_filename = mkstemp()
|
|
cwd = (
|
|
os.path.expandvars(os.path.expanduser(param["user_base_directory"]))
|
|
if param["user_base_directory"]
|
|
else os.getcwd()
|
|
)
|
|
|
|
# find a free tcp port
|
|
port = get_free_port(8888, 9000)
|
|
|
|
# if we are not running as root, make sure the sys executable is in the PATH
|
|
env = dict(**env)
|
|
env['PATH'] = '{}:{}'.format(Path(sys.executable).parent.as_posix(), env.get('PATH', ''))
|
|
|
|
# make sure we have the needed cwd
|
|
# noinspection PyBroadException
|
|
try:
|
|
Path(cwd).mkdir(parents=True, exist_ok=True)
|
|
except Exception:
|
|
pass
|
|
print(
|
|
"Running Jupyter Notebook Server on {} [{}] port {} at {}".format(hostname, hostnames, port, cwd)
|
|
)
|
|
process = subprocess.Popen(
|
|
[
|
|
sys.executable,
|
|
"-m",
|
|
"jupyter",
|
|
"lab",
|
|
"--no-browser",
|
|
"--allow-root",
|
|
"--ip",
|
|
"127.0.0.1",
|
|
"--port",
|
|
str(port),
|
|
],
|
|
env=env,
|
|
stdout=fd,
|
|
stderr=fd,
|
|
cwd=cwd,
|
|
)
|
|
return monitor_jupyter_server(fd, local_filename, process, task, port, hostnames)
|
|
|
|
|
|
def setup_ssh_server(hostname, hostnames, param, task):
|
|
if not param.get("ssh_server"):
|
|
return
|
|
|
|
print("Installing SSH Server on {} [{}]".format(hostname, hostnames))
|
|
ssh_password = param.get("ssh_password", "training")
|
|
# noinspection PyBroadException
|
|
try:
|
|
port = get_free_port(10022, 15000)
|
|
proxy_port = get_free_port(10022, 15000)
|
|
|
|
# if we are root, install open-ssh
|
|
if os.geteuid() == 0:
|
|
# noinspection SpellCheckingInspection
|
|
os.system(
|
|
"export PYTHONPATH=\"\" && "
|
|
"apt-get install -y openssh-server && "
|
|
"mkdir -p /var/run/sshd && "
|
|
"echo 'root:{password}' | chpasswd && "
|
|
"echo 'PermitRootLogin yes' >> /etc/ssh/sshd_config && "
|
|
"sed -i 's/PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config && "
|
|
"sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd "
|
|
"&& " # noqa: W605
|
|
"echo 'ClientAliveInterval 10' >> /etc/ssh/sshd_config && "
|
|
"echo 'ClientAliveCountMax 20' >> /etc/ssh/sshd_config && "
|
|
"echo 'AcceptEnv TRAINS_API_ACCESS_KEY TRAINS_API_SECRET_KEY "
|
|
"CLEARML_API_ACCESS_KEY CLEARML_API_SECRET_KEY' >> /etc/ssh/sshd_config && "
|
|
'echo "export VISIBLE=now" >> /etc/profile && '
|
|
'echo "export PATH=$PATH" >> /etc/profile && '
|
|
'echo "ldconfig" >> /etc/profile && '
|
|
'echo "export TRAINS_CONFIG_FILE={trains_config_file}" >> /etc/profile'.format(
|
|
password=ssh_password,
|
|
port=port,
|
|
trains_config_file=os.environ.get("CLEARML_CONFIG_FILE") or os.environ.get("TRAINS_CONFIG_FILE"),
|
|
)
|
|
)
|
|
sshd_path = '/usr/sbin/sshd'
|
|
ssh_config_path = '/etc/ssh/'
|
|
custom_ssh_conf = None
|
|
else:
|
|
# check if sshd exists
|
|
# noinspection PyBroadException
|
|
try:
|
|
sshd_path = subprocess.check_output('which sshd', shell=True).decode().strip()
|
|
ssh_config_path = os.path.join(os.getcwd(), '.clearml_session_sshd')
|
|
Path(ssh_config_path).mkdir(parents=True, exist_ok=True)
|
|
custom_ssh_conf = os.path.join(ssh_config_path, 'sshd_config')
|
|
with open(custom_ssh_conf, 'wt') as f:
|
|
conf = \
|
|
"PermitRootLogin yes" + "\n"\
|
|
"ClientAliveInterval 10" + "\n"\
|
|
"ClientAliveCountMax 20" + "\n"\
|
|
"AllowTcpForwarding yes" + "\n"\
|
|
"UsePAM yes" + "\n"\
|
|
"AuthorizedKeysFile {}".format(os.path.join(ssh_config_path, 'authorized_keys')) + "\n"\
|
|
"PidFile {}".format(os.path.join(ssh_config_path, 'sshd.pid')) + "\n"\
|
|
"AcceptEnv TRAINS_API_ACCESS_KEY TRAINS_API_SECRET_KEY "\
|
|
"CLEARML_API_ACCESS_KEY CLEARML_API_SECRET_KEY"+"\n"
|
|
for k in default_ssh_fingerprint:
|
|
filename = os.path.join(ssh_config_path, '{}'.format(k.replace('__pub', '.pub')))
|
|
conf += "HostKey {}\n".format(filename)
|
|
|
|
f.write(conf)
|
|
except Exception:
|
|
print('Error: Cannot install sshd (not root) and could not find sshd executable, leaving!')
|
|
return
|
|
# clear the ssh password, we cannot change it
|
|
ssh_password = None
|
|
task.set_parameter('{}/ssh_password'.format(config_section_name), '')
|
|
|
|
# create fingerprint files
|
|
Path(ssh_config_path).mkdir(parents=True, exist_ok=True)
|
|
for k, v in default_ssh_fingerprint.items():
|
|
filename = os.path.join(ssh_config_path, '{}'.format(k.replace('__pub', '.pub')))
|
|
try:
|
|
os.unlink(filename)
|
|
except Exception: # noqa
|
|
pass
|
|
if v:
|
|
with open(filename, 'wt') as f:
|
|
f.write(v + (' root@{}'.format(hostname) if filename.endswith('.pub') else ''))
|
|
os.chmod(filename, 0o600 if filename.endswith('.pub') else 0o600)
|
|
|
|
# run server in foreground so it gets killed with us
|
|
proc_args = [sshd_path, "-D", "-p", str(port)] + (["-f", custom_ssh_conf] if custom_ssh_conf else [])
|
|
proc = subprocess.Popen(args=proc_args)
|
|
# noinspection PyBroadException
|
|
try:
|
|
result = proc.wait(timeout=1)
|
|
except Exception:
|
|
result = 0
|
|
|
|
if result != 0:
|
|
raise ValueError("Failed launching sshd: ", proc_args)
|
|
|
|
# noinspection PyBroadException
|
|
try:
|
|
TcpProxy(listen_port=proxy_port, target_port=port, proxy_state={}, verbose=False, # noqa
|
|
keep_connection=True, is_connection_server=True)
|
|
except Exception as ex:
|
|
print('Warning: Could not setup stable ssh port, {}'.format(ex))
|
|
proxy_port = None
|
|
|
|
if task:
|
|
if proxy_port:
|
|
task.set_parameter(name='properties/internal_stable_ssh_port', value=str(proxy_port))
|
|
task.set_parameter(name='properties/internal_ssh_port', value=str(port))
|
|
|
|
print(
|
|
"\n#\n# SSH Server running on {} [{}] port {}\n# LOGIN u:root p:{}\n#\n".format(
|
|
hostname, hostnames, port, ssh_password
|
|
)
|
|
)
|
|
|
|
except Exception as ex:
|
|
print("Error: {}\n\n#\n# Error: SSH server could not be launched\n#\n".format(ex))
|
|
|
|
|
|
def setup_user_env(param, task):
|
|
env = setup_os_env(param)
|
|
# do not change user bash/profile
|
|
if os.geteuid() != 0:
|
|
if param.get("user_key") and param.get("user_secret"):
|
|
env['TRAINS_API_ACCESS_KEY'] = param.get("user_key")
|
|
env['TRAINS_API_SECRET_KEY'] = param.get("user_secret")
|
|
return env
|
|
|
|
# create symbolic link to the venv
|
|
environment = os.path.expanduser('~/environment')
|
|
# noinspection PyBroadException
|
|
try:
|
|
os.symlink(os.path.abspath(os.path.join(os.path.abspath(sys.executable), '..', '..')), environment)
|
|
print('Virtual environment are available at {}'.format(environment))
|
|
except Exception:
|
|
pass
|
|
# set default user credentials
|
|
if param.get("user_key") and param.get("user_secret"):
|
|
os.system("echo 'export TRAINS_API_ACCESS_KEY=\"{}\"' >> ~/.bashrc".format(
|
|
param.get("user_key", "").replace('$', '\\$')))
|
|
os.system("echo 'export TRAINS_API_SECRET_KEY=\"{}\"' >> ~/.bashrc".format(
|
|
param.get("user_secret", "").replace('$', '\\$')))
|
|
os.system("echo 'export TRAINS_DOCKER_IMAGE=\"{}\"' >> ~/.bashrc".format(
|
|
param.get("default_docker", "").strip() or env.get('TRAINS_DOCKER_IMAGE', '')))
|
|
os.system("echo 'export TRAINS_API_ACCESS_KEY=\"{}\"' >> ~/.profile".format(
|
|
param.get("user_key", "").replace('$', '\\$')))
|
|
os.system("echo 'export TRAINS_API_SECRET_KEY=\"{}\"' >> ~/.profile".format(
|
|
param.get("user_secret", "").replace('$', '\\$')))
|
|
os.system("echo 'export TRAINS_DOCKER_IMAGE=\"{}\"' >> ~/.profile".format(
|
|
param.get("default_docker", "").strip() or env.get('TRAINS_DOCKER_IMAGE', '')))
|
|
env['TRAINS_API_ACCESS_KEY'] = param.get("user_key")
|
|
env['TRAINS_API_SECRET_KEY'] = param.get("user_secret")
|
|
# set default folder for user
|
|
if param.get("user_base_directory"):
|
|
base_dir = param.get("user_base_directory")
|
|
if ' ' in base_dir:
|
|
base_dir = '\"{}\"'.format(base_dir)
|
|
os.system("echo 'cd {}' >> ~/.bashrc".format(base_dir))
|
|
os.system("echo 'cd {}' >> ~/.profile".format(base_dir))
|
|
|
|
# make sure we activate the venv in the bash
|
|
os.system("echo 'source {}' >> ~/.bashrc".format(os.path.join(environment, 'bin', 'activate')))
|
|
os.system("echo '. {}' >> ~/.profile".format(os.path.join(environment, 'bin', 'activate')))
|
|
|
|
# check if we need to create .git-credentials
|
|
# noinspection PyProtectedMember
|
|
git_credentials = task._get_configuration_text('git_credentials')
|
|
if git_credentials:
|
|
git_cred_file = os.path.expanduser('~/.config/git/credentials')
|
|
# noinspection PyBroadException
|
|
try:
|
|
Path(git_cred_file).parent.mkdir(parents=True, exist_ok=True)
|
|
with open(git_cred_file, 'wt') as f:
|
|
f.write(git_credentials)
|
|
except Exception:
|
|
print('Could not write {} file'.format(git_cred_file))
|
|
# noinspection PyProtectedMember
|
|
git_config = task._get_configuration_text('git_config')
|
|
if git_config:
|
|
git_config_file = os.path.expanduser('~/.config/git/config')
|
|
# noinspection PyBroadException
|
|
try:
|
|
Path(git_config_file).parent.mkdir(parents=True, exist_ok=True)
|
|
with open(git_config_file, 'wt') as f:
|
|
f.write(git_config)
|
|
except Exception:
|
|
print('Could not write {} file'.format(git_config_file))
|
|
|
|
return env
|
|
|
|
|
|
def get_host_name(task, param):
|
|
# noinspection PyBroadException
|
|
try:
|
|
hostname = socket.gethostname()
|
|
hostnames = socket.gethostbyname(socket.gethostname())
|
|
except Exception:
|
|
def get_ip_addresses(family):
|
|
for interface, snics in psutil.net_if_addrs().items():
|
|
for snic in snics:
|
|
if snic.family == family:
|
|
yield snic.address
|
|
|
|
hostnames = list(get_ip_addresses(socket.AF_INET))[0]
|
|
hostname = hostnames
|
|
|
|
# try to get external address (if possible)
|
|
# noinspection PyBroadException
|
|
try:
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
# noinspection PyBroadException
|
|
try:
|
|
# doesn't even have to be reachable
|
|
s.connect(('8.255.255.255', 1))
|
|
hostnames = s.getsockname()[0]
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
s.close()
|
|
except Exception:
|
|
pass
|
|
|
|
# update host name
|
|
if not task.get_parameter(name='properties/external_address'):
|
|
external_addr = hostnames
|
|
if param.get('public_ip'):
|
|
# noinspection PyBroadException
|
|
try:
|
|
external_addr = requests.get('https://checkip.amazonaws.com').text.strip()
|
|
except Exception:
|
|
pass
|
|
task.set_parameter(name='properties/external_address', value=str(external_addr))
|
|
|
|
return hostname, hostnames
|
|
|
|
|
|
def run_user_init_script(task):
|
|
# run initialization script:
|
|
# noinspection PyProtectedMember
|
|
init_script = task._get_configuration_text(config_object_section_bash_init)
|
|
if not init_script or not str(init_script).strip():
|
|
return
|
|
print("Running user initialization bash script:")
|
|
init_filename = os_json_filename = None
|
|
try:
|
|
fd, init_filename = mkstemp(suffix='.init.sh')
|
|
os.close(fd)
|
|
fd, os_json_filename = mkstemp(suffix='.env.json')
|
|
os.close(fd)
|
|
with open(init_filename, 'wt') as f:
|
|
f.write(init_script +
|
|
'\n{} -c '
|
|
'"exec(\\"try:\\n import os\\n import json\\n'
|
|
' json.dump(dict(os.environ), open(\\\'{}\\\', \\\'w\\\'))'
|
|
'\\nexcept: pass\\")"'.format(sys.executable, os_json_filename))
|
|
env = dict(**os.environ)
|
|
# do not pass or update back the PYTHONPATH environment variable
|
|
env.pop('PYTHONPATH', None)
|
|
subprocess.call(['/bin/bash', init_filename], env=env)
|
|
with open(os_json_filename, 'rt') as f:
|
|
environ = json.load(f)
|
|
# do not pass or update back the PYTHONPATH environment variable
|
|
environ.pop('PYTHONPATH', None)
|
|
# update environment variables
|
|
os.environ.update(environ)
|
|
except Exception as ex:
|
|
print('User initialization script failed: {}'.format(ex))
|
|
finally:
|
|
if init_filename:
|
|
try:
|
|
os.unlink(init_filename)
|
|
except: # noqa
|
|
pass
|
|
if os_json_filename:
|
|
try:
|
|
os.unlink(os_json_filename)
|
|
except: # noqa
|
|
pass
|
|
|
|
|
|
def main():
|
|
param = {
|
|
"user_base_directory": "~/",
|
|
"ssh_server": True,
|
|
"ssh_password": "training",
|
|
"default_docker": "nvidia/cuda",
|
|
"user_key": None,
|
|
"user_secret": None,
|
|
"vscode_server": True,
|
|
"jupyterlab": True,
|
|
"public_ip": False,
|
|
}
|
|
task = init_task(param, default_ssh_fingerprint)
|
|
|
|
run_user_init_script(task)
|
|
|
|
hostname, hostnames = get_host_name(task, param)
|
|
|
|
env = setup_user_env(param, task)
|
|
|
|
setup_ssh_server(hostname, hostnames, param, task)
|
|
|
|
start_vscode_server(hostname, hostnames, param, task, env)
|
|
|
|
start_jupyter_server(hostname, hostnames, param, task, env)
|
|
|
|
print('We are done')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|