clearml-session/clearml_session/interactive_session_task.py

684 lines
27 KiB
Python
Raw Normal View History

2020-12-22 19:32:02 +00:00
import json
import os
import socket
import subprocess
import sys
from time import sleep
2020-12-22 19:32:02 +00:00
import requests
from copy import deepcopy
from tempfile import mkstemp
import psutil
from pathlib2 import Path
from clearml import Task, StorageManager
# noinspection SpellCheckingInspection
default_ssh_fingerprint = {
'ssh_host_ecdsa_key':
r"-----BEGIN EC PRIVATE KEY-----"+"\n"
r"MHcCAQEEIOCAf3KEN9Hrde53rqQM4eR8VfCnO0oc4XTEBw0w6lCfoAoGCCqGSM49"+"\n"
r"AwEHoUQDQgAEn/LlC/1UN1q6myfjs03LJdHY2LB0b1hBjAsLvQnDMt8QE6Rml3UF"+"\n"
r"QK/UFw4mEqCFCD+dcbyWqFsKxTm6WtFStg=="+"\n"
r"-----END EC PRIVATE KEY-----"+"\n",
'ssh_host_ed25519_key':
r"-----BEGIN OPENSSH PRIVATE KEY-----"+"\n"
r"b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW"+"\n"
r"QyNTUxOQAAACDvweeJHnUKtwY7/WRqDJEZTDk8AajWKFt/BXmEI3+A8gAAAJiEMTXOhDE1"+"\n"
r"zgAAAAtzc2gtZWQyNTUxOQAAACDvweeJHnUKtwY7/WRqDJEZTDk8AajWKFt/BXmEI3+A8g"+"\n"
r"AAAEBCHpidTBUN3+W8s3qRNkyaJpA/So4vEqDvOhseSqJeH+/B54kedQq3Bjv9ZGoMkRlM"+"\n"
r"OTwBqNYoW38FeYQjf4DyAAAAEXJvb3RAODQ1NmQ5YTdlYTk4AQIDBA=="+"\n"
r"-----END OPENSSH PRIVATE KEY-----"+"\n",
'ssh_host_rsa_key':
r"-----BEGIN RSA PRIVATE KEY-----"+"\n"
r"MIIEowIBAAKCAQEAs8R3BrinMM/k9Jak7UqsoONqLQoasYgkeBVOOfRJ6ORYWW5R"+"\n"
r"WLkYnPPUGRpbcoM1Imh7ODBgKzs0mh5/j3y0SKP/MpvT4bf38e+QGjuC+6fR4Ah0"+"\n"
r"L5ohGIMyqhAiBoXgj0k2BE6en/4Rb3BwNPMocCTus82SwajzMNgWneRC6GCq2M0n"+"\n"
r"0PWenhS0IQz7jUlw3JU8z6T3ROPiMBPU7ubBhiNlAzMYPr76Z7J6ZNrCclAvdGkI"+"\n"
r"YxK7RNq0HwfoUj0UFD9iaEHswDIlNc34p93lP6GIAbh7uVYfGhg4z7HdBoN2qweN"+"\n"
r"szo7iQX9N8EFP4WfpLzNFteThzgN/bdso8iv0wIDAQABAoIBAQCPvbF64110b1dg"+"\n"
r"p7AauVINl6oHd4PensCicE7LkmUi3qsyXz6WVfKzVVgr9mJWz0lGSQr14+CR0NZ/"+"\n"
r"wZE393vkdZWSLv2eB88vWeH8x8c1WHw9yiS1B2YdRpLVXu8GDjh/+gdCLGc0ASCJ"+"\n"
r"3fsqq5+TBEUF6oPFbEWAsdhryeAiFAokeIVEKkxRnIDvPCP6i0evUHAxEP+wOngu"+"\n"
r"4XONkixNmATNa1jP2YAjmh3uQbAf2BvDZuywJmqV8fqZa/BwuK3W+R/92t0ySZ5Q"+"\n"
r"Z7RCZzPzFvWY683/Cfx5+BH3XcIetbcZ/HKuc+TdBvvFgqrLNIJ4OXMp3osjZDMO"+"\n"
r"YZIE6DdBAoGBAOG8cgm2N+Kl2dl0q1r4S+hf//zPaDorNasvcXJcj/ypy1MdmDXt"+"\n"
r"whLSAuTN4r8axgbuws2Z870pIGd28koqg78U+pOPabkphloo8Fc97RO28ZJCK2g0"+"\n"
r"/prPgwSYymkhrvwdzIbI11BPL/rr9cLJ1eYDnzGDSqvXJDL79XxrzwMzAoGBAMve"+"\n"
r"ULkfqaYVlgY58d38XruyCpSmRSq39LTeTYRWkJTNFL6rkqL9A69z/ITdpSStEuR8"+"\n"
r"8MXQSsPz8xUhFrA2bEjW7AT0r6OqGbjljKeh1whYOfgGfMKQltTfikkrf5w0UrLw"+"\n"
r"NQ8USfpwWdFnBGQG0yE/AFknyLH14/pqfRlLzaDhAoGAcN3IJxL03l4OjqvHAbUk"+"\n"
r"PwvA8qbBdlQkgXM3RfcCB1LeVrB1aoF2h/J5f+1xchvw54Z54FMZi3sEuLbAblTT"+"\n"
r"irbyktUiB3K7uli90uEjqLfQEVEEYxYcN0uKNsIucmJlG6nKmZnSDlWJp+xS9RH1"+"\n"
r"4QvujNMYgtMPRm60T4GYAAECgYB6J9LMqik4CDUls/C2J7MH2m22lk5Zg3JQMefW"+"\n"
r"xRvK3XtxqFKr8NkVd3U2k6yRZlcsq6SFkwJJmdHsti/nFCUcHBO+AHOBqLnS7VCz"+"\n"
r"XSkAqgTKFfEJkCOgl/U/VJ4ZFcz7xSy1xV1yf4GCFK0v1lsJz7tAsLLz1zdsZARj"+"\n"
r"dOVYYQKBgC3IQHfd++r9kcL3+vU7bDVU4aKq0JFDA79DLhKDpSTVxqTwBT+/BIpS"+"\n"
r"8z79zBTjNy5gMqxZp/SWBVWmsO8d7IUk9O2L/bMhHF0lOKbaHQQ9oveCzIwDewcf"+"\n"
r"5I45LjjGPJS84IBYv4NElptRk/2eFFejr75xdm4lWfpLb1SXPOPB"+"\n"
r"-----END RSA PRIVATE KEY-----"+"\n",
'ssh_host_rsa_key__pub':
r'ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCzxHcGuKcwz+T0lqTtSqyg42otChqxiCR4FU459Eno5FhZblFYuRic89QZGlt'
r'ygzUiaHs4MGArOzSaHn+PfLRIo/8ym9Pht/fx75AaO4L7p9HgCHQvmiEYgzKqECIGheCPSTYETp6f/hFvcHA08yhwJO6zzZLBqPM'
r'w2Bad5ELoYKrYzSfQ9Z6eFLQhDPuNSXDclTzPpPdE4+IwE9Tu5sGGI2UDMxg+vvpnsnpk2sJyUC90aQhjErtE2rQfB+hSPRQUP2Jo'
r'QezAMiU1zfin3eU/oYgBuHu5Vh8aGDjPsd0Gg3arB42zOjuJBf03wQU/hZ+kvM0W15OHOA39t2yjyK/T',
'ssh_host_ecdsa_key__pub':
r'ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJ/y5Qv9VDdaupsn47NNyyXR2Niwd'
r'G9YQYwLC70JwzLfEBOkZpd1BUCv1BcOJhKghQg/nXG8lqhbCsU5ulrRUrY=',
'ssh_host_ed25519_key__pub': None,
}
config_section_name = 'interactive_session'
config_object_section_ssh = 'SSH'
config_object_section_bash_init = 'interactive_init_script'
__allocated_ports = []
def get_free_port(range_min, range_max):
global __allocated_ports
used_ports = [i.laddr.port for i in psutil.net_connections()]
port = next(i for i in range(range_min, range_max) if i not in used_ports and i not in __allocated_ports)
2020-12-22 19:32:02 +00:00
__allocated_ports.append(port)
return port
def init_task(param, a_default_ssh_fingerprint):
# initialize ClearML
Task.add_requirements('jupyter')
Task.add_requirements('jupyterlab')
Task.add_requirements('jupyterlab_git')
task = Task.init(
project_name="DevOps", task_name="Allocate Jupyter Notebook Instance", task_type=Task.TaskTypes.service)
# Add jupyter server base folder
task.connect(param, name=config_section_name)
# connect ssh finger print configuration (with fallback if section is missing)
old_default_ssh_fingerprint = deepcopy(a_default_ssh_fingerprint)
try:
task.connect_configuration(configuration=a_default_ssh_fingerprint, name=config_object_section_ssh)
except (TypeError, ValueError):
a_default_ssh_fingerprint.clear()
a_default_ssh_fingerprint.update(old_default_ssh_fingerprint)
if param.get('default_docker'):
task.set_base_docker("{} --network host".format(param['default_docker']))
# leave local process, only run remotely
task.execute_remotely()
return task
def setup_os_env(param):
# get rid of all the runtime ClearML
preserve = (
"_API_HOST",
"_WEB_HOST",
"_FILES_HOST",
"_CONFIG_FILE",
"_API_ACCESS_KEY",
"_API_SECRET_KEY",
"_API_HOST_VERIFY_CERT",
"_DOCKER_IMAGE",
)
# set default docker image, with network configuration
if param.get('default_docker', '').strip():
os.environ["TRAINS_DOCKER_IMAGE"] = param['default_docker'].strip()
os.environ["CLEARML_DOCKER_IMAGE"] = param['default_docker'].strip()
# setup os environment
env = deepcopy(os.environ)
for key in os.environ:
if (key.startswith("TRAINS") or key.startswith("CLEARML")) and not any(key.endswith(p) for p in preserve):
env.pop(key, None)
return env
def monitor_jupyter_server(fd, local_filename, process, task, jupyter_port, hostnames):
# todo: add auto spin down see: https://tljh.jupyter.org/en/latest/topic/idle-culler.html
# print stdout/stderr
prev_line_count = 0
process_running = True
token = None
while process_running:
process_running = False
try:
process.wait(timeout=2.0 if not token else 15.0)
except subprocess.TimeoutExpired:
process_running = True
# noinspection PyBroadException
try:
with open(local_filename, "rt") as f:
# read new lines
new_lines = f.readlines()
if not new_lines:
continue
os.lseek(fd, 0, 0)
os.ftruncate(fd, 0)
except Exception:
continue
print("".join(new_lines))
prev_line_count += len(new_lines)
# if we already have the token, do nothing, just monitor
if token:
continue
# update task with jupyter notebook server links (port / token)
line = ''
for line in new_lines:
if "http://" not in line and "https://" not in line:
continue
parts = line.split('?token=', 1)
2020-12-22 19:32:02 +00:00
if len(parts) != 2:
continue
token = parts[1]
port = parts[0].split(':')[-1]
# try to cast to int
try:
port = int(port) # noqa
except (TypeError, ValueError):
continue
break
# we could not locate the token, try again
if not token:
continue
# update the task with the correct links and token
task.set_parameter(name='properties/jupyter_token', value=str(token))
# we ignore the reported port, because jupyter server will get confused
# if we have multiple servers running and will point to the wrong port/server
task.set_parameter(name='properties/jupyter_port', value=str(jupyter_port))
jupyter_url = '{}://{}:{}?token={}'.format(
'https' if "https://" in line else 'http',
hostnames, jupyter_port, token
)
print('\nJupyter Lab URL: {}\n'.format(jupyter_url))
task.set_parameter(name='properties/jupyter_url', value=jupyter_url)
# cleanup
# noinspection PyBroadException
try:
os.close(fd)
except Exception:
pass
# noinspection PyBroadException
try:
os.unlink(local_filename)
except Exception:
pass
def start_vscode_server(hostname, hostnames, param, task, env):
if not param.get("vscode_server"):
return
# make a copy of env and remove the pythonpath from it.
env = dict(**env)
env.pop('PYTHONPATH', None)
# find a free tcp port
port = get_free_port(9000, 9100)
if os.geteuid() == 0:
# installing VSCODE:
try:
python_ext = StorageManager.get_local_copy(
'https://github.com/microsoft/vscode-python/releases/download/2020.10.332292344/ms-python-release.vsix',
extract_archive=False)
code_server_deb = StorageManager.get_local_copy(
'https://github.com/cdr/code-server/releases/download/v3.7.4/code-server_3.7.4_amd64.deb',
extract_archive=False)
os.system("dpkg -i {}".format(code_server_deb))
except Exception as ex:
print("Failed installing vscode server: {}".format(ex))
return
vscode_path = 'code-server'
else:
python_ext = None
# check if code-server exists
# noinspection PyBroadException
try:
vscode_path = subprocess.check_output('which code-server', shell=True).decode().strip()
assert vscode_path
except Exception:
print('Error: Cannot install code-server (not root) and could not find code-server executable, skipping.')
task.set_parameter(name='properties/vscode_port', value=str(-1))
return
2020-12-22 19:32:02 +00:00
cwd = (
os.path.expandvars(os.path.expanduser(param["user_base_directory"]))
if param["user_base_directory"]
else os.getcwd()
)
# make sure we have the needed cwd
# noinspection PyBroadException
try:
Path(cwd).mkdir(parents=True, exist_ok=True)
except Exception:
pass
print("Running VSCode Server on {} [{}] port {} at {}".format(hostname, hostnames, port, cwd))
print("VSCode Server available: http://{}:{}/\n".format(hostnames, port))
user_folder = os.path.join(cwd, ".vscode/user/")
exts_folder = os.path.join(cwd, ".vscode/exts/")
try:
fd, local_filename = mkstemp()
subprocess.Popen(
[
vscode_path,
2020-12-22 19:32:02 +00:00
"--auth",
"none",
"--bind-addr",
"127.0.0.1:{}".format(port),
"--user-data-dir", user_folder,
"--extensions-dir", exts_folder,
"--install-extension", "ms-toolsai.jupyter",
# "--install-extension", "donjayamanne.python-extension-pack"
] + ["--install-extension", python_ext] if python_ext else [],
2020-12-22 19:32:02 +00:00
env=env,
stdout=fd,
stderr=fd,
)
settings = Path(os.path.expanduser(os.path.join(user_folder, 'User/settings.json')))
settings.parent.mkdir(parents=True, exist_ok=True)
# noinspection PyBroadException
try:
with open(settings.as_posix(), 'rt') as f:
base_json = json.load(f)
except Exception:
base_json = {}
# noinspection PyBroadException
try:
base_json.update({
"extensions.autoCheckUpdates": False,
"extensions.autoUpdate": False,
"python.pythonPath": sys.executable,
"terminal.integrated.shell.linux": "/bin/bash" if Path("/bin/bash").is_file() else None,
})
with open(settings.as_posix(), 'wt') as f:
json.dump(base_json, f)
except Exception:
pass
proc = subprocess.Popen(
['bash', '-c',
'{} --auth none --bind-addr 127.0.0.1:{} --disable-update-check '
'--user-data-dir {} --extensions-dir {}'.format(vscode_path, port, user_folder, exts_folder)],
2020-12-22 19:32:02 +00:00
env=env,
stdout=fd,
stderr=fd,
cwd=cwd,
)
try:
error_code = proc.wait(timeout=1)
raise ValueError("code-server failed starting, return code {}".format(error_code))
except subprocess.TimeoutExpired:
pass
except Exception as ex:
print('Failed running vscode server: {}'.format(ex))
return
task.set_parameter(name='properties/vscode_port', value=str(port))
def start_jupyter_server(hostname, hostnames, param, task, env):
if not param.get('jupyterlab', True):
print('no jupyterlab to monitor - going to sleep')
while True:
sleep(10.)
return
2020-12-22 19:32:02 +00:00
# execute jupyter notebook
fd, local_filename = mkstemp()
cwd = (
os.path.expandvars(os.path.expanduser(param["user_base_directory"]))
if param["user_base_directory"]
else os.getcwd()
)
# find a free tcp port
port = get_free_port(8888, 9000)
# if we are not running as root, make sure the sys executable is in the PATH
env = dict(**env)
env['PATH'] = '{}:{}'.format(Path(sys.executable).parent.as_posix(), env.get('PATH', ''))
2020-12-22 19:32:02 +00:00
# make sure we have the needed cwd
# noinspection PyBroadException
try:
Path(cwd).mkdir(parents=True, exist_ok=True)
except Exception:
pass
print(
"Running Jupyter Notebook Server on {} [{}] port {} at {}".format(hostname, hostnames, port, cwd)
)
process = subprocess.Popen(
[
sys.executable,
"-m",
"jupyter",
"lab",
"--no-browser",
"--allow-root",
"--ip",
"127.0.0.1",
2020-12-22 19:32:02 +00:00
"--port",
str(port),
],
env=env,
stdout=fd,
stderr=fd,
cwd=cwd,
)
return monitor_jupyter_server(fd, local_filename, process, task, port, hostnames)
def setup_ssh_server(hostname, hostnames, param, task):
if not param.get("ssh_server"):
return
print("Installing SSH Server on {} [{}]".format(hostname, hostnames))
ssh_password = param.get("ssh_password", "training")
# noinspection PyBroadException
try:
port = get_free_port(10022, 15000)
proxy_port = get_free_port(10022, 15000)
# if we are root, install open-ssh
if os.geteuid() == 0:
# noinspection SpellCheckingInspection
os.system(
"export PYTHONPATH=\"\" && "
"apt-get install -y openssh-server && "
"mkdir -p /var/run/sshd && "
"echo 'root:{password}' | chpasswd && "
"echo 'PermitRootLogin yes' >> /etc/ssh/sshd_config && "
"sed -i 's/PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config && "
"sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd "
"&& " # noqa: W605
"echo 'ClientAliveInterval 10' >> /etc/ssh/sshd_config && "
"echo 'ClientAliveCountMax 20' >> /etc/ssh/sshd_config && "
"echo 'AcceptEnv TRAINS_API_ACCESS_KEY TRAINS_API_SECRET_KEY "
"CLEARML_API_ACCESS_KEY CLEARML_API_SECRET_KEY' >> /etc/ssh/sshd_config && "
'echo "export VISIBLE=now" >> /etc/profile && '
'echo "export PATH=$PATH" >> /etc/profile && '
'echo "ldconfig" >> /etc/profile && '
'echo "export TRAINS_CONFIG_FILE={trains_config_file}" >> /etc/profile'.format(
password=ssh_password,
port=port,
trains_config_file=os.environ.get("CLEARML_CONFIG_FILE") or os.environ.get("TRAINS_CONFIG_FILE"),
)
2020-12-22 19:32:02 +00:00
)
sshd_path = '/usr/sbin/sshd'
ssh_config_path = '/etc/ssh/'
custom_ssh_conf = None
else:
# check if sshd exists
# noinspection PyBroadException
try:
sshd_path = subprocess.check_output('which sshd', shell=True).decode().strip()
ssh_config_path = os.path.join(os.getcwd(), '.clearml_session_sshd')
Path(ssh_config_path).mkdir(parents=True, exist_ok=True)
custom_ssh_conf = os.path.join(ssh_config_path, 'sshd_config')
with open(custom_ssh_conf, 'wt') as f:
conf = \
"PermitRootLogin yes" + "\n"\
"ClientAliveInterval 10" + "\n"\
"ClientAliveCountMax 20" + "\n"\
"AllowTcpForwarding yes" + "\n"\
"UsePAM yes" + "\n"\
"AuthorizedKeysFile {}".format(os.path.join(ssh_config_path, 'authorized_keys')) + "\n"\
"PidFile {}".format(os.path.join(ssh_config_path, 'sshd.pid')) + "\n"\
"AcceptEnv TRAINS_API_ACCESS_KEY TRAINS_API_SECRET_KEY "\
"CLEARML_API_ACCESS_KEY CLEARML_API_SECRET_KEY"+"\n"
for k in default_ssh_fingerprint:
filename = os.path.join(ssh_config_path, '{}'.format(k.replace('__pub', '.pub')))
conf += "HostKey {}\n".format(filename)
f.write(conf)
except Exception:
print('Error: Cannot install sshd (not root) and could not find sshd executable, leaving!')
return
# clear the ssh password, we cannot change it
ssh_password = None
task.set_parameter('{}/ssh_password'.format(config_section_name), '')
2020-12-22 19:32:02 +00:00
# create fingerprint files
Path(ssh_config_path).mkdir(parents=True, exist_ok=True)
2020-12-22 19:32:02 +00:00
for k, v in default_ssh_fingerprint.items():
filename = os.path.join(ssh_config_path, '{}'.format(k.replace('__pub', '.pub')))
2020-12-22 19:32:02 +00:00
try:
os.unlink(filename)
except Exception: # noqa
pass
if v:
with open(filename, 'wt') as f:
f.write(v + (' root@{}'.format(hostname) if filename.endswith('.pub') else ''))
os.chmod(filename, 0o600 if filename.endswith('.pub') else 0o600)
2020-12-22 19:32:02 +00:00
# run server in foreground so it gets killed with us
proc_args = [sshd_path, "-D", "-p", str(port)] + (["-f", custom_ssh_conf] if custom_ssh_conf else [])
proc = subprocess.Popen(args=proc_args)
# noinspection PyBroadException
try:
result = proc.wait(timeout=1)
except Exception:
result = 0
2020-12-22 19:32:02 +00:00
if result != 0:
raise ValueError("Failed launching sshd: ", proc_args)
# noinspection PyBroadException
try:
TcpProxy(listen_port=proxy_port, target_port=port, proxy_state={}, verbose=False, # noqa
keep_connection=True, is_connection_server=True)
except Exception as ex:
print('Warning: Could not setup stable ssh port, {}'.format(ex))
proxy_port = None
if task:
if proxy_port:
task.set_parameter(name='properties/internal_stable_ssh_port', value=str(proxy_port))
task.set_parameter(name='properties/internal_ssh_port', value=str(port))
print(
"\n#\n# SSH Server running on {} [{}] port {}\n# LOGIN u:root p:{}\n#\n".format(
hostname, hostnames, port, ssh_password
2020-12-22 19:32:02 +00:00
)
)
2020-12-22 19:32:02 +00:00
except Exception as ex:
print("Error: {}\n\n#\n# Error: SSH server could not be launched\n#\n".format(ex))
2020-12-22 19:32:02 +00:00
def setup_user_env(param, task):
env = setup_os_env(param)
# create symbolic link to the venv
environment = os.path.expanduser('~/environment')
# noinspection PyBroadException
try:
os.symlink(os.path.abspath(os.path.join(os.path.abspath(sys.executable), '..', '..')), environment)
print('Virtual environment are available at {}'.format(environment))
except Exception:
pass
# set default user credentials
if param.get("user_key") and param.get("user_secret"):
os.system("echo 'export TRAINS_API_ACCESS_KEY=\"{}\"' >> ~/.bashrc".format(
param.get("user_key", "").replace('\\$', '\\$')))
os.system("echo 'export TRAINS_API_SECRET_KEY=\"{}\"' >> ~/.bashrc".format(
param.get("user_secret", "").replace('\\$', '\\$')))
os.system("echo 'export TRAINS_DOCKER_IMAGE=\"{}\"' >> ~/.bashrc".format(
param.get("default_docker", "").strip() or env.get('TRAINS_DOCKER_IMAGE', '')))
os.system("echo 'export TRAINS_API_ACCESS_KEY=\"{}\"' >> ~/.profile".format(
param.get("user_key", "").replace('\\$', '\\$')))
os.system("echo 'export TRAINS_API_SECRET_KEY=\"{}\"' >> ~/.profile".format(
param.get("user_secret", "").replace('\\$', '\\$')))
os.system("echo 'export TRAINS_DOCKER_IMAGE=\"{}\"' >> ~/.profile".format(
param.get("default_docker", "").strip() or env.get('TRAINS_DOCKER_IMAGE', '')))
env['TRAINS_API_ACCESS_KEY'] = param.get("user_key")
env['TRAINS_API_SECRET_KEY'] = param.get("user_secret")
# set default folder for user
if param.get("user_base_directory"):
base_dir = param.get("user_base_directory")
if ' ' in base_dir:
base_dir = '\"{}\"'.format(base_dir)
os.system("echo 'cd {}' >> ~/.bashrc".format(base_dir))
os.system("echo 'cd {}' >> ~/.profile".format(base_dir))
# make sure we activate the venv in the bash
os.system("echo 'source {}' >> ~/.bashrc".format(os.path.join(environment, 'bin', 'activate')))
os.system("echo '. {}' >> ~/.profile".format(os.path.join(environment, 'bin', 'activate')))
# check if we need to create .git-credentials
# noinspection PyProtectedMember
git_credentials = task._get_configuration_text('git_credentials')
if git_credentials:
git_cred_file = os.path.expanduser('~/.config/git/credentials')
# noinspection PyBroadException
try:
Path(git_cred_file).parent.mkdir(parents=True, exist_ok=True)
with open(git_cred_file, 'wt') as f:
f.write(git_credentials)
except Exception:
print('Could not write {} file'.format(git_cred_file))
# noinspection PyProtectedMember
git_config = task._get_configuration_text('git_config')
if git_config:
git_config_file = os.path.expanduser('~/.config/git/config')
# noinspection PyBroadException
try:
Path(git_config_file).parent.mkdir(parents=True, exist_ok=True)
with open(git_config_file, 'wt') as f:
f.write(git_config)
except Exception:
print('Could not write {} file'.format(git_config_file))
return env
def get_host_name(task, param):
# noinspection PyBroadException
try:
hostname = socket.gethostname()
hostnames = socket.gethostbyname(socket.gethostname())
except Exception:
def get_ip_addresses(family):
for interface, snics in psutil.net_if_addrs().items():
for snic in snics:
if snic.family == family:
yield snic.address
hostnames = list(get_ip_addresses(socket.AF_INET))[0]
hostname = hostnames
# try to get external address (if possible)
# noinspection PyBroadException
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# noinspection PyBroadException
try:
# doesn't even have to be reachable
s.connect(('8.255.255.255', 1))
hostnames = s.getsockname()[0]
except Exception:
pass
finally:
s.close()
except Exception:
pass
# update host name
if not task.get_parameter(name='properties/external_address'):
external_addr = hostnames
if param.get('public_ip'):
# noinspection PyBroadException
try:
external_addr = requests.get('https://checkip.amazonaws.com').text.strip()
except Exception:
pass
task.set_parameter(name='properties/external_address', value=str(external_addr))
return hostname, hostnames
def run_user_init_script(task):
# run initialization script:
# noinspection PyProtectedMember
init_script = task._get_configuration_text(config_object_section_bash_init)
if not init_script or not str(init_script).strip():
return
print("Running user initialization bash script:")
init_filename = os_json_filename = None
try:
fd, init_filename = mkstemp(suffix='.init.sh')
os.close(fd)
fd, os_json_filename = mkstemp(suffix='.env.json')
os.close(fd)
with open(init_filename, 'wt') as f:
f.write(init_script +
'\n{} -c '
'"exec(\\"try:\\n import os\\n import json\\n'
' json.dump(dict(os.environ), open(\\\'{}\\\', \\\'w\\\'))'
'\\nexcept: pass\\")"'.format(sys.executable, os_json_filename))
env = dict(**os.environ)
# do not pass or update back the PYTHONPATH environment variable
env.pop('PYTHONPATH', None)
subprocess.call(['/bin/bash', init_filename], env=env)
with open(os_json_filename, 'rt') as f:
environ = json.load(f)
# do not pass or update back the PYTHONPATH environment variable
environ.pop('PYTHONPATH', None)
# update environment variables
os.environ.update(environ)
except Exception as ex:
print('User initialization script failed: {}'.format(ex))
finally:
if init_filename:
try:
os.unlink(init_filename)
except: # noqa
pass
if os_json_filename:
try:
os.unlink(os_json_filename)
except: # noqa
pass
def main():
param = {
"user_base_directory": "~/",
"ssh_server": True,
"ssh_password": "training",
"default_docker": "nvidia/cuda",
"user_key": None,
"user_secret": None,
"vscode_server": True,
"jupyterlab": True,
2020-12-22 19:32:02 +00:00
"public_ip": False,
}
task = init_task(param, default_ssh_fingerprint)
run_user_init_script(task)
hostname, hostnames = get_host_name(task, param)
env = setup_user_env(param, task)
setup_ssh_server(hostname, hostnames, param, task)
start_vscode_server(hostname, hostnames, param, task, env)
start_jupyter_server(hostname, hostnames, param, task, env)
print('We are done')
if __name__ == '__main__':
main()