import base64 import json import os import shutil import socket import subprocess import sys from copy import deepcopy import getpass from functools import partial from tempfile import mkstemp, gettempdir, mkdtemp from time import sleep, time from datetime import datetime import psutil import requests from clearml import Task, StorageManager from clearml.backend_api import Session from clearml.backend_api.services import tasks from pathlib2 import Path # noinspection SpellCheckingInspection default_ssh_fingerprint = { 'ssh_host_ecdsa_key': r"-----BEGIN EC PRIVATE KEY-----"+"\n" r"MHcCAQEEIOCAf3KEN9Hrde53rqQM4eR8VfCnO0oc4XTEBw0w6lCfoAoGCCqGSM49"+"\n" r"AwEHoUQDQgAEn/LlC/1UN1q6myfjs03LJdHY2LB0b1hBjAsLvQnDMt8QE6Rml3UF"+"\n" r"QK/UFw4mEqCFCD+dcbyWqFsKxTm6WtFStg=="+"\n" r"-----END EC PRIVATE KEY-----"+"\n", 'ssh_host_ed25519_key': r"-----BEGIN OPENSSH PRIVATE KEY-----"+"\n" r"b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW"+"\n" r"QyNTUxOQAAACDvweeJHnUKtwY7/WRqDJEZTDk8AajWKFt/BXmEI3+A8gAAAJiEMTXOhDE1"+"\n" r"zgAAAAtzc2gtZWQyNTUxOQAAACDvweeJHnUKtwY7/WRqDJEZTDk8AajWKFt/BXmEI3+A8g"+"\n" r"AAAEBCHpidTBUN3+W8s3qRNkyaJpA/So4vEqDvOhseSqJeH+/B54kedQq3Bjv9ZGoMkRlM"+"\n" r"OTwBqNYoW38FeYQjf4DyAAAAEXJvb3RAODQ1NmQ5YTdlYTk4AQIDBA=="+"\n" r"-----END OPENSSH PRIVATE KEY-----"+"\n", 'ssh_host_rsa_key': r"-----BEGIN RSA PRIVATE KEY-----"+"\n" r"MIIEowIBAAKCAQEAs8R3BrinMM/k9Jak7UqsoONqLQoasYgkeBVOOfRJ6ORYWW5R"+"\n" r"WLkYnPPUGRpbcoM1Imh7ODBgKzs0mh5/j3y0SKP/MpvT4bf38e+QGjuC+6fR4Ah0"+"\n" r"L5ohGIMyqhAiBoXgj0k2BE6en/4Rb3BwNPMocCTus82SwajzMNgWneRC6GCq2M0n"+"\n" r"0PWenhS0IQz7jUlw3JU8z6T3ROPiMBPU7ubBhiNlAzMYPr76Z7J6ZNrCclAvdGkI"+"\n" r"YxK7RNq0HwfoUj0UFD9iaEHswDIlNc34p93lP6GIAbh7uVYfGhg4z7HdBoN2qweN"+"\n" r"szo7iQX9N8EFP4WfpLzNFteThzgN/bdso8iv0wIDAQABAoIBAQCPvbF64110b1dg"+"\n" r"p7AauVINl6oHd4PensCicE7LkmUi3qsyXz6WVfKzVVgr9mJWz0lGSQr14+CR0NZ/"+"\n" r"wZE393vkdZWSLv2eB88vWeH8x8c1WHw9yiS1B2YdRpLVXu8GDjh/+gdCLGc0ASCJ"+"\n" r"3fsqq5+TBEUF6oPFbEWAsdhryeAiFAokeIVEKkxRnIDvPCP6i0evUHAxEP+wOngu"+"\n" r"4XONkixNmATNa1jP2YAjmh3uQbAf2BvDZuywJmqV8fqZa/BwuK3W+R/92t0ySZ5Q"+"\n" r"Z7RCZzPzFvWY683/Cfx5+BH3XcIetbcZ/HKuc+TdBvvFgqrLNIJ4OXMp3osjZDMO"+"\n" r"YZIE6DdBAoGBAOG8cgm2N+Kl2dl0q1r4S+hf//zPaDorNasvcXJcj/ypy1MdmDXt"+"\n" r"whLSAuTN4r8axgbuws2Z870pIGd28koqg78U+pOPabkphloo8Fc97RO28ZJCK2g0"+"\n" r"/prPgwSYymkhrvwdzIbI11BPL/rr9cLJ1eYDnzGDSqvXJDL79XxrzwMzAoGBAMve"+"\n" r"ULkfqaYVlgY58d38XruyCpSmRSq39LTeTYRWkJTNFL6rkqL9A69z/ITdpSStEuR8"+"\n" r"8MXQSsPz8xUhFrA2bEjW7AT0r6OqGbjljKeh1whYOfgGfMKQltTfikkrf5w0UrLw"+"\n" r"NQ8USfpwWdFnBGQG0yE/AFknyLH14/pqfRlLzaDhAoGAcN3IJxL03l4OjqvHAbUk"+"\n" r"PwvA8qbBdlQkgXM3RfcCB1LeVrB1aoF2h/J5f+1xchvw54Z54FMZi3sEuLbAblTT"+"\n" r"irbyktUiB3K7uli90uEjqLfQEVEEYxYcN0uKNsIucmJlG6nKmZnSDlWJp+xS9RH1"+"\n" r"4QvujNMYgtMPRm60T4GYAAECgYB6J9LMqik4CDUls/C2J7MH2m22lk5Zg3JQMefW"+"\n" r"xRvK3XtxqFKr8NkVd3U2k6yRZlcsq6SFkwJJmdHsti/nFCUcHBO+AHOBqLnS7VCz"+"\n" r"XSkAqgTKFfEJkCOgl/U/VJ4ZFcz7xSy1xV1yf4GCFK0v1lsJz7tAsLLz1zdsZARj"+"\n" r"dOVYYQKBgC3IQHfd++r9kcL3+vU7bDVU4aKq0JFDA79DLhKDpSTVxqTwBT+/BIpS"+"\n" r"8z79zBTjNy5gMqxZp/SWBVWmsO8d7IUk9O2L/bMhHF0lOKbaHQQ9oveCzIwDewcf"+"\n" r"5I45LjjGPJS84IBYv4NElptRk/2eFFejr75xdm4lWfpLb1SXPOPB"+"\n" r"-----END RSA PRIVATE KEY-----"+"\n", 'ssh_host_rsa_key__pub': r'ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCzxHcGuKcwz+T0lqTtSqyg42otChqxiCR4FU459Eno5FhZblFYuRic89QZGlt' r'ygzUiaHs4MGArOzSaHn+PfLRIo/8ym9Pht/fx75AaO4L7p9HgCHQvmiEYgzKqECIGheCPSTYETp6f/hFvcHA08yhwJO6zzZLBqPM' r'w2Bad5ELoYKrYzSfQ9Z6eFLQhDPuNSXDclTzPpPdE4+IwE9Tu5sGGI2UDMxg+vvpnsnpk2sJyUC90aQhjErtE2rQfB+hSPRQUP2Jo' r'QezAMiU1zfin3eU/oYgBuHu5Vh8aGDjPsd0Gg3arB42zOjuJBf03wQU/hZ+kvM0W15OHOA39t2yjyK/T', 'ssh_host_ecdsa_key__pub': r'ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJ/y5Qv9VDdaupsn47NNyyXR2Niwd' r'G9YQYwLC70JwzLfEBOkZpd1BUCv1BcOJhKghQg/nXG8lqhbCsU5ulrRUrY=', 'ssh_host_ed25519_key__pub': None, } config_section_name = 'interactive_session' config_object_section_ssh = 'SSH' config_object_section_bash_init = 'interactive_init_script' artifact_workspace_name = "workspace" sync_runtime_property = "workspace_sync_ts" sync_workspace_creating_id = "created_by_session" __poor_lock = [] __allocated_ports = [] def get_free_port(range_min, range_max): global __allocated_ports used_ports = [i.laddr.port for i in psutil.net_connections()] port = next(i for i in range(range_min, range_max) if i not in used_ports and i not in __allocated_ports) __allocated_ports.append(port) return port def init_task(param, a_default_ssh_fingerprint): # initialize ClearML Task.add_requirements('jupyter') Task.add_requirements('jupyterlab') Task.add_requirements('jupyterlab_git') task = Task.init( project_name="DevOps", task_name="Allocate Jupyter Notebook Instance", task_type=Task.TaskTypes.service ) # Add jupyter server base folder if Session.check_min_api_version('2.13'): param.pop('user_key', None) param.pop('user_secret', None) param.pop('ssh_password', None) task.connect(param, name=config_section_name) # noinspection PyProtectedMember runtime_prop = dict(task._get_runtime_properties()) # remove the user key/secret the moment we have it param['user_key'] = runtime_prop.pop('_user_key', None) param['user_secret'] = runtime_prop.pop('_user_secret', None) # no need to reset, we will need it param['ssh_password'] = runtime_prop.get('_ssh_password') # Force removing properties # noinspection PyProtectedMember task._edit(runtime=runtime_prop) task.reload() else: task.connect(param, name=config_section_name) # connect ssh fingerprint configuration (with fallback if section is missing) old_default_ssh_fingerprint = deepcopy(a_default_ssh_fingerprint) try: task.connect_configuration(configuration=a_default_ssh_fingerprint, name=config_object_section_ssh) except (TypeError, ValueError): a_default_ssh_fingerprint.clear() a_default_ssh_fingerprint.update(old_default_ssh_fingerprint) if param.get('default_docker') and task.running_locally(): task.set_base_docker("{} --network host".format(param['default_docker'])) # leave local process, only run remotely task.execute_remotely() return task def setup_os_env(param): # get rid of all the runtime ClearML preserve = ( "_API_HOST", "_WEB_HOST", "_FILES_HOST", "_CONFIG_FILE", "_API_ACCESS_KEY", "_API_SECRET_KEY", "_API_HOST_VERIFY_CERT", "_DOCKER_IMAGE", "_DOCKER_BASH_SCRIPT", ) # set default docker image, with network configuration if param.get('default_docker', '').strip(): os.environ["CLEARML_DOCKER_IMAGE"] = param['default_docker'].strip() # setup os environment env = deepcopy(os.environ) for key in os.environ: # only set CLEARML_ remove any TRAINS_ if key.startswith("TRAINS") or (key.startswith("CLEARML") and not any(key.endswith(p) for p in preserve)): env.pop(key, None) return env def monitor_jupyter_server(fd, local_filename, process, task, jupyter_port, hostnames): # todo: add auto spin down see: https://tljh.jupyter.org/en/latest/topic/idle-culler.html # print stdout/stderr prev_line_count = 0 process_running = True token = None while process_running: process_running = False try: process.wait(timeout=2.0 if not token else 15.0) except subprocess.TimeoutExpired: process_running = True # noinspection PyBroadException try: with open(local_filename, "rt") as f: # read new lines new_lines = f.readlines() if not new_lines: continue os.lseek(fd, 0, 0) os.ftruncate(fd, 0) except Exception: continue print("".join(new_lines)) prev_line_count += len(new_lines) # if we already have the token, do nothing, just monitor if token: continue # update task with jupyter notebook server links (port / token) line = '' for line in new_lines: if "http://" not in line and "https://" not in line: continue parts = line.split('?token=', 1) if len(parts) != 2: continue token = parts[1] port = parts[0].split(':')[-1] # try to cast to int try: port = int(port) # noqa except (TypeError, ValueError): continue break # we could not locate the token, try again if not token: continue # we ignore the reported port, because jupyter server will get confused # if we have multiple servers running and will point to the wrong port/server task.set_parameter(name='properties/jupyter_port', value=str(jupyter_port)) jupyter_url = '{}://{}:{}?token={}'.format( 'https' if "https://" in line else 'http', hostnames, jupyter_port, token ) # update the task with the correct links and token if Session.check_min_api_version("2.13"): # noinspection PyProtectedMember runtime_prop = task._get_runtime_properties() runtime_prop['_jupyter_token'] = str(token) runtime_prop['_jupyter_url'] = str(jupyter_url) # noinspection PyProtectedMember task._set_runtime_properties(runtime_prop) else: task.set_parameter(name='properties/jupyter_token', value=str(token)) task.set_parameter(name='properties/jupyter_url', value=jupyter_url) print('\nJupyter Lab URL: {}\n'.format(jupyter_url)) # cleanup # noinspection PyBroadException try: os.close(fd) except Exception: pass # noinspection PyBroadException try: os.unlink(local_filename) except Exception: pass def start_vscode_server(hostname, hostnames, param, task, env, bind_ip="127.0.0.1", port=None): if not param.get("vscode_server"): return # get vscode version and python extension version # they are extremely flaky, this combination works, most do not. vscode_version = '4.14.1' python_ext_version = '2023.12.0' if param.get("vscode_version"): vscode_version_parts = param.get("vscode_version").split(':') vscode_version = vscode_version_parts[0] if len(vscode_version_parts) > 1: python_ext_version = vscode_version_parts[1] # make a copy of env and remove the pythonpath from it. env = dict(**env) env.pop('PYTHONPATH', None) # example of CLEARML_SESSION_VSCODE_PY_EXT value # 'https://github.com/microsoft/vscode-python/releases/download/{}/ms-python-release.vsix' python_ext_download_link = os.environ.get("CLEARML_SESSION_VSCODE_PY_EXT") code_server_deb_download_link = \ os.environ.get("CLEARML_SESSION_VSCODE_SERVER_DEB") or \ 'https://github.com/coder/code-server/releases/download/v{version}/code-server_{version}_amd64.deb' pre_installed = False python_ext = None # find a free tcp port port = get_free_port(9000, 9100) if not port else int(port) if os.geteuid() == 0: # check if preinstalled # noinspection PyBroadException try: vscode_path = subprocess.check_output('which code-server', shell=True).decode().strip() pre_installed = bool(vscode_path) except Exception: vscode_path = None if not vscode_path: # installing VSCODE: try: python_ext = StorageManager.get_local_copy( python_ext_download_link.format(python_ext_version), extract_archive=False) if python_ext_download_link else None code_server_deb = StorageManager.get_local_copy( code_server_deb_download_link.format(version=vscode_version), extract_archive=False) os.system("dpkg -i {}".format(code_server_deb)) except Exception as ex: print("Failed installing vscode server: {}".format(ex)) return vscode_path = 'code-server' else: python_ext = None pre_installed = True # check if code-server exists # noinspection PyBroadException try: vscode_path = subprocess.check_output('which code-server', shell=True).decode().strip() assert vscode_path except Exception: print('Error: Cannot install code-server (not root) and could not find code-server executable, skipping.') task.set_parameter(name='properties/vscode_port', value=str(-1)) return cwd = ( os.path.expandvars(os.path.expanduser(param["user_base_directory"])) if param["user_base_directory"] else os.getcwd() ) # make sure we have the needed cwd # noinspection PyBroadException try: Path(cwd).mkdir(parents=True, exist_ok=True) except Exception: print("Warning: failed setting user base directory [{}] reverting to ~/".format(cwd)) cwd = os.path.expanduser("~/") print("Running VSCode Server on {} [{}] port {} at {}".format(hostname, hostnames, port, cwd)) print("VSCode Server available: http://{}:{}/\n".format(hostnames, port)) user_folder = os.path.join(cwd, ".vscode/user/") exts_folder = os.path.join(cwd, ".vscode/exts/") proc = None try: fd, local_filename = mkstemp() if pre_installed: user_folder = os.path.expanduser("~/.local/share/code-server/") if not os.path.isdir(user_folder): user_folder = None exts_folder = None else: exts_folder = os.path.expanduser("~/.local/share/code-server/extensions/") else: vscode_extensions = param.get("vscode_extensions") or "" vscode_extensions_cmd = [] jupyter_ext_version = True for ext in vscode_extensions.split(","): ext = ext.strip() if not ext: continue if ext.startswith("ms-python.python"): python_ext_version = python_ext = None elif ext.startswith("ms-toolsai.jupyter"): jupyter_ext_version = None vscode_extensions_cmd += ["--install-extension", ext] if python_ext: vscode_extensions_cmd += ["--install-extension", "{}".format(python_ext)] elif python_ext_version: vscode_extensions_cmd += ["--install-extension", "ms-python.python@{}".format(python_ext_version)] if jupyter_ext_version: vscode_extensions_cmd += ["--install-extension", "ms-toolsai.jupyter"] print("VScode extensions: {}".format(vscode_extensions_cmd)) subprocess.Popen( [ vscode_path, "--auth", "none", "--bind-addr", "{}:{}".format(bind_ip, port), "--user-data-dir", user_folder, "--extensions-dir", exts_folder, ] + vscode_extensions_cmd, env=env, stdout=fd, stderr=fd, ) if user_folder: # set user level configuration settings = Path(os.path.expanduser(os.path.join(user_folder, 'User/settings.json'))) settings.parent.mkdir(parents=True, exist_ok=True) # noinspection PyBroadException try: with open(settings.as_posix(), 'rt') as f: base_json = json.load(f) except Exception: base_json = {} # noinspection PyBroadException try: # Notice we are Not using "python.defaultInterpreterPath": sys.executable, # because for some reason it breaks the auto python interpreter setup base_json.update({ "extensions.autoCheckUpdates": False, "extensions.autoUpdate": False, "python.pythonPath": sys.executable, "terminal.integrated.shell.linux": "/bin/bash" if Path("/bin/bash").is_file() else None, "security.workspace.trust.untrustedFiles": "open", # "security.workspace.trust.startupPrompt": "never", "security.workspace.trust.enabled": False, }) with open(settings.as_posix(), 'wt') as f: json.dump(base_json, f) except Exception: pass # set machine level configuration settings = Path(os.path.expanduser(os.path.join(user_folder, 'Machine/settings.json'))) settings.parent.mkdir(parents=True, exist_ok=True) # noinspection PyBroadException try: with open(settings.as_posix(), 'rt') as f: base_json = json.load(f) except Exception: base_json = {} # noinspection PyBroadException try: # "python.defaultInterpreterPath" is a machine level setting base_json.update({ "python.defaultInterpreterPath": sys.executable, }) with open(settings.as_posix(), 'wt') as f: json.dump(base_json, f) except Exception: pass proc = subprocess.Popen( ['bash', '-c', '{} --auth none --bind-addr {}:{} --disable-update-check {} {}'.format( vscode_path, bind_ip, port, '--user-data-dir \"{}\"'.format(user_folder) if user_folder else '', '--extensions-dir \"{}\"'.format(exts_folder) if exts_folder else '')], env=env, stdout=fd, stderr=fd, cwd=cwd, ) try: error_code = proc.wait(timeout=1) raise ValueError("code-server failed starting, return code {}".format(error_code)) except subprocess.TimeoutExpired: pass except Exception as ex: print('Failed running vscode server: {}'.format(ex)) task.set_parameter(name='properties/vscode_port', value=str(-1)) return task.set_parameter(name='properties/vscode_port', value=str(port)) return proc def start_jupyter_server(hostname, hostnames, param, task, env, bind_ip="127.0.0.1", port=None): if not param.get('jupyterlab', True): print('no jupyterlab to monitor - going to sleep') while True: sleep(10.) return # noqa # execute jupyter notebook fd, local_filename = mkstemp() cwd = ( os.path.expandvars(os.path.expanduser(param["user_base_directory"])) if param["user_base_directory"] else os.getcwd() ) # find a free tcp port port = get_free_port(8888, 9000) if not port else int(port) # if we are not running as root, make sure the sys executable is in the PATH env = dict(**env) env['PATH'] = '{}:{}'.format(Path(sys.executable).parent.as_posix(), env.get('PATH', '')) try: # set default shell to bash if not defined if not env.get("SHELL") and shutil.which("bash"): env['SHELL'] = shutil.which("bash") except Exception as ex: print("WARNING: failed finding default shell bash: {}".format(ex)) # make sure we have the needed cwd # noinspection PyBroadException try: Path(cwd).mkdir(parents=True, exist_ok=True) except Exception: print("Warning: failed setting user base directory [{}] reverting to ~/".format(cwd)) cwd = os.path.expanduser("~/") # setup jupyter-lab default # noinspection PyBroadException try: settings = Path(os.path.expanduser( "~/.jupyter/lab/user-settings/@jupyterlab/apputils-extension/notification.jupyterlab-settings")) settings.parent.mkdir(parents=True, exist_ok=True) # noinspection PyBroadException try: with open(settings.as_posix(), 'rt') as f: base_json = json.load(f) except Exception: base_json = {} # Notice we are Not using "python.defaultInterpreterPath": sys.executable, # because for some reason it breaks the auto python interpreter setup base_json.update({ "checkForUpdates": False, "doNotDisturbMode": False, "fetchNews": "false" }) with open(settings.as_posix(), 'wt') as f: json.dump(base_json, f) except Exception as ex: print("WARNING: Could not set default jupyter lab settings: {}".format(ex)) print( "Running Jupyter Notebook Server on {} [{}] port {} at {}".format(hostname, hostnames, port, cwd) ) process = subprocess.Popen( [ sys.executable, "-m", "jupyter", "lab", "--no-browser", "--allow-root", "--ip", bind_ip, "--port", str(port), ], env=env, stdout=fd, stderr=fd, cwd=cwd, ) return monitor_jupyter_server(fd, local_filename, process, task, port, hostnames) def setup_ssh_server(hostname, hostnames, param, task, env): if not param.get("ssh_server"): return # make sure we do not pass it along to others, work on a copy env = deepcopy(env) env.pop('LOCAL_PYTHON', None) env.pop('PYTHONPATH', None) env.pop('DEBIAN_FRONTEND', None) print("Installing SSH Server on {} [{}]".format(hostname, hostnames)) ssh_password = param.get("ssh_password", "training") ssh_port = None if Session.check_min_api_version("2.13"): try: # noinspection PyProtectedMember ssh_port = task._get_runtime_properties().get("internal_tcp_port") except Exception as ex: print("Failed retrieving internal TCP port for SSH daemon: {}".format(ex)) # noinspection PyBroadException try: ssh_port = ssh_port or param.get("ssh_ports") or "10022:15000" min_port = int(ssh_port.split(":")[0]) max_port = max(min_port+32, int(ssh_port.split(":")[-1])) port = get_free_port(min_port, max_port) if param.get("use_ssh_proxy"): proxy_port = port port = get_free_port(min_port, max_port) else: proxy_port = None use_dropbear = bool(param.get("force_dropbear", False)) # if we are root, install open-ssh if not use_dropbear and os.geteuid() == 0: # noinspection SpellCheckingInspection os.system( "export PYTHONPATH=\"\" && " "([ ! -z $(which sshd) ] || " "(apt-get update ; DEBIAN_FRONTEND=noninteractive apt-get install -y openssh-server)) && " "mkdir -p /var/run/sshd && " "echo 'root:{password}' | chpasswd && " "echo 'PermitRootLogin yes' >> /etc/ssh/sshd_config && " "sed -i 's/PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config && " "sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd " "&& " # noqa: W605 "echo 'ClientAliveInterval 10' >> /etc/ssh/sshd_config && " "echo 'ClientAliveCountMax 20' >> /etc/ssh/sshd_config && " "echo 'AcceptEnv CLEARML_API_ACCESS_KEY CLEARML_API_SECRET_KEY " "CLEARML_API_ACCESS_KEY CLEARML_API_SECRET_KEY' >> /etc/ssh/sshd_config && " 'echo "export VISIBLE=now" >> /etc/profile && ' 'echo "export PATH=$PATH" >> /etc/profile && ' 'echo "ldconfig" 2>/dev/null >> /etc/profile && ' 'echo "export CLEARML_CONFIG_FILE={trains_config_file}" >> /etc/profile'.format( password=ssh_password, port=port, trains_config_file=os.environ.get("CLEARML_CONFIG_FILE") or os.environ.get("TRAINS_CONFIG_FILE"), ) ) sshd_path = '/usr/sbin/sshd' ssh_config_path = '/etc/ssh/' custom_ssh_conf = None else: # check if sshd exists # noinspection PyBroadException try: os.system('echo "export CLEARML_CONFIG_FILE={trains_config_file}" >> $HOME/.profile'.format( trains_config_file=os.environ.get("CLEARML_CONFIG_FILE") or os.environ.get("TRAINS_CONFIG_FILE"), )) except Exception: print("warning failed setting ~/.profile") # check if shd is preinstalled # noinspection PyBroadException try: # try running SSHd as non-root (currently bypassed, use dropbear instead) sshd_path = None ## subprocess.check_output('which sshd', shell=True).decode().strip() if not sshd_path: raise ValueError("sshd was not found") except Exception: # noinspection PyBroadException try: print('WARNING: SSHd was not found defaulting to user-space dropbear sshd server') dropbear_download_link = \ os.environ.get("CLEARML_DROPBEAR_EXEC") or \ 'https://github.com/allegroai/dropbear/releases/download/DROPBEAR_CLEARML_2023.02/dropbearmulti' dropbear = StorageManager.get_local_copy(dropbear_download_link, extract_archive=False) os.chmod(dropbear, 0o744) sshd_path = dropbear use_dropbear = True except Exception: print('Error: failed locating SSHd and failed fetching `dropbear`, leaving!') return # noinspection PyBroadException try: ssh_config_path = os.path.join(os.getcwd(), '.clearml_session_sshd') # noinspection PyBroadException try: Path(ssh_config_path).mkdir(parents=True, exist_ok=True) except Exception: ssh_config_path = os.path.join(gettempdir(), '.clearml_session_sshd') Path(ssh_config_path).mkdir(parents=True, exist_ok=True) custom_ssh_conf = os.path.join(ssh_config_path, 'sshd_config') with open(custom_ssh_conf, 'wt') as f: conf = \ "PermitRootLogin yes" + "\n"\ "ClientAliveInterval 10" + "\n"\ "ClientAliveCountMax 20" + "\n"\ "AllowTcpForwarding yes" + "\n"\ "UsePAM yes" + "\n"\ "AuthorizedKeysFile {}".format(os.path.join(ssh_config_path, 'authorized_keys')) + "\n"\ "PidFile {}".format(os.path.join(ssh_config_path, 'sshd.pid')) + "\n"\ "AcceptEnv CLEARML_API_ACCESS_KEY CLEARML_API_SECRET_KEY "\ "CLEARML_API_ACCESS_KEY CLEARML_API_SECRET_KEY"+"\n" for k in default_ssh_fingerprint: filename = os.path.join(ssh_config_path, '{}'.format(k.replace('__pub', '.pub'))) conf += "HostKey {}\n".format(filename) f.write(conf) except Exception: print('Error: Cannot configure sshd, leaving!') return if not use_dropbear: # clear the ssh password, we cannot change it ssh_password = None task.set_parameter('{}/ssh_password'.format(config_section_name), '') # get current user: # noinspection PyBroadException try: current_user = getpass.getuser() or "root" except Exception: # we failed getting the user, let's assume root print("Warning: failed getting active user name, assuming 'root'") current_user = "root" # create fingerprint files Path(ssh_config_path).mkdir(parents=True, exist_ok=True) keys_filename = {} for k, v in default_ssh_fingerprint.items(): filename = os.path.join(ssh_config_path, '{}'.format(k.replace('__pub', '.pub'))) try: os.unlink(filename) except Exception: # noqa pass if v: with open(filename, 'wt') as f: f.write(v + (' {}@{}'.format(current_user, hostname) if filename.endswith('.pub') else '')) os.chmod(filename, 0o600 if filename.endswith('.pub') else 0o600) keys_filename[k] = filename # run server in foreground so it gets killed with us if use_dropbear: # convert key files dropbear_key_files = [] for k, ssh_key_file in keys_filename.items(): # skip over the public keys, there is no need for them if not ssh_key_file or ssh_key_file.endswith(".pub"): continue drop_key_file = ssh_key_file + ".dropbear" try: os.system("{} dropbearconvert openssh dropbear {} {}".format( sshd_path, ssh_key_file, drop_key_file)) if Path(drop_key_file).is_file(): dropbear_key_files += ["-r", drop_key_file] except Exception: pass proc_args = [sshd_path, "dropbear", "-e", "-K", "30", "-I", "0", "-F", "-p", str(port)] + dropbear_key_files # this is a copy of `env` so there is nothing to worry about if ssh_password: env["DROPBEAR_CLEARML_FIXED_PASSWORD"] = ssh_password else: proc_args = [sshd_path, "-D", "-p", str(port)] + (["-f", custom_ssh_conf] if custom_ssh_conf else []) proc = subprocess.Popen(args=proc_args, env=env) # noinspection PyBroadException try: result = proc.wait(timeout=1) except Exception: result = 0 if result != 0: raise ValueError("Failed launching sshd: ", proc_args) if proxy_port: # noinspection PyBroadException try: TcpProxy(listen_port=proxy_port, target_port=port, proxy_state={}, verbose=False, # noqa keep_connection=True, is_connection_server=True) except Exception as ex: print('Warning: Could not setup stable ssh port, {}'.format(ex)) proxy_port = None if task: if proxy_port: task.set_parameter(name='properties/internal_stable_ssh_port', value=str(proxy_port)) task.set_parameter(name='properties/internal_ssh_port', value=str(port)) # noinspection PyProtectedMember task._set_runtime_properties( runtime_properties={ 'internal_ssh_port': str(proxy_port or port), '_ssh_user': current_user, } ) print( "\n#\n# SSH Server running on {} [{}] port {}\n# LOGIN u:root p:{}\n#\n".format( hostname, hostnames, port, ssh_password ) ) except Exception as ex: print("Error: {}\n\n#\n# Error: SSH server could not be launched\n#\n".format(ex)) def _b64_decode_file(encoded_string): # noinspection PyBroadException try: import gzip value = gzip.decompress(base64.decodebytes(encoded_string.encode('ascii'))).decode('utf8') return value except Exception: return None def setup_user_env(param, task): env = setup_os_env(param) # apply vault if we have it vault_environment = {} if param.get("user_key") and param.get("user_secret"): # noinspection PyBroadException try: print('Applying vault configuration') from clearml.backend_api.session.defs import ENV_ENABLE_ENV_CONFIG_SECTION, ENV_ENABLE_FILES_CONFIG_SECTION prev_env, prev_files = ENV_ENABLE_ENV_CONFIG_SECTION.get(), ENV_ENABLE_FILES_CONFIG_SECTION.get() ENV_ENABLE_ENV_CONFIG_SECTION.set(True), ENV_ENABLE_FILES_CONFIG_SECTION.set(True) prev_envs = deepcopy(os.environ) Session(api_key=param.get("user_key"), secret_key=param.get("user_secret")) vault_environment = {k: v for k, v in os.environ.items() if prev_envs.get(k) != v} if prev_env is None: ENV_ENABLE_ENV_CONFIG_SECTION.pop() else: ENV_ENABLE_ENV_CONFIG_SECTION.set(prev_env) if prev_files is None: ENV_ENABLE_FILES_CONFIG_SECTION.pop() else: ENV_ENABLE_FILES_CONFIG_SECTION.set(prev_files) if vault_environment: print('Vault environment added: {}'.format(list(vault_environment.keys()))) except Exception as ex: print('Applying vault configuration failed: {}'.format(ex)) # do not change user bash/profile if we are not running inside a container if os.geteuid() != 0: # check if we are inside a container is_container = False try: with open("/proc/1/sched", "rt") as f: lines = f.readlines() if lines and lines[0].split()[0] in ("bash", "sh", "zsh"): # this a container is_container = True except Exception: # noqa pass if not is_container: if param.get("user_key") and param.get("user_secret"): env['CLEARML_API_ACCESS_KEY'] = param.get("user_key") env['CLEARML_API_SECRET_KEY'] = param.get("user_secret") return env # target source config source_conf = '~/.clearmlrc' # create symbolic link to the venv environment = os.path.expanduser('~/environment') # noinspection PyBroadException try: os.symlink(os.path.abspath(os.path.join(os.path.abspath(sys.executable), '..', '..')), environment) print('Virtual environment are available at {}'.format(environment)) except Exception as e: print("Error: Exception while trying to create symlink. The Application will continue...") print(e) # set default user credentials if param.get("user_key") and param.get("user_secret"): os.system("echo 'export CLEARML_API_ACCESS_KEY=\"{}\"' >> {}".format( param.get("user_key", "").replace('$', '\\$'), source_conf)) os.system("echo 'export CLEARML_API_SECRET_KEY=\"{}\"' >> {}".format( param.get("user_secret", "").replace('$', '\\$'), source_conf)) env['CLEARML_API_ACCESS_KEY'] = param.get("user_key") env['CLEARML_API_SECRET_KEY'] = param.get("user_secret") elif os.environ.get("CLEARML_AUTH_TOKEN"): env['CLEARML_AUTH_TOKEN'] = os.environ.get("CLEARML_AUTH_TOKEN") os.system("echo 'export CLEARML_AUTH_TOKEN=\"{}\"' >> {}".format( os.environ.get("CLEARML_AUTH_TOKEN").replace('$', '\\$'), source_conf)) if param.get("default_docker"): os.system("echo 'export CLEARML_DOCKER_IMAGE=\"{}\"' >> {}".format( param.get("default_docker", "").strip() or env.get('CLEARML_DOCKER_IMAGE', ''), source_conf)) if vault_environment: for k, v in vault_environment.items(): os.system("echo 'export {}=\"{}\"' >> {}".format(k, v, source_conf)) env[k] = str(v) if v else "" # make sure we activate the venv in the bash if Path(os.path.join(environment, 'bin', 'activate')).expanduser().exists(): os.system("echo 'source {}' >> {}".format(os.path.join(environment, 'bin', 'activate'), source_conf)) elif Path(os.path.join(environment, 'etc', 'conda', 'activate.d')).expanduser().exists(): # let conda patch the bashrc os.system("conda init") # make sure we activate this environment by default os.system("echo 'conda activate {}' >> {}".format(environment, source_conf)) # set default folder for user if param.get("user_base_directory"): base_dir = param.get("user_base_directory") if ' ' in base_dir: base_dir = '\"{}\"'.format(base_dir) os.system("echo 'cd {}' >> {}".format(base_dir, source_conf)) # make sure we load the source configuration os.system("echo 'source {}' >> ~/.bashrc".format(source_conf)) os.system("echo '. {}' >> ~/.profile".format(source_conf)) # check if we need to create .git-credentials runtime_property_support = Session.check_min_api_version("2.13") if runtime_property_support: # noinspection PyProtectedMember runtime_prop = dict(task._get_runtime_properties()) git_credentials = runtime_prop.pop('_git_credentials', None) git_config = runtime_prop.pop('_git_config', None) # force removing properties # noinspection PyProtectedMember task._edit(runtime=runtime_prop) task.reload() if git_credentials is not None: git_credentials = _b64_decode_file(git_credentials) if git_config is not None: git_config = _b64_decode_file(git_config) else: # noinspection PyProtectedMember git_credentials = task._get_configuration_text('git_credentials') # noinspection PyProtectedMember git_config = task._get_configuration_text('git_config') if git_credentials: git_cred_file = os.path.expanduser('~/.config/git/credentials') # noinspection PyBroadException try: Path(git_cred_file).parent.mkdir(parents=True, exist_ok=True) with open(git_cred_file, 'wt') as f: f.write(git_credentials) except Exception: print('Could not write {} file'.format(git_cred_file)) if git_config: git_config_file = os.path.expanduser('~/.config/git/config') # noinspection PyBroadException try: Path(git_config_file).parent.mkdir(parents=True, exist_ok=True) with open(git_config_file, 'wt') as f: f.write(git_config) except Exception: print('Could not write {} file'.format(git_config_file)) # check if we need to retrieve remote files for the session if "session-files" in task.artifacts: try: target_dir = os.path.expanduser("~/session-files/") cached_files_folder = task.artifacts["session-files"].get_local_copy( extract_archive=True, force_download=True, raise_on_error=True) # noinspection PyBroadException try: # first try a simple, move, if we fail, copy and delete os.replace(cached_files_folder, target_dir) except Exception: import shutil Path(target_dir).mkdir(parents=True, exist_ok=True) if Path(cached_files_folder).is_dir(): shutil.copytree( src=cached_files_folder, dst=target_dir, symlinks=True, ignore_dangling_symlinks=True, dirs_exist_ok=True) shutil.rmtree(cached_files_folder) else: target_file = Path(cached_files_folder).name # we need to remove the taskid prefix from the cache folder target_file = (Path(target_dir) / (".".join(target_file.split(".")[1:]))).as_posix() shutil.copy(cached_files_folder, target_file, follow_symlinks=False) os.unlink(cached_files_folder) except Exception as ex: print("\nWARNING: Failed downloading remote session files! {}\n".format(ex)) return env def get_host_name(task, param): # noinspection PyBroadException try: hostname = socket.gethostname() hostnames = socket.gethostbyname(socket.gethostname()) except Exception: def get_ip_addresses(family): for interface, snics in psutil.net_if_addrs().items(): for snic in snics: if snic.family == family: yield snic.address hostnames = list(get_ip_addresses(socket.AF_INET))[0] hostname = hostnames # try to get external address (if possible) # noinspection PyBroadException try: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # noinspection PyBroadException try: # doesn't even have to be reachable s.connect(('8.255.255.255', 1)) hostnames = s.getsockname()[0] except Exception: pass finally: s.close() except Exception: pass # update host name if (not task.get_parameter(name='properties/external_address') and not task.get_parameter(name='properties/k8s-gateway-address')): if task._get_runtime_properties().get("external_address"): external_addr = task._get_runtime_properties().get("external_address") else: external_addr = hostnames if param.get('public_ip'): # noinspection PyBroadException try: external_addr = requests.get('https://checkip.amazonaws.com').text.strip() except Exception: pass # make sure we set it to the runtime properties task._set_runtime_properties({"external_address": external_addr}) # make sure we set it back to the Task user properties task.set_parameter(name='properties/external_address', value=str(external_addr)) return hostname, hostnames def run_user_init_script(task): # run initialization script: # noinspection PyProtectedMember init_script = task._get_configuration_text(config_object_section_bash_init) if not init_script or not str(init_script).strip(): return print("Running user initialization bash script:") init_filename = os_json_filename = None try: fd, init_filename = mkstemp(suffix='.init.sh') os.close(fd) fd, os_json_filename = mkstemp(suffix='.env.json') os.close(fd) with open(init_filename, 'wt') as f: f.write(init_script + '\n{} -c ' '"exec(\\"try:\\n import os\\n import json\\n' ' json.dump(dict(os.environ), open(\\\'{}\\\', \\\'w\\\'))' '\\nexcept: pass\\")"'.format(sys.executable, os_json_filename)) env = dict(**os.environ) # do not pass or update back the PYTHONPATH environment variable env.pop('PYTHONPATH', None) subprocess.call(['/bin/bash', init_filename], env=env) with open(os_json_filename, 'rt') as f: environ = json.load(f) # do not pass or update back the PYTHONPATH environment variable environ.pop('PYTHONPATH', None) # update environment variables os.environ.update(environ) except Exception as ex: print('User initialization script failed: {}'.format(ex)) finally: if init_filename: try: os.unlink(init_filename) except: # noqa pass if os_json_filename: try: os.unlink(os_json_filename) except: # noqa pass os.environ['CLEARML_DOCKER_BASH_SCRIPT'] = str(init_script) def _sync_workspace_snapshot(task, param): workspace_folder = param.get("store_workspace") if not workspace_folder: # nothing to do return print("Syncing workspace {}".format(workspace_folder)) workspace_folder = Path(os.path.expandvars(workspace_folder)).expanduser() if not workspace_folder.is_dir(): print("WARNING: failed to create workspace snapshot from '{}' - " "directory does not exist".format(workspace_folder)) return # build hash of files_desc = "" for f in workspace_folder.rglob("*"): fs = f.stat() files_desc += "{}: {}[{}]\n".format(f.absolute(), fs.st_size, fs.st_mtime) workspace_hash = hash(str(files_desc)) if param.get("workspace_hash") == workspace_hash: print("Skipping workspace snapshot upload, " "already uploaded no files changed since last sync {}".format(param.get(sync_runtime_property))) return print("Uploading workspace: {}".format(workspace_folder)) # force running status - so that we can upload the artifact if task.status not in ("in_progress", ): task.mark_started(force=True) try: # create a tar file of the folder # put a consistent file name into a temp folder because the filename is part of # the compressed artifact, and we want consistency in hash. # After that we rename compressed file to temp file and temp_folder = Path(mkdtemp(prefix='workspace_')) local_gzip = (temp_folder / "workspace_snapshot").as_posix() # notice it will add a ".tar.gz" suffix to the file local_gzip = shutil.make_archive( base_name=local_gzip, format="gztar", root_dir=workspace_folder.as_posix()) if not local_gzip: print("ERROR: Failed compressing workspace [{}]".format(workspace_folder)) raise ValueError("Failed compressing workspace") # list archived files for preview files = list(workspace_folder.rglob("*")) archive_preview = 'Archive content {}:\n'.format(workspace_folder) for filename in sorted(files): if filename.is_file(): relative_file_name = filename.relative_to(workspace_folder) archive_preview += '{} - {:,} B\n'.format(relative_file_name, filename.stat().st_size) # upload actual snapshot tgz task.upload_artifact( name=artifact_workspace_name, artifact_object=Path(local_gzip), delete_after_upload=True, preview=archive_preview, metadata={"timestamp": str(datetime.utcnow()), sync_workspace_creating_id: task.id}, wait_on_upload=True, retries=3 ) try: temp_folder.rmdir() except Exception as ex: print("Warning: Failed removing temp artifact folder: {}".format(ex)) print("Finalizing workspace sync") # change artifact to input artifact task.reload() # find our artifact and update it for a in task.data.execution.artifacts: if a.key != artifact_workspace_name: # nothing to do continue elif a.mode == tasks.ArtifactModeEnum.input: # the old input entry - we are changing to output artifact # the reason is that we want this entry to be deleted with this Task # in contrast to Input entries that are Not deleted when deleting the Task a.mode = tasks.ArtifactModeEnum.output a.key = "old_" + str(a.key) else: # set the new entry as an input artifact a.mode = tasks.ArtifactModeEnum.input # noinspection PyProtectedMember task._edit(execution=task.data.execution, force=True) task.reload() # update our timestamp & hash param[sync_runtime_property] = time() param["workspace_hash"] = workspace_hash # noinspection PyProtectedMember task._set_runtime_properties(runtime_properties={sync_runtime_property: time()}) print("[{}] Workspace '{}' snapshot synced".format(datetime.utcnow(), workspace_folder)) except Exception as ex: print("ERROR: Failed syncing workspace [{}]: {}".format(workspace_folder, ex)) finally: task.mark_stopped(force=True, status_message="workspace shutdown sync completed") def sync_workspace_snapshot(task, param): __poor_lock.append(time()) if len(__poor_lock) != 1: # someone is already in, we should leave __poor_lock.pop(-1) try: return _sync_workspace_snapshot(task, param) finally: __poor_lock.pop(-1) def restore_workspace(task, param): if not param.get("store_workspace"): # check if we have something to restore, show warning if artifact_workspace_name in task.artifacts: print("WARNING: Found workspace snapshot, but ignoring since store_workspace is 'None'") return # add sync callback, timeout 5 min print("Setting workspace snapshot sync callback on session end") task.register_abort_callback( partial(sync_workspace_snapshot, task, param), callback_execution_timeout=60*5) try: workspace_folder = Path(os.path.expandvars(param.get("store_workspace"))).expanduser() workspace_folder.mkdir(parents=True, exist_ok=True) except Exception as ex: print("ERROR: Could not create workspace folder {}: {}".format( param.get("store_workspace"), ex)) return if artifact_workspace_name not in task.artifacts: print("No workspace snapshot was found, a new workspace snapshot [{}] " "will be created when session ends".format(workspace_folder)) return print("Fetching previous workspace snapshot") artifact_zip_file = task.artifacts[artifact_workspace_name].get_local_copy(extract_archive=False) print("Restoring workspace snapshot") try: shutil.unpack_archive(artifact_zip_file, extract_dir=workspace_folder.as_posix()) except Exception as ex: print("ERROR: restoring workspace snapshot failed: {}".format(ex)) return # remove the workspace from the cache try: os.unlink(artifact_zip_file) except Exception as ex: print("WARNING: temp workspace zip could not be removed: {}".format(ex)) print("Successfully restored workspace checkpoint to {}".format(workspace_folder)) # set time stamp # noinspection PyProtectedMember task._set_runtime_properties(runtime_properties={sync_runtime_property: time()}) def main(): param = { "user_base_directory": "~/", "ssh_server": True, "ssh_password": "training", "default_docker": "nvidia/cuda", "user_key": None, "user_secret": None, "vscode_server": True, "vscode_version": '', "vscode_extensions": '', "jupyterlab": True, "public_ip": False, "ssh_ports": None, "force_dropbear": False, "store_workspace": None, "use_ssh_proxy": False, } task = init_task(param, default_ssh_fingerprint) run_user_init_script(task) # restore workspace if exists # notice, if "store_workspace" is not set we will Not restore the workspace try: restore_workspace(task, param) except Exception as ex: print("ERROR: Failed restoring workspace: {}".format(ex)) hostname, hostnames = get_host_name(task, param) env = setup_user_env(param, task) setup_ssh_server(hostname, hostnames, param, task, env) start_vscode_server(hostname, hostnames, param, task, env) start_jupyter_server(hostname, hostnames, param, task, env) print('We are done - sync workspace if needed') sync_workspace_snapshot(task, param) print('Goodbye') if __name__ == '__main__': main()