From 30f49a65c53c2be4df172d1331a9d679ae9ba4ea Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 22 Dec 2020 21:32:02 +0200 Subject: [PATCH] Initial release --- .gitignore | 147 +--- README.md | 294 ++++++++ clearml_session/__init__.py | 0 clearml_session/__main__.py | 772 ++++++++++++++++++++ clearml_session/interactive_session_task.py | 616 ++++++++++++++++ clearml_session/single_thread_proxy.py | 126 ++++ clearml_session/tcp_proxy.py | 359 +++++++++ clearml_session/version.py | 1 + requirements.txt | 3 + setup.py | 79 ++ 10 files changed, 2284 insertions(+), 113 deletions(-) create mode 100644 README.md create mode 100644 clearml_session/__init__.py create mode 100644 clearml_session/__main__.py create mode 100644 clearml_session/interactive_session_task.py create mode 100644 clearml_session/single_thread_proxy.py create mode 100644 clearml_session/tcp_proxy.py create mode 100644 clearml_session/version.py create mode 100644 requirements.txt create mode 100644 setup.py diff --git a/.gitignore b/.gitignore index b6e4761..b24d71e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,129 +1,50 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class +# These are some examples of commonly ignored file patterns. +# You should customize this list as applicable to your project. +# Learn more about .gitignore: +# https://www.atlassian.com/git/tutorials/saving-changes/gitignore -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ +# Node artifact files +node_modules/ dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -pip-wheel-metadata/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec +# Compiled Java class files +*.class -# Installer logs -pip-log.txt -pip-delete-this-directory.txt +# Compiled Python bytecode +*.py[cod] -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ - -# Translations -*.mo -*.pot - -# Django stuff: +# Log files *.log -local_settings.py -db.sqlite3 -db.sqlite3-journal -# Flask stuff: -instance/ -.webassets-cache +# Package files +*.jar -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder +# Maven target/ +dist/ -# Jupyter Notebook -.ipynb_checkpoints +# JetBrains IDE +.idea/ -# IPython -profile_default/ -ipython_config.py +# Unit test reports +TEST*.xml -# pyenv -.python-version +# Generated by MacOS +.DS_Store -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock +# Generated by Windows +Thumbs.db -# PEP 582; used by e.g. github.com/David-OConnor/pyflow -__pypackages__/ +# Applications +*.app +*.exe +*.war -# Celery stuff -celerybeat-schedule -celerybeat.pid +# Large media files +*.mp4 +*.tiff +*.avi +*.flv +*.mov +*.wmv -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..1f7f6ea --- /dev/null +++ b/README.md @@ -0,0 +1,294 @@ +# `clearml-session` - CLI for launching JupyterLab / VSCode on a remote machine + +**`clearml-session`** is a utility for launching detachable remote interactive sessions (MacOS, Windows, Linux) + +### tl;dr +CLI to launch remote sessions for JupyterLab / VSCode-server / SSH, inside any docker image! + +### What does it do? +Starting a clearml (ob)session from your local machine triggers the following: +- ClearML allocates a remote instance (GPU) from your dedicated pool +- On the allocated instance it will spin **jupyter-lab** + **vscode server** + **SSH** access for +interactive usage (i.e., development) +- Clearml will start monitoring machine performance, allowing DevOps to detect stale instances and spin them down + +### Use-cases for remote interactive sessions: +1. Development requires resources not available on the current developer's machines +2. Team resource sharing (e.g. how to dynamically assign GPUs to developers) +3. Spin a copy of a previously executed experiment for remote debugging purposes (:open_mouth:!) +4. Scale-out development to multiple clouds, assign development machines on AWS/GCP/Azure in a seamless way + +## Prerequisites: +* **An SSH client installed on your machine** - To verify open your terminal and execute `ssh`, if you did not receive an error, we are good to go. +* At least one `clearml-agent` running on a remote host. See installation [details](https://github.com/allegroai/clearml-agent). + +Supported OS: MacOS, Windows, Linux + + +## Secure & Stable +**clearml-session** creates a single, secure, and encrypted connection to the remote machine over SSH. +SSH credentials are automatically generated by the CLI and contain fully random 32 bytes password. + +All http connections are tunneled over the SSH connection, +allowing users to add additional services on the remote machine (!) + +Furthermore, all tunneled connections have a special stable network layer allowing you to refresh the underlying SSH +connection without breaking any network sockets! + +This means that if the network connection is unstable, you can refresh +the base SSH network tunnel, without breaking JupyterLab/VSCode-server or your own SSH connection +(e.h. debugging over SSH with PyCharm) + +--- + +## How to use: Interactive Session + + +1. run `clearml-session` +2. select the requested queue (resource) +3. wait until a machine is up and ready +4. click on the link to the remote JupyterLab/VSCode OR connect with the provided SSH details + +**Notice! You can also**: Select a **docker image** to execute in, install required **python packages**, run **bash script**, +pass **git credentials**, etc. +See below for full CLI options. + +## Frequently Asked Questions: + +#### How Does Clearml enable this? + +The `clearml-session` creates a new interactive `Task` in the system (default project: DevOps). + +This `Task` is responsible for setting the SSH and JupyterLab/VSCode on the host machine. + +The local `clearml-session` awaits for the interactive Task to finish with the initial setup, then +it connects via SSH to the host machine (see "safe and stable" above), and tunnels +both SSH and JupyterLab over the SSH connection. + +The end results is a local link which you can use to access the JupyterLab/VSCode on the remote machine, over a **secure and encrypted** connection! + +#### How can this be used to scale up/out development resources? + +**Clearml** has a cloud autoscaler, so you can easily and automatically spin machines for development! + +There is also a default docker image to use when initiating a task. + +This means that using **clearml-session**s +with the autoscaler enabled, allows for turn-key secure development environment inside a docker of your choosing. + +Learn more about it [here]() + +#### Does this fit Work From Home situations? +**YES**. Install `clearml-agent` on target machines inside the organization, connect over your company VPN +and use `clearml-session` to gain access to a dedicated on-prem machine with the docker of your choosing +(with out-of-the-box support for any internal docker artifactory). + +Learn more about how to utilize your office workstations and on-prem machines [here]() + +## Tutorials + +### Getting started + +Requirements `clearml` python package installed and configured (see detailed [instructions]()) +``` bash +pip install clearml-session +clearml-session --docker nvcr.io/nvidia/pytorch:20.11-py3 --git-credentilas +``` + +Wait for the machine to spin up: +Expected CLI output would look something like: +``` console +Creating new session +New session created [id=3d38e738c5ff458a9ec465e77e19da23] +Waiting for remote machine allocation [id=3d38e738c5ff458a9ec465e77e19da23] +.Status [queued] +....Status [in_progress] +Remote machine allocated +Setting remote environment [Task id=3d38e738c5ff458a9ec465e77e19da23] +Setup process details: https://app.community.clear.ml/projects/64ae77968db24b27abf86a501667c330/experiments/3d38e738c5ff458a9ec465e77e19da23/output/log +Waiting for environment setup to complete [usually about 20-30 seconds] +.............. +Remote machine is ready +Setting up connection to remote session +Starting SSH tunnel +Warning: Permanently added '[192.168.0.17]:10022' (ECDSA) to the list of known hosts. +root@192.168.0.17's password: f7bae03235ff2a62b6bfbc6ab9479f9e28640a068b1208b63f60cb097b3a1784 + + +Interactive session is running: +SSH: ssh root@localhost -p 8022 [password: f7bae03235ff2a62b6bfbc6ab9479f9e28640a068b1208b63f60cb097b3a1784] +Jupyter Lab URL: http://localhost:8878/?token=df52806d36ad30738117937507b213ac14ed638b8c336a7e +VSCode server available at http://localhost:8898/ + +Connection is up and running +Enter "r" (or "reconnect") to reconnect the session (for example after suspend) +Ctrl-C (or "quit") to abort (remote session remains active) +or "Shutdown" to shutdown remote interactive session +``` + +Click on the JupyterLab link (http://localhost:8878/?token=xyz) +Open your terminal, clone your code & start working :) + +### Leaving a session and reconnecting from the same machine + +On the `clearml-session` CLI terminal, enter 'quit' or press Ctrl-C +It will close the CLI but leaves the remote session running + +When you want to reconnect to it, execute: +``` bash +clearml-session +``` + +Then press "Y" (or enter) to reconnect to the already running session +``` console +clearml-session - launch interactive session +Checking previous session +Connect to active session id=3d38e738c5ff458a9ec465e77e19da23 [Y]/n? +``` + +### Shutting down a remote session + +On the `clearml-session` CLI terminal, enter 'shutdown' (case-insensitive) +It will shut down the remote session, free the resource and close the CLI + +``` console +Enter "r" (or "reconnect") to reconnect the session (for example after suspend) +Ctrl-C (or "quit") to abort (remote session rema +Yes of course, current SSO supports Google/GitHub/BitBucket/... + SAML/LDAP (Usually with user permissions fully integrated to the LDAP) +ins active) +or "Shutdown" to shutdown remote interactive session + +shutdown + +Shutting down interactive session +Interactive session ended +Leaving interactive session +``` + +### Connecting to a running interactive session from a different machine + +Continue working on an interactive session from **any** machine. +In the `clearml` web UI, go to DevOps project, and find your interactive session. +Click on the ID button next to the Task name, and copy the unique ID. + +``` bash +clearml-session --attach +``` + +Click on the JupyterLab/VSCode link, or connect directly to the SSH session + +### Debug a previously executed experiment + +If you have a previously executed experiment in the system, +you can create an exact copy of the experiment and debug it on the remote interactive session. +`clearml-session` will replicate the exact remote environment, add JupyterLab/VSCode/SSH and allow you interactively +execute and debug the experiment, on the allocated remote machine. + +In the `clearml` web UI, find the experiment (Task) you wish to debug. +Click on the ID button next to the Task name, and copy the unique ID. + +``` bash +clearml-session --debugging +``` + +Click on the JupyterLab/VSCode link, or connect directly to the SSH session + +## CLI options + +``` bash +clearml-session --help +``` + +``` console +clearml-session - CLI for launching JupyterLab / VSCode on a remote machine +usage: clearml-session [-h] [--version] [--attach [ATTACH]] + [--debugging DEBUGGING] [--queue QUEUE] + [--docker DOCKER] [--public-ip [true/false]] + [--vscode-server [true/false]] + [--git-credentials [true/false]] + [--user-folder USER_FOLDER] + [--packages [PACKAGES [PACKAGES ...]]] + [--requirements REQUIREMENTS] + [--init-script [INIT_SCRIPT]] + [--config-file CONFIG_FILE] + [--remote-gateway [REMOTE_GATEWAY]] + [--base-task-id BASE_TASK_ID] [--project PROJECT] + [--disable-keepalive] + [--queue-excluded-tag [QUEUE_EXCLUDED_TAG [QUEUE_EXCLUDED_TAG ...]]] + [--queue-include-tag [QUEUE_INCLUDE_TAG [QUEUE_INCLUDE_TAG ...]]] + [--skip-docker-network] [--password PASSWORD] + +clearml-session - CLI for launching JupyterLab / VSCode on a remote machine + +optional arguments: + -h, --help show this help message and exit + --version Display the clearml-session utility version + --attach [ATTACH] Attach to running interactive session (default: + previous session) + --debugging DEBUGGING + Pass existing Task id (experiment), create a copy of + the experiment on a remote machine, and launch + jupyter/ssh for interactive access. Example + --debugging + --queue QUEUE Select the queue to launch the interactive session on + (default: previously used queue) + --docker DOCKER Select the docker image to use in the interactive + session on (default: previously used docker image or + `nvidia/cuda:10.1-runtime-ubuntu18.04`) + --public-ip [true/false] + If True register the public IP of the remote machine. + Set if running on the cloud. Default: false (use for + local / on-premises) + --vscode-server [true/false] + Installing vscode server (code-server) on interactive + session (default: true) + --git-credentials [true/false] + If true, local .git-credentials file is sent to the + interactive session. (default: false) + --user-folder USER_FOLDER + Advanced: Set the remote base folder (default: ~/) + --packages [PACKAGES [PACKAGES ...]] + Additional packages to add, supports version numbers + (default: previously added packages). examples: + --packages torch==1.7 tqdm + --requirements REQUIREMENTS + Specify requirements.txt file to install when setting + the interactive session. Requirements file is read and + stored in `packages` section as default for the next + sessions. Can be overridden by calling `--packages` + --init-script [INIT_SCRIPT] + Specify BASH init script file to be executed when + setting the interactive session. Script content is + read and stored as default script for the next + sessions. To clear the init-script do not pass a file + --config-file CONFIG_FILE + Advanced: Change the configuration file used to store + the previous state (default: ~/.clearml_session.json + --remote-gateway [REMOTE_GATEWAY] + Advanced: Specify gateway ip/address to be passed to + interactive session (for use with k8s ingestion / ELB + --base-task-id BASE_TASK_ID + Advanced: Set the base task ID for the interactive + session. (default: previously used Task). Use `none` + for the default interactive session + --project PROJECT Advanced: Set the project name for the interactive + session Task + --disable-keepalive Advanced: If set, disable the transparent proxy always + keeping the sockets alive. Default: false, use + transparent socket mitigating connection drops. + --queue-excluded-tag [QUEUE_EXCLUDED_TAG [QUEUE_EXCLUDED_TAG ...]] + Advanced: Excluded queues with this specific tag from + the selection + --queue-include-tag [QUEUE_INCLUDE_TAG [QUEUE_INCLUDE_TAG ...]] + Advanced: Only include queues with this specific tag + from the selection + --skip-docker-network + Advanced: If set, `--network host` is **not** passed + to docker (assumes k8s network ingestion) (default: + false) + --password PASSWORD Advanced: Select ssh password for the interactive + session (default: previously used one) + +Notice! all arguments are stored as new defaults for the next session + +``` diff --git a/clearml_session/__init__.py b/clearml_session/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/clearml_session/__main__.py b/clearml_session/__main__.py new file mode 100644 index 0000000..3eef0fd --- /dev/null +++ b/clearml_session/__main__.py @@ -0,0 +1,772 @@ +import hashlib +import json +import logging +import os +import subprocess +import sys +from argparse import ArgumentParser, FileType +from functools import reduce +from io import TextIOBase +from time import time, sleep + +if sys.platform == 'win32': + import msvcrt # noqa + import wexpect as pexpect # noqa +else: + import select # noqa + import pexpect # noqa + +import psutil +from clearml import Task +from clearml.backend_api.session.client import APIClient +from clearml.config import config_obj +from .tcp_proxy import TcpProxy +from .single_thread_proxy import SingleThreadProxy + + +system_tag = 'interactive' +default_docker_image = 'nvidia/cuda:10.1-runtime-ubuntu18.04' + + +def _read_std_input(timeout): + # wait for user input with timeout, return None if timeout or user input + if sys.platform == 'win32': + start_time = time() + input_str = '' + while True: + if msvcrt.kbhit(): + char = msvcrt.getche() + if ord(char) == 13: # enter_key + print('') + return input_str.strip() + input_str += char.decode() + if len(input_str) == 0 and (time() - start_time) > timeout: + return None + else: + i, o, e = select.select([sys.stdin], [], [], timeout) + if not i: + return None + line = sys.stdin.readline().strip() + # flush stdin buffer + while i: + i, o, e = select.select([sys.stdin], [], [], 0) + if i: + sys.stdin.readline() + return line + + +def _get_config_section_name(): + org_path = [p for p in sys.path] + # noinspection PyBroadException + try: + sys.path.append(os.path.abspath(os.path.join(__file__, '..',))) + from interactive_session_task import ( # noqa + config_section_name, config_object_section_ssh, config_object_section_bash_init) # noqa + return config_section_name, config_object_section_ssh, config_object_section_bash_init + except Exception: + return None, None, None + finally: + sys.path = org_path + + +def _check_ssh_executable(): + # check Windows 32bit version is not supported + if sys.platform == 'win32' and getattr(sys, 'winver', '').endswith('-32'): + raise ValueError("Python 32-bit version detected. Only Python 64-bit is supported!") + + # noinspection PyBroadException + try: + if sys.platform == 'win32': + ssh_exec = subprocess.check_output('where ssh.exe'.split()).decode('utf-8').split('\n')[0].strip() + else: + ssh_exec = subprocess.check_output('which ssh'.split()).decode('utf-8').split('\n')[0].strip() + return ssh_exec + except Exception: + return None + + +def create_base_task(state, project_name=None, task_name=None): + task = Task.create(project_name=project_name or 'DevOps', + task_name=task_name or 'Interactive Session', + task_type=Task.TaskTypes.application) + task_state = task.export_task() + base_script_file = os.path.abspath(os.path.join(__file__, '..', 'tcp_proxy.py')) + with open(base_script_file, 'rt') as f: + task_state['script']['diff'] = f.read() + base_script_file = os.path.abspath(os.path.join(__file__, '..', 'interactive_session_task.py')) + with open(base_script_file, 'rt') as f: + task_state['script']['diff'] += '\n\n' + f.read() + + task_state['script']['working_dir'] = '.' + task_state['script']['entry_point'] = 'interactive_session.py' + task_state['script']['requirements'] = {'pip': '\n'.join( + ["clearml", "jupyter", "jupyterlab", "jupyterlab_git"] + + (['pylint'] if state.get('vscode_server') else []))} + task.update_task(task_state) + section, _, _ = _get_config_section_name() + task.set_parameters({ + "{}/user_base_directory".format(section): "~/", + "{}/ssh_server".format(section): True, + "{}/ssh_password".format(section): "training", + "{}/default_docker".format(section): "nvidia/cuda", + "{}/user_key".format(section): '', + "{}/user_secret".format(section): '', + "properties/external_address": '', + "properties/internal_ssh_port": 10022, + "properties/jupyter_token": '', + "properties/jupyter_port": '', + }) + task.set_system_tags([system_tag]) + task.reset(force=True) + return task + + +def create_debugging_task(state, debug_task_id): + debug_task = Task.get_task(task_id=debug_task_id) + # if there is no git repository, we cannot debug it + if not debug_task.data.script.repository: + raise ValueError("Debugging task has no git repository, single script debugging is not supported.") + + task = Task.clone(source_task=debug_task_id, parent=debug_task_id) + + task_state = task.export_task() + + base_script_file = os.path.abspath(os.path.join(__file__, '..', 'interactive_session_task.py')) + with open(base_script_file, 'rt') as f: + entry_diff = ['+'+line.rstrip() for line in f.readlines()] + entry_diff_header = \ + "diff --git a/__interactive_session__.py b/__interactive_session__.py\n" \ + "--- a/__interactive_session__.py\n" \ + "+++ b/__interactive_session__.py\n" \ + "@@ -0,0 +1,{} @@\n".format(len(entry_diff)) + + task_state['script']['diff'] = \ + entry_diff_header + '\n'.join(entry_diff) + '\n' + (task_state['script']['diff'] or '') + task_state['script']['working_dir'] = '.' + task_state['script']['entry_point'] = '__interactive_session__.py' + state['packages'] = (state.get('packages') or []) + ["clearml", "jupyter", "jupyterlab", "jupyterlab_git"] + ( + ['pylint'] if state.get('vscode_server') else []) + task.update_task(task_state) + section, _, _ = _get_config_section_name() + task.set_parameters({ + "{}/user_base_directory".format(section): "~/", + "{}/ssh_server".format(section): True, + "{}/ssh_password".format(section): "training", + "{}/default_docker".format(section): "nvidia/cuda", + "{}/user_key".format(section): '', + "{}/user_secret".format(section): '', + "properties/external_address": '', + "properties/internal_ssh_port": 10022, + "properties/jupyter_token": '', + "properties/jupyter_port": '', + }) + task.set_system_tags([system_tag]) + task.reset(force=True) + return task + + +def delete_old_tasks(client, base_task_id): + print('Removing stale interactive sessions') + res = client.session.send_request(service='users', action='get_current_user', async_enable=False) + assert res.ok + current_user_id = res.json()['data']['user']['id'] + previous_tasks = client.tasks.get_all(**{ + 'status': ['failed', 'stopped', 'completed'], + 'parent': base_task_id or None, + 'system_tags': None if base_task_id else [system_tag], + 'page_size': 100, 'page': 0, + 'user': [current_user_id], 'only_fields': ['id'] + }) + for t in previous_tasks: + try: + client.tasks.delete(task=t.id, force=True) + except Exception as ex: + logging.getLogger().warning('{}\nFailed deleting old session {}'.format(ex, t.id)) + + +def get_project_id(state): + project_id = None + project_name = state.get('project') or None + if project_name: + projects = Task.get_projects() + project_id = [p for p in projects if p.name == project_name] + if project_id: + project_id = project_id[0] + else: + logging.getLogger().warning("could not locate project by the named '{}'".format(project_name)) + project_id = None + return project_id + + +def get_user_inputs(args, parser, state, client): + default_needed_args = tuple() + + user_args = sorted([a for a in args.__dict__ if not a.startswith('_')]) + # clear some states if we replace the base_task_id + if 'base_task_id' in user_args and getattr(args, 'base_task_id', None) != state.get('base_task_id'): + print('New base_task_id \'{}\', clearing previous packages & init_script'.format( + getattr(args, 'base_task_id', None))) + state.pop('init_script', None) + state.pop('packages', None) + state.pop('base_task_id', None) + + if str(getattr(args, 'base_task_id', '')).lower() == 'none': + args.base_task_id = None + state['base_task_id'] = None + + for a in user_args: + v = getattr(args, a, None) + if a in ('requirements', 'packages', 'attach', 'config_file'): + continue + if isinstance(v, TextIOBase): + state[a] = v.read() + elif not v and a == 'init_script': + if v is None: + state[a] = '' + else: + pass # keep as is + elif not v and a == 'remote_gateway': + state.pop(a, None) + elif v is not None: + state[a] = v + + if a in default_needed_args and not state.get(a): + # noinspection PyProtectedMember + state[a] = input( + "\nCould not locate previously used value of '{}', please provide it?" + "\n Help: {}\n> ".format( + a, parser._option_string_actions['--{}'.format(a.replace('_', '-'))].help)) + # if no password was set, create a new random one + if not state.get('password'): + state['password'] = hashlib.sha256("seed me Seymour {}".format(time()).encode()).hexdigest() + + # store the requirements from the requirements.txt + # override previous requirements + if args.requirements: + state['packages'] = (args.packages or []) + [ + p.strip() for p in args.requirements.readlines() if not p.strip().startswith('#')] + elif args.packages is not None: + state['packages'] = args.packages or [] + + # allow to select queue + ask_queues = not state.get('queue') + if state.get('queue'): + choice = input('Use previous queue (resource) \'{}\' [Y]/n? '.format(state['queue'])) + if choice in ('n', 'N', 'no', 'No', 'NO'): + ask_queues = True + if ask_queues: + print('Select the queue (resource) you request:') + queues = sorted([q.name for q in client.queues.get_all( + system_tags=['-{}'.format(t) for t in state.get('queue_excluded_tag', ['internal'])] + + ['{}'.format(t) for t in state.get('queue_include_tag', [])])]) + queues_list = '\n'.join('{}] {}'.format(i, q) for i, q in enumerate(queues)) + while True: + try: + choice = int(input(queues_list+'\nSelect a queue [0-{}] '.format(len(queues)))) + break + except (TypeError, ValueError): + pass + state['queue'] = queues[int(choice)] + + print("\nInteractive session config:\n{}\n".format( + json.dumps({k: v for k, v in state.items() if not str(k).startswith('__')}, indent=4, sort_keys=True))) + + choice = input('Launch interactive session [Y]/n? ') + if choice in ('n', 'N', 'no', 'No', 'NO'): + print('User aborted') + exit(0) + + return state + + +def save_state(state, state_file): + # if we are running in debugging mode, + # only store the current task (do not change the defaults) + if state.get('debugging'): + # noinspection PyBroadException + base_state = load_state(state_file) + base_state['task_id'] = state.get('task_id') + state = base_state + + state['__version__'] = get_version() + # save new state + with open(state_file, 'wt') as f: + json.dump(state, f, sort_keys=True) + + +def load_state(state_file): + # noinspection PyBroadException + try: + with open(state_file, 'rt') as f: + state = json.load(f) + except Exception: + state = {} + return state + + +def clone_task(state, project_id): + new_task = False + if state.get('debugging'): + print('Starting new debugging session to {}'.format(state.get('debugging'))) + task = create_debugging_task(state, state.get('debugging')) + elif state.get('base_task_id'): + print('Cloning base session {}'.format(state['base_task_id'])) + task = Task.clone(source_task=state['base_task_id'], project=project_id, parent=state['base_task_id']) + task.set_system_tags([system_tag]) + else: + print('Creating new session') + task = create_base_task(state, project_name=state.get('project')) + new_task = True + + task_params = task.get_parameters(backwards_compatibility=False) + if 'General/ssh_server' in task_params: + section = 'General' + init_section = 'init_script' + else: + section, _, init_section = _get_config_section_name() + task_params['properties/jupyter_token'] = '' + task_params['properties/jupyter_port'] = '' + if state.get('remote_gateway') is not None: + task_params['properties/external_address'] = str(state.get('remote_gateway')) + task_params['{}/ssh_server'.format(section)] = str(True) + task_params['{}/ssh_password'.format(section)] = state['password'] + task_params['{}/user_key'.format(section)] = config_obj.get("api.credentials.access_key") + task_params['{}/user_secret'.format(section)] = config_obj.get("api.credentials.secret_key") + task_params["{}/vscode_server".format(section)] = bool(state.get('vscode_server')) + task_params["{}/public_ip".format(section)] = bool(state.get('public_ip')) + if state.get('user_folder'): + task_params['{}/user_base_directory'.format(section)] = state.get('user_folder') + docker = state.get('docker') or task.data.execution.docker_cmd + if not state.get('skip_docker_network') and not docker: + docker = default_docker_image + if docker: + task_params['{}/default_docker'.format(section)] = docker.replace('--network host', '').strip() + task.set_base_docker(docker + ( + ' --network host' if not state.get('skip_docker_network') and '--network host' not in docker else '')) + # set the bash init script + if state.get('init_script') is not None and (not new_task or state.get('init_script').strip()): + # noinspection PyProtectedMember + task._set_configuration(name=init_section, config_type='bash', config_text=state.get('init_script') or '') + + # store the .git-credentials + if state.get('git_credentials'): + git_cred_file = os.path.join(os.path.expanduser('~'), '.git-credentials') + if os.path.isfile(git_cred_file): + task.connect_configuration( + configuration=git_cred_file, name='git_credentials', description='git credentials') + git_conf_file = os.path.join(os.path.expanduser('~'), '.gitconfig') + if os.path.isfile(git_conf_file): + task.connect_configuration( + configuration=git_conf_file, name='git_config', description='git config') + + if state.get('packages'): + requirements = task.data.script.requirements or {} + # notice split order is important! + packages = [p for p in state['packages'] if p.strip() and not p.strip().startswith('#')] + packages_id = set(reduce(lambda a, b: a.split(b)[0], "#;@=~<>", p).strip() for p in packages) + if isinstance(requirements.get('pip'), str): + requirements['pip'] = requirements['pip'].split('\n') + for p in (requirements.get('pip') or []): + if not p.strip() or p.strip().startswith('#'): + continue + p_id = reduce(lambda a, b: a.split(b)[0], "#;@=~<>", p).strip() + if p_id not in packages_id: + packages += [p] + + requirements['pip'] = '\n'.join(sorted(packages)) + task.update_task({'script': {'requirements': requirements}}) + task.set_parameters(task_params) + print('New session created [id={}]'.format(task.id)) + return task + + +def wait_for_machine(task): + # wait until task is running + print('Waiting for remote machine allocation [id={}]'.format(task.id)) + last_status = None + while last_status != 'in_progress' and last_status in (None, 'created', 'queued', 'unknown',): + print('.', end='', flush=True) + if last_status is not None: + sleep(2.) + status = task.get_status() + if last_status != status: + # noinspection PyProtectedMember + last_status = task._get_status()[1] + print('Status [{}]{}'.format(status, ' - {}'.format(last_status) if last_status else '')) + last_status = status + print('Remote machine allocated') + print('Setting remote environment [Task id={}]'.format(task.id)) + print('Setup process details: {}'.format(task.get_output_log_web_page())) + print('Waiting for environment setup to complete [usually about 20-30 seconds]') + # monitor progress, until we get the new jupyter, then we know it is working + task.reload() + while not task.get_parameter('properties/jupyter_port') and task.get_status() == 'in_progress': + print('.', end='', flush=True) + sleep(3.) + task.reload() + if task.get_status() != 'in_progress': + raise ValueError("Remote setup failed (status={}) see details: {}".format( + task.get_status(), task.get_output_log_web_page())) + print('\nRemote machine is ready') + + return task + + +def start_ssh_tunnel(remote_address, ssh_port, ssh_password, local_remote_pair_list): + print('Starting SSH tunnel') + child = None + args = ['-N', '-C', + 'root@{}'.format(remote_address), '-p', '{}'.format(ssh_port), + '-o', 'UserKnownHostsFile=/dev/null', + '-o', 'StrictHostKeyChecking=no', + '-o', 'ServerAliveInterval=10', + '-o', 'ServerAliveCountMax=10', ] + + for local, remote in local_remote_pair_list: + args.extend(['-L', '{}:localhost:{}'.format(local, remote)]) + + # noinspection PyBroadException + try: + child = pexpect.spawn( + command=_check_ssh_executable(), + args=args, + logfile=sys.stdout, timeout=20, encoding='utf-8') + i = child.expect(['password:', r'\(yes\/no\)', r'.*[$#] ', pexpect.EOF]) + if i == 0: + child.sendline(ssh_password) + try: + child.expect(['password:'], timeout=5) + print('Incorrect password') + raise ValueError('Incorrect password') + except pexpect.TIMEOUT: + pass + + elif i == 1: + child.sendline("yes") + ret1 = child.expect(["password:", pexpect.EOF]) + if ret1 == 0: + child.sendline(ssh_password) + try: + child.expect(['password:'], timeout=5) + print('Incorrect password') + raise ValueError('Incorrect password') + except pexpect.TIMEOUT: + pass + except Exception: + child.terminate(force=True) + child = None + print('\n') + return child + + +def monitor_ssh_tunnel(state, task): + print('Setting up connection to remote session') + local_jupyter_port = 8878 + local_ssh_port = 8022 + local_vscode_port = 8898 + ssh_process = None + sleep_period = 3 + ssh_port = jupyter_token = jupyter_port = internal_ssh_port = ssh_password = remote_address = None + vscode_port = None + + connect_state = {'reconnect': False} + if not state.get('disable_keepalive'): + local_jupyter_port_ = local_jupyter_port + 1 + SingleThreadProxy(local_jupyter_port, local_jupyter_port_) + local_vscode_port_ = local_vscode_port + 1 + if state.get('vscode_server'): + SingleThreadProxy(local_vscode_port, local_vscode_port_) + local_ssh_port_ = local_ssh_port + 1 + TcpProxy(local_ssh_port, local_ssh_port_, connect_state, verbose=False, + keep_connection=True, is_connection_server=False) + else: + local_jupyter_port_ = local_jupyter_port + local_ssh_port_ = local_ssh_port + local_vscode_port_ = local_vscode_port + + default_section = _get_config_section_name()[0] + local_remote_pair_list = [] + try: + while task.get_status() == 'in_progress': + if not all([ssh_port, jupyter_token, jupyter_port, internal_ssh_port, ssh_password, remote_address]): + task.reload() + task_parameters = task.get_parameters() + section = 'General' if 'General/ssh_server' in task_parameters else default_section + remote_address = \ + task_parameters.get('properties/k8s-gateway-address') or \ + task_parameters.get('properties/external_address') + ssh_password = task_parameters.get('{}/ssh_password'.format(section)) or state['password'] + internal_ssh_port = task_parameters.get('properties/internal_ssh_port') + jupyter_port = task_parameters.get('properties/jupyter_port') + jupyter_token = task_parameters.get('properties/jupyter_token') + ssh_port = \ + task_parameters.get('properties/k8s-pod-port') or \ + task_parameters.get('properties/external_ssh_port') or internal_ssh_port + if not state.get('disable_keepalive'): + internal_ssh_port = task_parameters.get('properties/internal_stable_ssh_port') or internal_ssh_port + local_remote_pair_list = [(local_jupyter_port_, jupyter_port), (local_ssh_port_, internal_ssh_port)] + if state.get('vscode_server'): + vscode_port = task_parameters.get('properties/vscode_port') + if vscode_port: + local_remote_pair_list += [(local_vscode_port_, vscode_port)] + + if not jupyter_port: + print('Waiting for Jupyter server...') + continue + + if connect_state.get('reconnect'): + # noinspection PyBroadException + try: + ssh_process.close(**({'force': True} if sys.platform != 'win32' else {})) + ssh_process = None + except Exception: + pass + + if not ssh_process or not ssh_process.isalive(): + ssh_process = start_ssh_tunnel( + remote_address, ssh_port, ssh_password, + local_remote_pair_list=local_remote_pair_list) + + if ssh_process and ssh_process.isalive(): + msg = \ + 'Interactive session is running:\n'\ + 'SSH: ssh root@localhost -p {local_ssh_port} [password: {ssh_password}]\n'\ + 'Jupyter Lab URL: http://localhost:{local_jupyter_port}/?token={jupyter_token}'.format( + local_jupyter_port=local_jupyter_port, local_ssh_port=local_ssh_port, + ssh_password=ssh_password, jupyter_token=jupyter_token) + if vscode_port: + msg += 'VSCode server available at http://localhost:{local_vscode_port}/'.format( + local_vscode_port=local_vscode_port) + print(msg) + + print('\nConnection is up and running\n' + 'Enter \"r\" (or \"reconnect\") to reconnect the session (for example after suspend)\n' + 'Ctrl-C (or "quit") to abort (remote session remains active)\n' + 'or \"Shutdown\" to shutdown remote interactive session') + else: + logging.getLogger().warning('SSH tunneling failed, retrying in {} seconds'.format(3)) + sleep(3.) + continue + + connect_state['reconnect'] = False + + # wait for user input + user_input = _read_std_input(timeout=sleep_period) + if user_input is None: + # noinspection PyBroadException + try: + # check open connections + proc = psutil.Process(ssh_process.pid) + open_ports = [p.laddr.port for p in proc.connections(kind='tcp4') if p.status == 'LISTEN'] + remote_ports = [p.raddr.port for p in proc.connections(kind='tcp4') if p.status == 'ESTABLISHED'] + if int(local_jupyter_port_) not in open_ports or \ + int(local_ssh_port_) not in open_ports or \ + int(ssh_port) not in remote_ports: + connect_state['reconnect'] = True + except Exception: + pass + continue + + if user_input.lower() == 'shutdown': + print('Shutting down interactive session') + task.mark_stopped() + break + elif user_input.lower() in ('r', 'reconnect', ): + print('Reconnecting to interactive session') + # noinspection PyBroadException + try: + ssh_process.close(**({'force': True} if sys.platform != 'win32' else {})) + except Exception: + pass + elif user_input.lower() in ('q', 'quit',): + raise KeyboardInterrupt() + else: + print('unknown command: \'{}\''.format(user_input)) + + print('Interactive session ended') + except KeyboardInterrupt: + print('\nUser aborted') + + # kill the ssh process + # noinspection PyBroadException + try: + ssh_process.close(**({'force': True} if sys.platform != 'win32' else {})) + except Exception: + pass + # noinspection PyBroadException + try: + ssh_process.kill(9 if sys.platform != 'win32' else 15) + except Exception: + pass + + +def setup_parser(parser): + parser.add_argument('--version', action='store_true', default=None, + help='Display the clearml-session utility version') + parser.add_argument('--attach', default=False, nargs='?', + help='Attach to running interactive session (default: previous session)') + parser.add_argument('--debugging', type=str, default=None, + help='Pass existing Task id (experiment), create a copy of the experiment on a remote machine, ' + 'and launch jupyter/ssh for interactive access. Example --debugging ') + parser.add_argument('--queue', type=str, default=None, + help='Select the queue to launch the interactive session on (default: previously used queue)') + parser.add_argument('--docker', type=str, default=None, + help='Select the docker image to use in the interactive session on ' + '(default: previously used docker image or `{}`)'.format(default_docker_image)) + parser.add_argument('--public-ip', default=None, nargs='?', const='true', metavar='true/false', + type=lambda x: (str(x).strip().lower() in ('true', 'yes')), + help='If True register the public IP of the remote machine. Set if running on the cloud. ' + 'Default: false (use for local / on-premises)') + parser.add_argument('--vscode-server', default=True, nargs='?', const='true', metavar='true/false', + type=lambda x: (str(x).strip().lower() in ('true', 'yes')), + help='Installing vscode server (code-server) on interactive session (default: true)') + parser.add_argument('--git-credentials', default=False, nargs='?', const='true', metavar='true/false', + type=lambda x: (str(x).strip().lower() in ('true', 'yes')), + help='If true, local .git-credentials file is sent to the interactive session. ' + '(default: false)') + parser.add_argument('--user-folder', type=str, default=None, + help='Advanced: Set the remote base folder (default: ~/)') + parser.add_argument('--packages', type=str, nargs='*', + help='Additional packages to add, supports version numbers ' + '(default: previously added packages). ' + 'examples: --packages torch==1.7 tqdm') + parser.add_argument('--requirements', type=FileType('r'), default=None, + help='Specify requirements.txt file to install when setting the interactive session. ' + 'Requirements file is read and stored in `packages` section as default for ' + 'the next sessions. Can be overridden by calling `--packages`') + parser.add_argument('--init-script', type=FileType('r'), default=False, nargs='?', + help='Specify BASH init script file to be executed when setting the interactive session. ' + 'Script content is read and stored as default script for the next sessions. ' + 'To clear the init-script do not pass a file') + parser.add_argument('--config-file', type=str, default='~/.clearml_session.json', + help='Advanced: Change the configuration file used to store the previous state ' + '(default: ~/.clearml_session.json') + parser.add_argument('--remote-gateway', default=None, nargs='?', + help='Advanced: Specify gateway ip/address to be passed to interactive session ' + '(for use with k8s ingestion / ELB') + parser.add_argument('--base-task-id', type=str, default=None, + help='Advanced: Set the base task ID for the interactive session. ' + '(default: previously used Task). Use `none` for the default interactive session') + parser.add_argument('--project', type=str, default=None, + help='Advanced: Set the project name for the interactive session Task') + parser.add_argument('--disable-keepalive', action='store_true', default=None, + help='Advanced: If set, disable the transparent proxy always keeping the sockets alive. ' + 'Default: false, use transparent socket mitigating connection drops.') + parser.add_argument('--queue-excluded-tag', default=None, nargs='*', + help='Advanced: Excluded queues with this specific tag from the selection') + parser.add_argument('--queue-include-tag', default=None, nargs='*', + help='Advanced: Only include queues with this specific tag from the selection') + parser.add_argument('--skip-docker-network', action='store_true', default=None, + help='Advanced: If set, `--network host` is **not** passed to docker ' + '(assumes k8s network ingestion) (default: false)') + parser.add_argument('--password', type=str, default=None, + help='Advanced: Select ssh password for the interactive session ' + '(default: previously used one)') + + +def get_version(): + from .version import __version__ + return __version__ + + +def cli(): + title = 'clearml-session - CLI for launching JupyterLab / VSCode on a remote machine' + print(title) + parser = ArgumentParser( + prog='clearml-session', description=title, + epilog='Notice! all arguments are stored as new defaults for the next session') + setup_parser(parser) + + # get the args + args = parser.parse_args() + + if args.version: + print('Version {}'.format(get_version())) + exit(0) + + # check ssh + if not _check_ssh_executable(): + raise ValueError("Could not locate SSH executable") + + # load previous state + state_file = os.path.abspath(os.path.expandvars(os.path.expanduser(args.config_file))) + state = load_state(state_file) + + task = None + if not args.debugging and (args.attach or state.get('task_id')): + task_id = args.attach or state.get('task_id') + print('Checking previous session') + try: + task = Task.get_task(task_id=task_id) + except ValueError: + task = None + previous_status = task.get_status() if task else None + if previous_status == 'in_progress': + # only ask if we were not requested directly + if args.attach is False: + choice = input('Connect to active session id={} [Y]/n? '.format(task_id)) + if choice in ('n', 'N', 'no', 'No', 'NO'): + task = None + else: + print('Using active session id={}'.format(task_id)) + else: + print('Previous session is unavailable [status={}], starting a new session.'.format(previous_status)) + task = None + + if task: + state['task_id'] = task.id + save_state(state, state_file) + else: + state.pop('task_id', None) + save_state(state, state_file) + + print('Verifying credentials') + client = APIClient() + + # update state with new args + # and make sure we have all the required fields + state = get_user_inputs(args, parser, state, client) + + # save state + save_state(state, state_file) + + # get project name + project_id = get_project_id(state) + + # remove old Tasks created by us. + delete_old_tasks(client, state.get('base_task_id')) + + # Clone the Task and adjust parameters + task = clone_task(state, project_id) + state['task_id'] = task.id + save_state(state, state_file) + + # launch + Task.enqueue(task=task, queue_name=state['queue']) + + # wait for machine to become available + try: + wait_for_machine(task) + except ValueError as ex: + print('\nERROR: {}'.format(ex)) + return 1 + + # launch ssh tunnel + monitor_ssh_tunnel(state, task) + + # we are done + print('Leaving interactive session') + + +def main(): + try: + cli() + except KeyboardInterrupt: + print('\nUser aborted') + except Exception as ex: + print('\nError: {}'.format(ex)) + exit(1) + + +if __name__ == '__main__': + main() diff --git a/clearml_session/interactive_session_task.py b/clearml_session/interactive_session_task.py new file mode 100644 index 0000000..330ecc9 --- /dev/null +++ b/clearml_session/interactive_session_task.py @@ -0,0 +1,616 @@ +import json +import os +import socket +import subprocess +import sys +import requests +from copy import deepcopy +from tempfile import mkstemp + +import psutil +from pathlib2 import Path + +from clearml import Task, StorageManager + + +# noinspection SpellCheckingInspection +default_ssh_fingerprint = { + 'ssh_host_ecdsa_key': + r"-----BEGIN EC PRIVATE KEY-----"+"\n" + r"MHcCAQEEIOCAf3KEN9Hrde53rqQM4eR8VfCnO0oc4XTEBw0w6lCfoAoGCCqGSM49"+"\n" + r"AwEHoUQDQgAEn/LlC/1UN1q6myfjs03LJdHY2LB0b1hBjAsLvQnDMt8QE6Rml3UF"+"\n" + r"QK/UFw4mEqCFCD+dcbyWqFsKxTm6WtFStg=="+"\n" + r"-----END EC PRIVATE KEY-----"+"\n", + + 'ssh_host_ed25519_key': + r"-----BEGIN OPENSSH PRIVATE KEY-----"+"\n" + r"b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW"+"\n" + r"QyNTUxOQAAACDvweeJHnUKtwY7/WRqDJEZTDk8AajWKFt/BXmEI3+A8gAAAJiEMTXOhDE1"+"\n" + r"zgAAAAtzc2gtZWQyNTUxOQAAACDvweeJHnUKtwY7/WRqDJEZTDk8AajWKFt/BXmEI3+A8g"+"\n" + r"AAAEBCHpidTBUN3+W8s3qRNkyaJpA/So4vEqDvOhseSqJeH+/B54kedQq3Bjv9ZGoMkRlM"+"\n" + r"OTwBqNYoW38FeYQjf4DyAAAAEXJvb3RAODQ1NmQ5YTdlYTk4AQIDBA=="+"\n" + r"-----END OPENSSH PRIVATE KEY-----"+"\n", + + 'ssh_host_rsa_key': + r"-----BEGIN RSA PRIVATE KEY-----"+"\n" + r"MIIEowIBAAKCAQEAs8R3BrinMM/k9Jak7UqsoONqLQoasYgkeBVOOfRJ6ORYWW5R"+"\n" + r"WLkYnPPUGRpbcoM1Imh7ODBgKzs0mh5/j3y0SKP/MpvT4bf38e+QGjuC+6fR4Ah0"+"\n" + r"L5ohGIMyqhAiBoXgj0k2BE6en/4Rb3BwNPMocCTus82SwajzMNgWneRC6GCq2M0n"+"\n" + r"0PWenhS0IQz7jUlw3JU8z6T3ROPiMBPU7ubBhiNlAzMYPr76Z7J6ZNrCclAvdGkI"+"\n" + r"YxK7RNq0HwfoUj0UFD9iaEHswDIlNc34p93lP6GIAbh7uVYfGhg4z7HdBoN2qweN"+"\n" + r"szo7iQX9N8EFP4WfpLzNFteThzgN/bdso8iv0wIDAQABAoIBAQCPvbF64110b1dg"+"\n" + r"p7AauVINl6oHd4PensCicE7LkmUi3qsyXz6WVfKzVVgr9mJWz0lGSQr14+CR0NZ/"+"\n" + r"wZE393vkdZWSLv2eB88vWeH8x8c1WHw9yiS1B2YdRpLVXu8GDjh/+gdCLGc0ASCJ"+"\n" + r"3fsqq5+TBEUF6oPFbEWAsdhryeAiFAokeIVEKkxRnIDvPCP6i0evUHAxEP+wOngu"+"\n" + r"4XONkixNmATNa1jP2YAjmh3uQbAf2BvDZuywJmqV8fqZa/BwuK3W+R/92t0ySZ5Q"+"\n" + r"Z7RCZzPzFvWY683/Cfx5+BH3XcIetbcZ/HKuc+TdBvvFgqrLNIJ4OXMp3osjZDMO"+"\n" + r"YZIE6DdBAoGBAOG8cgm2N+Kl2dl0q1r4S+hf//zPaDorNasvcXJcj/ypy1MdmDXt"+"\n" + r"whLSAuTN4r8axgbuws2Z870pIGd28koqg78U+pOPabkphloo8Fc97RO28ZJCK2g0"+"\n" + r"/prPgwSYymkhrvwdzIbI11BPL/rr9cLJ1eYDnzGDSqvXJDL79XxrzwMzAoGBAMve"+"\n" + r"ULkfqaYVlgY58d38XruyCpSmRSq39LTeTYRWkJTNFL6rkqL9A69z/ITdpSStEuR8"+"\n" + r"8MXQSsPz8xUhFrA2bEjW7AT0r6OqGbjljKeh1whYOfgGfMKQltTfikkrf5w0UrLw"+"\n" + r"NQ8USfpwWdFnBGQG0yE/AFknyLH14/pqfRlLzaDhAoGAcN3IJxL03l4OjqvHAbUk"+"\n" + r"PwvA8qbBdlQkgXM3RfcCB1LeVrB1aoF2h/J5f+1xchvw54Z54FMZi3sEuLbAblTT"+"\n" + r"irbyktUiB3K7uli90uEjqLfQEVEEYxYcN0uKNsIucmJlG6nKmZnSDlWJp+xS9RH1"+"\n" + r"4QvujNMYgtMPRm60T4GYAAECgYB6J9LMqik4CDUls/C2J7MH2m22lk5Zg3JQMefW"+"\n" + r"xRvK3XtxqFKr8NkVd3U2k6yRZlcsq6SFkwJJmdHsti/nFCUcHBO+AHOBqLnS7VCz"+"\n" + r"XSkAqgTKFfEJkCOgl/U/VJ4ZFcz7xSy1xV1yf4GCFK0v1lsJz7tAsLLz1zdsZARj"+"\n" + r"dOVYYQKBgC3IQHfd++r9kcL3+vU7bDVU4aKq0JFDA79DLhKDpSTVxqTwBT+/BIpS"+"\n" + r"8z79zBTjNy5gMqxZp/SWBVWmsO8d7IUk9O2L/bMhHF0lOKbaHQQ9oveCzIwDewcf"+"\n" + r"5I45LjjGPJS84IBYv4NElptRk/2eFFejr75xdm4lWfpLb1SXPOPB"+"\n" + r"-----END RSA PRIVATE KEY-----"+"\n", + + 'ssh_host_rsa_key__pub': + r'ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCzxHcGuKcwz+T0lqTtSqyg42otChqxiCR4FU459Eno5FhZblFYuRic89QZGlt' + r'ygzUiaHs4MGArOzSaHn+PfLRIo/8ym9Pht/fx75AaO4L7p9HgCHQvmiEYgzKqECIGheCPSTYETp6f/hFvcHA08yhwJO6zzZLBqPM' + r'w2Bad5ELoYKrYzSfQ9Z6eFLQhDPuNSXDclTzPpPdE4+IwE9Tu5sGGI2UDMxg+vvpnsnpk2sJyUC90aQhjErtE2rQfB+hSPRQUP2Jo' + r'QezAMiU1zfin3eU/oYgBuHu5Vh8aGDjPsd0Gg3arB42zOjuJBf03wQU/hZ+kvM0W15OHOA39t2yjyK/T', + 'ssh_host_ecdsa_key__pub': + r'ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJ/y5Qv9VDdaupsn47NNyyXR2Niwd' + r'G9YQYwLC70JwzLfEBOkZpd1BUCv1BcOJhKghQg/nXG8lqhbCsU5ulrRUrY=', + 'ssh_host_ed25519_key__pub': None, +} +config_section_name = 'interactive_session' +config_object_section_ssh = 'SSH' +config_object_section_bash_init = 'interactive_init_script' + + +__allocated_ports = [] + + +def get_free_port(range_min, range_max): + global __allocated_ports + used_ports = [i.laddr.port for i in psutil.net_connections()] + port = [i for i in range(range_min, range_max) if i not in used_ports and i not in __allocated_ports][0] + __allocated_ports.append(port) + return port + + +def init_task(param, a_default_ssh_fingerprint): + # initialize ClearML + Task.add_requirements('jupyter') + Task.add_requirements('jupyterlab') + Task.add_requirements('jupyterlab_git') + task = Task.init( + project_name="DevOps", task_name="Allocate Jupyter Notebook Instance", task_type=Task.TaskTypes.service) + + # Add jupyter server base folder + task.connect(param, name=config_section_name) + # connect ssh finger print configuration (with fallback if section is missing) + old_default_ssh_fingerprint = deepcopy(a_default_ssh_fingerprint) + try: + task.connect_configuration(configuration=a_default_ssh_fingerprint, name=config_object_section_ssh) + except (TypeError, ValueError): + a_default_ssh_fingerprint.clear() + a_default_ssh_fingerprint.update(old_default_ssh_fingerprint) + if param.get('default_docker'): + task.set_base_docker("{} --network host".format(param['default_docker'])) + # leave local process, only run remotely + task.execute_remotely() + return task + + +def setup_os_env(param): + # get rid of all the runtime ClearML + preserve = ( + "_API_HOST", + "_WEB_HOST", + "_FILES_HOST", + "_CONFIG_FILE", + "_API_ACCESS_KEY", + "_API_SECRET_KEY", + "_API_HOST_VERIFY_CERT", + "_DOCKER_IMAGE", + ) + # set default docker image, with network configuration + if param.get('default_docker', '').strip(): + os.environ["TRAINS_DOCKER_IMAGE"] = param['default_docker'].strip() + os.environ["CLEARML_DOCKER_IMAGE"] = param['default_docker'].strip() + + # setup os environment + env = deepcopy(os.environ) + for key in os.environ: + if (key.startswith("TRAINS") or key.startswith("CLEARML")) and not any(key.endswith(p) for p in preserve): + env.pop(key, None) + + return env + + +def monitor_jupyter_server(fd, local_filename, process, task, jupyter_port, hostnames): + # todo: add auto spin down see: https://tljh.jupyter.org/en/latest/topic/idle-culler.html + # print stdout/stderr + prev_line_count = 0 + process_running = True + token = None + while process_running: + process_running = False + try: + process.wait(timeout=2.0 if not token else 15.0) + except subprocess.TimeoutExpired: + process_running = True + + # noinspection PyBroadException + try: + with open(local_filename, "rt") as f: + # read new lines + new_lines = f.readlines() + if not new_lines: + continue + os.lseek(fd, 0, 0) + os.ftruncate(fd, 0) + except Exception: + continue + + print("".join(new_lines)) + prev_line_count += len(new_lines) + # if we already have the token, do nothing, just monitor + if token: + continue + + # update task with jupyter notebook server links (port / token) + line = '' + for line in new_lines: + if "http://" not in line and "https://" not in line: + continue + parts = line.split('/?token=', 1) + if len(parts) != 2: + continue + token = parts[1] + port = parts[0].split(':')[-1] + # try to cast to int + try: + port = int(port) # noqa + except (TypeError, ValueError): + continue + break + # we could not locate the token, try again + if not token: + continue + # update the task with the correct links and token + task.set_parameter(name='properties/jupyter_token', value=str(token)) + # we ignore the reported port, because jupyter server will get confused + # if we have multiple servers running and will point to the wrong port/server + task.set_parameter(name='properties/jupyter_port', value=str(jupyter_port)) + jupyter_url = '{}://{}:{}?token={}'.format( + 'https' if "https://" in line else 'http', + hostnames, jupyter_port, token + ) + print('\nJupyter Lab URL: {}\n'.format(jupyter_url)) + task.set_parameter(name='properties/jupyter_url', value=jupyter_url) + + # cleanup + # noinspection PyBroadException + try: + os.close(fd) + except Exception: + pass + # noinspection PyBroadException + try: + os.unlink(local_filename) + except Exception: + pass + + +def start_vscode_server(hostname, hostnames, param, task, env): + if not param.get("vscode_server"): + return + + # make a copy of env and remove the pythonpath from it. + env = dict(**env) + env.pop('PYTHONPATH', None) + + # find a free tcp port + port = get_free_port(9000, 9100) + + # installing VSCODE: + try: + python_ext = StorageManager.get_local_copy( + 'https://github.com/microsoft/vscode-python/releases/download/2020.10.332292344/ms-python-release.vsix', + extract_archive=False) + code_server_deb = StorageManager.get_local_copy( + 'https://github.com/cdr/code-server/releases/download/v3.7.4/code-server_3.7.4_amd64.deb', + extract_archive=False) + os.system("dpkg -i {}".format(code_server_deb)) + except Exception as ex: + print("Failed installing vscode server: {}".format(ex)) + return + + cwd = ( + os.path.expandvars(os.path.expanduser(param["user_base_directory"])) + if param["user_base_directory"] + else os.getcwd() + ) + # make sure we have the needed cwd + # noinspection PyBroadException + try: + Path(cwd).mkdir(parents=True, exist_ok=True) + except Exception: + pass + print("Running VSCode Server on {} [{}] port {} at {}".format(hostname, hostnames, port, cwd)) + print("VSCode Server available: http://{}:{}/\n".format(hostnames, port)) + user_folder = os.path.join(cwd, ".vscode/user/") + exts_folder = os.path.join(cwd, ".vscode/exts/") + + try: + fd, local_filename = mkstemp() + subprocess.Popen( + [ + "code-server", + "--auth", + "none", + "--bind-addr", + "127.0.0.1:{}".format(port), + "--user-data-dir", user_folder, + "--extensions-dir", exts_folder, + "--install-extension", python_ext, + "--install-extension", "ms-toolsai.jupyter", + # "--install-extension", "donjayamanne.python-extension-pack" + ], + env=env, + stdout=fd, + stderr=fd, + ) + settings = Path(os.path.expanduser(os.path.join(user_folder, 'User/settings.json'))) + settings.parent.mkdir(parents=True, exist_ok=True) + # noinspection PyBroadException + try: + with open(settings.as_posix(), 'rt') as f: + base_json = json.load(f) + except Exception: + base_json = {} + # noinspection PyBroadException + try: + base_json.update({ + "extensions.autoCheckUpdates": False, + "extensions.autoUpdate": False, + "python.pythonPath": sys.executable, + "terminal.integrated.shell.linux": "/bin/bash" if Path("/bin/bash").is_file() else None, + }) + with open(settings.as_posix(), 'wt') as f: + json.dump(base_json, f) + except Exception: + pass + proc = subprocess.Popen( + ['bash', '-c', + 'code-server --auth none --bind-addr 127.0.0.1:{} --disable-update-check ' + '--user-data-dir {} --extensions-dir {}'.format(port, user_folder, exts_folder)], + env=env, + stdout=fd, + stderr=fd, + cwd=cwd, + ) + try: + error_code = proc.wait(timeout=1) + raise ValueError("code-server failed starting, return code {}".format(error_code)) + except subprocess.TimeoutExpired: + pass + + except Exception as ex: + print('Failed running vscode server: {}'.format(ex)) + return + + task.set_parameter(name='properties/vscode_port', value=str(port)) + + +def start_jupyter_server(hostname, hostnames, param, task, env): + # execute jupyter notebook + fd, local_filename = mkstemp() + cwd = ( + os.path.expandvars(os.path.expanduser(param["user_base_directory"])) + if param["user_base_directory"] + else os.getcwd() + ) + + # find a free tcp port + port = get_free_port(8888, 9000) + + # make sure we have the needed cwd + # noinspection PyBroadException + try: + Path(cwd).mkdir(parents=True, exist_ok=True) + except Exception: + pass + print( + "Running Jupyter Notebook Server on {} [{}] port {} at {}".format(hostname, hostnames, port, cwd) + ) + process = subprocess.Popen( + [ + sys.executable, + "-m", + "jupyter", + "lab", + "--no-browser", + "--allow-root", + "--ip", + "0.0.0.0", + "--port", + str(port), + ], + env=env, + stdout=fd, + stderr=fd, + cwd=cwd, + ) + return monitor_jupyter_server(fd, local_filename, process, task, port, hostnames) + + +def setup_ssh_server(hostname, hostnames, param, task): + if not param.get("ssh_server"): + return + + print("Installing SSH Server on {} [{}]".format(hostname, hostnames)) + ssh_password = param.get("ssh_password", "training") + # noinspection PyBroadException + try: + port = get_free_port(10022, 15000) + proxy_port = get_free_port(10022, 15000) + + # noinspection SpellCheckingInspection + os.system( + "export PYTHONPATH=\"\" && " + "apt-get install -y openssh-server && " + "mkdir -p /var/run/sshd && " + "echo 'root:{password}' | chpasswd && " + "echo 'PermitRootLogin yes' >> /etc/ssh/sshd_config && " + "sed -i 's/PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config && " + "sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd " + "&& " # noqa: W605 + "echo 'ClientAliveInterval 10' >> /etc/ssh/sshd_config && " + "echo 'ClientAliveCountMax 20' >> /etc/ssh/sshd_config && " + "echo 'AcceptEnv TRAINS_API_ACCESS_KEY TRAINS_API_SECRET_KEY " + "CLEARML_API_ACCESS_KEY CLEARML_API_SECRET_KEY' >> /etc/ssh/sshd_config && " + 'echo "export VISIBLE=now" >> /etc/profile && ' + 'echo "export PATH=$PATH" >> /etc/profile && ' + 'echo "ldconfig" >> /etc/profile && ' + 'echo "export TRAINS_CONFIG_FILE={trains_config_file}" >> /etc/profile'.format( + password=ssh_password, + port=port, + trains_config_file=os.environ.get("CLEARML_CONFIG_FILE") or os.environ.get("TRAINS_CONFIG_FILE"), + ) + ) + + # create fingerprint files + Path('/etc/ssh/').mkdir(parents=True, exist_ok=True) + for k, v in default_ssh_fingerprint.items(): + filename = '/etc/ssh/{}'.format(k.replace('__pub', '.pub')) + try: + os.unlink(filename) + except Exception: # noqa + pass + if v: + with open(filename, 'wt') as f: + f.write(v + (' root@{}'.format(hostname) if filename.endswith('.pub') else '')) + os.chmod(filename, 0o644 if filename.endswith('.pub') else 0o600) + + # run server + result = os.system("/usr/sbin/sshd -p {port}".format(port=port)) + + if result == 0: + # noinspection PyBroadException + try: + TcpProxy(listen_port=proxy_port, target_port=port, proxy_state={}, verbose=False, # noqa + keep_connection=True, is_connection_server=True) + except Exception as ex: + print('Warning: Could not setup stable ssh port, {}'.format(ex)) + proxy_port = None + + if task: + if proxy_port: + task.set_parameter(name='properties/internal_stable_ssh_port', value=str(proxy_port)) + task.set_parameter(name='properties/internal_ssh_port', value=str(port)) + + print( + "\n#\n# SSH Server running on {} [{}] port {}\n# LOGIN u:root p:{}\n#\n".format( + hostname, hostnames, port, ssh_password + ) + ) + else: + raise ValueError() + except Exception as ex: + print("{}\n\n#\n# Error: SSH server could not be launched\n#\n".format(ex)) + + +def setup_user_env(param, task): + env = setup_os_env(param) + # create symbolic link to the venv + environment = os.path.expanduser('~/environment') + # noinspection PyBroadException + try: + os.symlink(os.path.abspath(os.path.join(os.path.abspath(sys.executable), '..', '..')), environment) + print('Virtual environment are available at {}'.format(environment)) + except Exception: + pass + # set default user credentials + if param.get("user_key") and param.get("user_secret"): + os.system("echo 'export TRAINS_API_ACCESS_KEY=\"{}\"' >> ~/.bashrc".format( + param.get("user_key", "").replace('\\$', '\\$'))) + os.system("echo 'export TRAINS_API_SECRET_KEY=\"{}\"' >> ~/.bashrc".format( + param.get("user_secret", "").replace('\\$', '\\$'))) + os.system("echo 'export TRAINS_DOCKER_IMAGE=\"{}\"' >> ~/.bashrc".format( + param.get("default_docker", "").strip() or env.get('TRAINS_DOCKER_IMAGE', ''))) + os.system("echo 'export TRAINS_API_ACCESS_KEY=\"{}\"' >> ~/.profile".format( + param.get("user_key", "").replace('\\$', '\\$'))) + os.system("echo 'export TRAINS_API_SECRET_KEY=\"{}\"' >> ~/.profile".format( + param.get("user_secret", "").replace('\\$', '\\$'))) + os.system("echo 'export TRAINS_DOCKER_IMAGE=\"{}\"' >> ~/.profile".format( + param.get("default_docker", "").strip() or env.get('TRAINS_DOCKER_IMAGE', ''))) + env['TRAINS_API_ACCESS_KEY'] = param.get("user_key") + env['TRAINS_API_SECRET_KEY'] = param.get("user_secret") + # set default folder for user + if param.get("user_base_directory"): + base_dir = param.get("user_base_directory") + if ' ' in base_dir: + base_dir = '\"{}\"'.format(base_dir) + os.system("echo 'cd {}' >> ~/.bashrc".format(base_dir)) + os.system("echo 'cd {}' >> ~/.profile".format(base_dir)) + + # make sure we activate the venv in the bash + os.system("echo 'source {}' >> ~/.bashrc".format(os.path.join(environment, 'bin', 'activate'))) + os.system("echo '. {}' >> ~/.profile".format(os.path.join(environment, 'bin', 'activate'))) + + # check if we need to create .git-credentials + # noinspection PyProtectedMember + git_credentials = task._get_configuration_text('git_credentials') + if git_credentials: + git_cred_file = os.path.expanduser('~/.config/git/credentials') + # noinspection PyBroadException + try: + Path(git_cred_file).parent.mkdir(parents=True, exist_ok=True) + with open(git_cred_file, 'wt') as f: + f.write(git_credentials) + except Exception: + print('Could not write {} file'.format(git_cred_file)) + # noinspection PyProtectedMember + git_config = task._get_configuration_text('git_config') + if git_config: + git_config_file = os.path.expanduser('~/.config/git/config') + # noinspection PyBroadException + try: + Path(git_config_file).parent.mkdir(parents=True, exist_ok=True) + with open(git_config_file, 'wt') as f: + f.write(git_config) + except Exception: + print('Could not write {} file'.format(git_config_file)) + + return env + + +def get_host_name(task, param): + # noinspection PyBroadException + try: + hostname = socket.gethostname() + hostnames = socket.gethostbyname(socket.gethostname()) + except Exception: + def get_ip_addresses(family): + for interface, snics in psutil.net_if_addrs().items(): + for snic in snics: + if snic.family == family: + yield snic.address + + hostnames = list(get_ip_addresses(socket.AF_INET))[0] + hostname = hostnames + + # try to get external address (if possible) + # noinspection PyBroadException + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + # noinspection PyBroadException + try: + # doesn't even have to be reachable + s.connect(('8.255.255.255', 1)) + hostnames = s.getsockname()[0] + except Exception: + pass + finally: + s.close() + except Exception: + pass + + # update host name + if not task.get_parameter(name='properties/external_address'): + external_addr = hostnames + if param.get('public_ip'): + # noinspection PyBroadException + try: + external_addr = requests.get('https://checkip.amazonaws.com').text.strip() + except Exception: + pass + task.set_parameter(name='properties/external_address', value=str(external_addr)) + + return hostname, hostnames + + +def run_user_init_script(task): + # run initialization script: + # noinspection PyProtectedMember + init_script = task._get_configuration_text(config_object_section_bash_init) + if not init_script or not str(init_script).strip(): + return + print("Running user initialization bash script:") + init_filename = os_json_filename = None + try: + fd, init_filename = mkstemp(suffix='.init.sh') + os.close(fd) + fd, os_json_filename = mkstemp(suffix='.env.json') + os.close(fd) + with open(init_filename, 'wt') as f: + f.write(init_script + + '\n{} -c ' + '"exec(\\"try:\\n import os\\n import json\\n' + ' json.dump(dict(os.environ), open(\\\'{}\\\', \\\'w\\\'))' + '\\nexcept: pass\\")"'.format(sys.executable, os_json_filename)) + env = dict(**os.environ) + # do not pass or update back the PYTHONPATH environment variable + env.pop('PYTHONPATH', None) + subprocess.call(['/bin/bash', init_filename], env=env) + with open(os_json_filename, 'rt') as f: + environ = json.load(f) + # do not pass or update back the PYTHONPATH environment variable + environ.pop('PYTHONPATH', None) + # update environment variables + os.environ.update(environ) + except Exception as ex: + print('User initialization script failed: {}'.format(ex)) + finally: + if init_filename: + try: + os.unlink(init_filename) + except: # noqa + pass + if os_json_filename: + try: + os.unlink(os_json_filename) + except: # noqa + pass + + +def main(): + param = { + "user_base_directory": "~/", + "ssh_server": True, + "ssh_password": "training", + "default_docker": "nvidia/cuda", + "user_key": None, + "user_secret": None, + "vscode_server": True, + "public_ip": False, + } + task = init_task(param, default_ssh_fingerprint) + + run_user_init_script(task) + + hostname, hostnames = get_host_name(task, param) + + env = setup_user_env(param, task) + + setup_ssh_server(hostname, hostnames, param, task) + + start_vscode_server(hostname, hostnames, param, task, env) + + start_jupyter_server(hostname, hostnames, param, task, env) + + print('We are done') + + +if __name__ == '__main__': + main() diff --git a/clearml_session/single_thread_proxy.py b/clearml_session/single_thread_proxy.py new file mode 100644 index 0000000..1a306de --- /dev/null +++ b/clearml_session/single_thread_proxy.py @@ -0,0 +1,126 @@ +import threading +import socket +import time +import select +import sys + + +class SingleThreadProxy(object): + max_timeout_for_remote_connection = 60 + + class Forward(object): + def __init__(self): + self.forward = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + def start(self, host, port): + try: + self.forward.connect((host, port)) + return self.forward + except Exception as e: + return False + + def __init__(self, port, tgtport, host="127.0.0.1", tgthost="127.0.0.1", + buffer_size=4096, delay=0.0001, state=None): + self.input_list = [] + self.channel = {} + self.sidmap = {} + self.state = state or {} + + # set max number of open files + # noinspection PyBroadException + try: + if sys.platform == 'win32': + import ctypes + ctypes.windll.msvcrt._setmaxstdio(max(2048, ctypes.windll.msvcrt._getmaxstdio())) # noqa + else: + import resource + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) + resource.setrlimit(resource.RLIMIT_NOFILE, (max(4096, soft), hard)) + except Exception: + pass + + self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.server.bind((host, port)) + self.server.listen(100) + self.tgthost, self.tgtport = tgthost, tgtport + self.buffer_size, self.delay = buffer_size, delay + self._proxy_daemon_thread = threading.Thread(target=self.main_loop) + self._proxy_daemon_thread.setDaemon(True) + self._proxy_daemon_thread.start() + + def main_loop(self): + self.input_list.append(self.server) + ss = select.select + while 1: + time.sleep(self.delay) + try: + inputready, outputready, exceptready = ss(self.input_list, [], []) + except: + continue + for self.s in inputready: + if self.s == self.server: + try: + self.on_accept() + except: + pass + break + + try: + self.data = self.s.recv(self.buffer_size) + except: + continue + if len(self.data) == 0: + try: + self.on_close() + except: + pass + break + else: + try: + self.on_recv() + except: + pass + + def on_accept(self): + clientsock, clientaddr = self.server.accept() + for i in range(self.max_timeout_for_remote_connection): + forward = self.Forward().start(self.tgthost, self.tgtport) + if forward: + break + # print('waiting for remote...') + time.sleep(1) + + if forward: + # logger.info("{0} has connected".format(clientaddr)) + self.input_list.append(clientsock) + self.input_list.append(forward) + self.channel[clientsock] = forward + self.channel[forward] = clientsock + _sidbase = "{0}_{1}_{2}_{3}".format(self.tgthost, self.tgtport, clientaddr[0], clientaddr[1]) + self.sidmap[clientsock] = (_sidbase, 1) + self.sidmap[forward] = (_sidbase, -1) + else: + # logger.warn("Can't establish connection with remote server.\n" + # "Closing connection with client side{0}".format(clientaddr)) + clientsock.close() + + def on_close(self): + # logger.info("{0} has disconnected".format(self.s.getpeername())) + + self.input_list.remove(self.s) + self.input_list.remove(self.channel[self.s]) + out = self.channel[self.s] + self.channel[out].close() + self.channel[self.s].close() + del self.channel[out] + del self.channel[self.s] + del self.sidmap[out] + del self.sidmap[self.s] + + def on_recv(self): + _sidbase = self.sidmap[self.s][0] + _c_or_s = self.sidmap[self.s][1] + data = self.data + # logger.debug(ctrl_less(data.strip())) + self.channel[self.s].send(data) diff --git a/clearml_session/tcp_proxy.py b/clearml_session/tcp_proxy.py new file mode 100644 index 0000000..2f274a5 --- /dev/null +++ b/clearml_session/tcp_proxy.py @@ -0,0 +1,359 @@ +import hashlib +import sys +import threading +import socket +import time +import select +import errno +from typing import Union + + +class TcpProxy(object): + __header = 'PROXY#' + __close_header = 'CLOSE#' + __uid_length = 64 + __socket_test_timeout = 3 + __max_sockets = 100 + __wait_timeout = 300 # make sure we do not collect lost sockets, and drop it after 5 minutes + __default_packet_size = 4096 + + def __init__(self, + listen_port=8868, target_port=8878, proxy_state=None, verbose=None, + keep_connection=False, is_connection_server=False): + # type: (int, int, dict, bool, bool, bool) -> () + self.listen_ip = '127.0.0.1' + self.target_ip = '127.0.0.1' + self.logfile = None # sys.stdout + self.listen_port = listen_port + self.target_port = target_port + self.proxy_state = proxy_state or {} + self.verbose = verbose + self.proxy_socket = None + self.active_local_sockets = {} + self.close_local_sockets = set() + self.keep_connection = keep_connection + self.keep_connection_server = keep_connection and is_connection_server + self.keep_connection_client = keep_connection and not is_connection_server + # set max number of open files + # noinspection PyBroadException + try: + if sys.platform == 'win32': + import ctypes + ctypes.windll.msvcrt._setmaxstdio(max(2048, ctypes.windll.msvcrt._getmaxstdio())) # noqa + else: + import resource + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) + resource.setrlimit(resource.RLIMIT_NOFILE, (max(4096, soft), hard)) + except Exception: + pass + self._proxy_daemon_thread = threading.Thread(target=self.daemon) + self._proxy_daemon_thread.setDaemon(True) + self._proxy_daemon_thread.start() + + def get_thread(self): + return self._proxy_daemon_thread + + @staticmethod + def receive_from(s, size=0): + # type: (socket.socket, int) -> bytes + # receive data from a socket until no more data is there + b = b"" + while True: + data = s.recv(size-len(b) if size else TcpProxy.__default_packet_size) + b += data + if size and len(b) < size: + continue + if size or not data or len(data) < TcpProxy.__default_packet_size: + break + return b + + @staticmethod + def send_to(s, data): + # type: (socket.socket, Union[str, bytes]) -> () + s.send(data.encode() if isinstance(data, str) else data) + + def start_proxy_thread(self, local_socket, uuid, init_data): + try: + remote_socket = self._open_remote_socket(local_socket) + except Exception as ex: + self.vprint('Exception {}: {}'.format(type(ex), ex)) + return + while True: + try: + init_data_ = init_data + init_data = None + self._process_socket_proxy(local_socket, remote_socket, uuid=uuid, init_data=init_data_) + return + except Exception as ex: + self.vprint('Exception {}: {}'.format(type(ex), ex)) + time.sleep(0.1) + + def _open_remote_socket(self, local_socket): + # This method is executed in a thread. It will relay data between the local + # host and the remote host, while letting modules work on the data before + # passing it on. + remote_socket = None + while True: + if remote_socket: + remote_socket.close() + remote_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM) + timeout = 60 + try: + remote_socket.settimeout(timeout) + remote_socket.connect((self.target_ip, self.target_port)) + msg = 'Connected to {}'.format(remote_socket.getpeername()) + self.vprint(msg) + self.log(msg) + except socket.error as serr: + if serr.errno == errno.ECONNREFUSED: + # for s in [remote_socket, local_socket]: + # s.close() + msg = '{}, {}:{} - Connection refused'.format( + time.strftime("%Y-%m-%d %H:%M:%S"), self.target_ip, self.target_port) + self.vprint(msg) + self.log(msg) + # return None + self.proxy_state['reconnect'] = True + time.sleep(1) + continue + elif serr.errno == errno.ETIMEDOUT: + # for s in [remote_socket, local_socket]: + # s.close() + msg = '{}, {}:{} - Connection connection timed out'.format( + time.strftime("%Y-%m-%d %H:%M:%S"), self.target_ip, self.target_port) + self.vprint(msg) + self.log(msg) + # return None + self.proxy_state['reconnect'] = True + time.sleep(1) + continue + else: + self.vprint("Connection error {}".format(serr.errno)) + for s in [remote_socket, local_socket]: + s.close() + raise serr + break + + return remote_socket + + def _process_socket_proxy(self, local_socket, remote_socket, uuid=None, init_data=None): + # This method is executed in a thread. It will relay data between the local + # host and the remote host, while letting modules work on the data before + # passing it on. + timeout = 60 + + # if we are self.keep_connection_client we need to generate uuid, send it + if self.keep_connection_client: + if uuid is None: + uuid = hashlib.sha256('{}{}'.format(time.time(), local_socket.getpeername()).encode()).hexdigest() + self.vprint('sending UUID {}'.format(uuid)) + self.send_to(remote_socket, self.__header + uuid) + + # check if we need to send init_data + if init_data: + self.vprint('sending init data {}'.format(len(init_data))) + self.send_to(remote_socket, init_data) + + # This loop ends when no more data is received on either the local or the + # remote socket + running = True + while running: + read_sockets, _, _ = select.select([remote_socket, local_socket], [], []) + + for sock in read_sockets: + try: + peer = sock.getpeername() + except socket.error as serr: + if serr.errno == errno.ENOTCONN: + # kind of a blind shot at fixing issue #15 + # I don't yet understand how this error can happen, + # but if it happens I'll just shut down the thread + # the connection is not in a useful state anymore + for s in [remote_socket, local_socket]: + s.close() + running = False + break + else: + self.vprint("{}: Socket exception in start_proxy_thread".format( + time.strftime('%Y-%m-%d %H:%M:%S'))) + raise serr + + data = self.receive_from(sock) + self.log('Received %d bytes' % len(data)) + + if sock == local_socket: + if len(data): + # log(args.logfile, b'< < < out\n' + data) + self.send_to(remote_socket, data) + else: + msg = "Connection from local client %s:%d closed" % peer + self.vprint(msg) + self.log(msg) + local_socket.close() + if not self.keep_connection or not uuid: + remote_socket.close() + running = False + elif self.keep_connection_server: + # test remote socket + self.vprint('waiting for reconnection, sleep 1 sec') + tic = time.time() + while uuid not in self.close_local_sockets and \ + self.active_local_sockets.get(uuid, {}).get('local_socket') == local_socket: + time.sleep(1) + self.vprint('wait local reconnect [{}]'.format(uuid)) + if time.time() - tic > self.__wait_timeout: + remote_socket.close() + running = False + break + if not running: + break + + self.vprint('done waiting') + if uuid in self.close_local_sockets: + self.vprint('client closed connection') + remote_socket.close() + running = False + self.close_local_sockets.remove(uuid) + else: + self.vprint('reconnecting local client') + local_socket = self.active_local_sockets.get(uuid, {}).get('local_socket') + + elif self.keep_connection_client: + # send UUID goodbye message + self.vprint('client {} closing socket'.format(uuid)) + remote_socket.close() + running = False + + break + + elif sock == remote_socket: + if len(data): + # log(args.logfile, b'> > > in\n' + data) + self.send_to(local_socket, data) + else: + msg = "Connection to remote server %s:%d closed" % peer + self.vprint(msg) + self.log(msg) + remote_socket.close() + if self.keep_connection_client and uuid: + # self.proxy_state['reconnect'] = True + self.vprint('Wait for remote reconnect') + time.sleep(1) + return self.start_proxy_thread(local_socket, uuid=uuid, init_data=None) + else: + local_socket.close() + running = False + break + + # remove the socket from the global list + if uuid: + self.active_local_sockets.pop(uuid, None) + if self.keep_connection_client: + self._send_remote_close_msg(timeout, uuid) + + def _send_remote_close_msg(self, timeout, uuid): + if not self.keep_connection_client or not uuid: + return + try: + self.vprint('create new control socket') + control_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM) + control_socket.settimeout(timeout) + control_socket.connect((self.target_ip, self.target_port)) + self.vprint('send close header [{}]'.format(uuid)) + self.send_to(control_socket, self.__close_header + uuid) + self.vprint('close control_socket') + control_socket.close() + except Exception as ex: + self.vprint('Error sending close header, '.format(ex)) + + def log(self, message, message_only=False): + # if message_only is True, only the message will be logged + # otherwise the message will be prefixed with a timestamp and a line is + # written after the message to make the log file easier to read + handle = self.logfile + if handle is None: + return + if not isinstance(message, bytes): + message = bytes(message, 'ascii') + if not message_only: + logentry = bytes("%s %s\n" % (time.strftime("%Y-%m-%d %H:%M:%S"), str(time.time())), 'ascii') + else: + logentry = b'' + logentry += message + if not message_only: + logentry += b'\n' + b'-' * 20 + b'\n' + handle.write(logentry.decode()) + + def vprint(self, msg): + # this will print msg, but only if is_verbose is True + if self.verbose: + print(msg) + + def daemon(self): + # this is the socket we will listen on for incoming connections + self.proxy_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.proxy_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + try: + self.proxy_socket.bind((self.listen_ip, self.listen_port)) + except socket.error as e: + print(e.strerror) + sys.exit(5) + + self.proxy_socket.listen(self.__max_sockets) + # endless loop + while True: + try: + in_socket, in_addrinfo = self.proxy_socket.accept() + msg = 'Connection from %s:%d' % in_addrinfo # noqa + self.vprint(msg) + self.log(msg) + uuid = None + init_data = None + if self.keep_connection_server: + read_sockets, _, _ = select.select([in_socket], [], []) + if read_sockets: + data = self.receive_from(in_socket, size=self.__uid_length + len(self.__header)) + self.vprint('Reading header [{}]'.format(len(data))) + if len(data) == self.__uid_length + len(self.__header): + # noinspection PyBroadException + try: + header = data.decode() + except Exception: + header = None + if header.startswith(self.__header): + uuid = header[len(self.__header):] + self.vprint('Reading UUID [{}] {}'.format(len(data), uuid)) + elif header.startswith(self.__close_header): + uuid = header[len(self.__close_header):] + self.vprint('Closing UUID [{}] {}'.format(len(data), uuid)) + self.close_local_sockets.add(uuid) + continue + else: + init_data = data + else: + init_data = data + + if self.active_local_sockets and uuid is not None: + self.vprint('Check waiting threads') + # noinspection PyBroadException + try: + if uuid in self.active_local_sockets: + self.vprint('Updating thread uuid {}'.format(uuid)) + self.active_local_sockets[uuid]['local_socket'] = in_socket + continue + except Exception: + pass + + if uuid: + self.active_local_sockets[uuid] = {'local_socket': in_socket} + + # check if thread is waiting + proxy_thread = threading.Thread(target=self.start_proxy_thread, args=(in_socket, uuid, init_data)) + proxy_thread.setDaemon(True) + self.log("Starting proxy thread " + proxy_thread.name) + proxy_thread.start() + except Exception as ex: + msg = 'Exception: {}'.format(ex) + self.vprint(msg) + self.log(msg) diff --git a/clearml_session/version.py b/clearml_session/version.py new file mode 100644 index 0000000..14e974f --- /dev/null +++ b/clearml_session/version.py @@ -0,0 +1 @@ +__version__ = '0.2.8' diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2030eeb --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +clearml +pexpect ; sys_platform != 'win32' +wexpect ; sys_platform == 'win32' diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..b5451e0 --- /dev/null +++ b/setup.py @@ -0,0 +1,79 @@ +""" +`clearml-session` - CLI for launching JupyterLab / VSCode on a remote machine +https://github.com/allegroai/clearml-session +""" + +import os.path +# Always prefer setuptools over distutils +from setuptools import setup, find_packages + + +def read_text(filepath): + with open(filepath, "r", encoding="utf-8") as f: + return f.read() + + +here = os.path.dirname(__file__) +# Get the long description from the README file +long_description = read_text(os.path.join(here, 'README.md')) + + +def read_version_string(version_file): + for line in read_text(version_file).splitlines(): + if line.startswith('__version__'): + delim = '"' if '"' in line else "'" + return line.split(delim)[1] + else: + raise RuntimeError("Unable to find version string.") + + +version = read_version_string("clearml_session/version.py") + +requirements = read_text(os.path.join(here, 'requirements.txt')).splitlines() + +setup( + name='clearml-session', + version=version, + description='clearml-session - CLI for launching JupyterLab / VSCode on a remote machine', + long_description=long_description, + long_description_content_type='text/markdown', + # The project's main homepage. + url='https://github.com/allegroai/clearml-session', + author='Allegroai', + author_email='clearml@allegro.ai', + license='Apache License 2.0', + classifiers=[ + # How mature is this project. Common values are + # 3 - Alpha + # 4 - Beta + # 5 - Production/Stable + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'Intended Audience :: Science/Research', + 'Operating System :: POSIX :: Linux', + 'Operating System :: MacOS :: MacOS X', + 'Operating System :: Microsoft', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Version Control', + 'Topic :: System :: Logging', + 'Topic :: System :: Monitoring', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'License :: OSI Approved :: Apache Software License', + ], + keywords='clearml mlops devops trains development machine deep learning version control machine-learning ' + 'machinelearning deeplearning deep-learning experiment-manager jupyter vscode', + packages=find_packages(exclude=['contrib', 'docs', 'data', 'examples', 'tests']), + install_requires=requirements, + # To provide executable scripts, use entry points in preference to the + # "scripts" keyword. Entry points provide cross-platform support and allow + # pip to create the appropriate form of executable for the target platform. + entry_points={ + 'console_scripts': [ + 'clearml-session = clearml_session.__main__:main', + ], + }, +)