clearml/examples/services/jupyter-service/execute_jupyter_notebook_server.py
2020-12-24 00:30:32 +02:00

167 lines
4.7 KiB
Python

import os
import socket
import subprocess
import sys
from copy import deepcopy
from tempfile import mkstemp
import psutil
# make sure we have jupyter in the auto requirements
import jupyter # noqa
from clearml import Task
# Connecting ClearML with the current process,
# from here on everything is logged automatically
task = Task.init(
project_name="DevOps",
task_name="Allocate Jupyter Notebook Instance",
task_type=Task.TaskTypes.service
)
# get rid of all the runtime ClearML
preserve = (
"_API_HOST",
"_WEB_HOST",
"_FILES_HOST",
"_CONFIG_FILE",
"_API_ACCESS_KEY",
"_API_SECRET_KEY",
"_API_HOST_VERIFY_CERT",
"_DOCKER_IMAGE",
)
# setup os environment
env = deepcopy(os.environ)
for key in os.environ:
if (key.startswith("TRAINS") and key[6:] not in preserve) or \
(key.startswith("CLEARML") and key[7:] not in preserve):
env.pop(key, None)
# Add jupyter server base folder
param = {
"jupyter_server_base_directory": "~/",
"ssh_server": True,
"ssh_password": "training",
"default_docker_for_jupyter": "nvidia/cuda",
}
task.connect(param)
# set default docker image, with network configuration
os.environ["CLEARML_DOCKER_IMAGE"] = param['default_docker_for_jupyter']
task.set_base_docker("{} --network host".format(param['default_docker_for_jupyter']))
# noinspection PyBroadException
try:
hostname = socket.gethostname()
hostnames = socket.gethostbyname(socket.gethostname())
except Exception:
def get_ip_addresses(family):
for interface, snics in psutil.net_if_addrs().items():
for snic in snics:
if snic.family == family:
yield snic.address
hostnames = list(get_ip_addresses(socket.AF_INET))
hostname = hostnames[0]
if param.get("ssh_server"):
print("Installing SSH Server on {} [{}]".format(hostname, hostnames))
ssh_password = param.get("ssh_password", "training")
# noinspection PyBroadException
try:
used_ports = [i.laddr.port for i in psutil.net_connections()]
port = [i for i in range(10022, 15000) if i not in used_ports][0]
result = os.system(
"apt-get install -y openssh-server && "
"mkdir -p /var/run/sshd && "
"echo 'root:{password}' | chpasswd && "
"echo 'PermitRootLogin yes' >> /etc/ssh/sshd_config && "
"sed -i 's/PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config && "
"sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd && " # noqa: W605
'echo "export VISIBLE=now" >> /etc/profile && '
'echo "export TRAINS_CONFIG_FILE={clearml_config_file}" >> /etc/profile && '
"/usr/sbin/sshd -p {port}".format(
password=ssh_password,
port=port,
clearml_config_file=os.environ.get("TRAINS_CONFIG_FILE"),
)
)
if result == 0:
print(
"\n#\n# SSH Server running on {} [{}] port {}\n# LOGIN u:root p:{}\n#\n".format(
hostname, hostnames, port, ssh_password
)
)
else:
raise ValueError()
except Exception:
print("\n#\n# Error: SSH server could not be launched\n#\n")
# execute jupyter notebook
fd, local_filename = mkstemp()
cwd = (
os.path.expandvars(os.path.expanduser(param["jupyter_server_base_directory"]))
if param["jupyter_server_base_directory"]
else os.getcwd()
)
print(
"Running Jupyter Notebook Server on {} [{}] at {}".format(hostname, hostnames, cwd)
)
process = subprocess.Popen(
[
sys.executable,
"-m",
"jupyter",
"notebook",
"--no-browser",
"--allow-root",
"--ip",
"0.0.0.0",
],
env=env,
stdout=fd,
stderr=fd,
cwd=cwd,
)
# print stdout/stderr
prev_line_count = 0
process_running = True
while process_running:
process_running = False
try:
process.wait(timeout=2.0 if prev_line_count == 0 else 15.0)
except subprocess.TimeoutExpired:
process_running = True
with open(local_filename, "rt") as f:
# read new lines
new_lines = f.readlines()
if not new_lines:
continue
output = "".join(new_lines)
print(output)
# update task comment with jupyter notebook server links
if prev_line_count == 0:
task.comment += "\n" + "".join(
line for line in new_lines if "http://" in line or "https://" in line
)
prev_line_count += len(new_lines)
os.lseek(fd, 0, 0)
os.ftruncate(fd, 0)
# cleanup
os.close(fd)
# noinspection PyBroadException
try:
os.unlink(local_filename)
except Exception:
pass