Add multi machine session detection. Add multi session support. Add venv agent support.

This commit is contained in:
allegroai 2021-01-10 18:30:37 +02:00
parent b30e7b6a03
commit 75c47f6b73

View File

@ -85,6 +85,53 @@ def _check_ssh_executable():
return None
def _check_configuration():
from clearml.backend_api import Session
return Session.get_api_server_host() != Session.default_demo_host
def _check_available_port(port, ipv6=True):
""" True -- it's possible to listen on this port for TCP/IPv4 or TCP/IPv6
connections. False -- otherwise.
"""
import socket
# noinspection PyBroadException
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(('127.0.0.1', port))
sock.listen(1)
sock.close()
if ipv6:
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
sock.bind(('::1', port))
sock.listen(1)
sock.close()
except Exception:
return False
return True
def _get_available_ports(list_initial_ports):
# noinspection PyBroadException
try:
used_ports = [i.laddr.port for i in psutil.net_connections()]
except Exception:
used_ports = None
available_ports = []
for p in list_initial_ports:
port = next(
i for i in range(p, 65000)
if i not in available_ports and
((used_ports is not None and i not in used_ports) or (used_ports is None and _check_available_port(i)))
)
available_ports.append(port)
return available_ports
def create_base_task(state, project_name=None, task_name=None):
task = Task.create(project_name=project_name or 'DevOps',
task_name=task_name or 'Interactive Session',
@ -100,7 +147,7 @@ def create_base_task(state, project_name=None, task_name=None):
task_state['script']['working_dir'] = '.'
task_state['script']['entry_point'] = 'interactive_session.py'
task_state['script']['requirements'] = {'pip': '\n'.join(
["clearml", "jupyter", "jupyterlab", "jupyterlab_git"] +
["clearml"] + (["jupyter", "jupyterlab", "jupyterlab_git"] if state.get('jupyter_lab') else []) +
(['pylint'] if state.get('vscode_server') else []))}
task.update_task(task_state)
section, _, _ = _get_config_section_name()
@ -112,7 +159,7 @@ def create_base_task(state, project_name=None, task_name=None):
"{}/user_key".format(section): '',
"{}/user_secret".format(section): '',
"properties/external_address": '',
"properties/internal_ssh_port": 10022,
"properties/internal_ssh_port": '',
"properties/jupyter_token": '',
"properties/jupyter_port": '',
})
@ -144,8 +191,10 @@ def create_debugging_task(state, debug_task_id):
entry_diff_header + '\n'.join(entry_diff) + '\n' + (task_state['script']['diff'] or '')
task_state['script']['working_dir'] = '.'
task_state['script']['entry_point'] = '__interactive_session__.py'
state['packages'] = (state.get('packages') or []) + ["clearml", "jupyter", "jupyterlab", "jupyterlab_git"] + (
['pylint'] if state.get('vscode_server') else [])
state['packages'] = \
(state.get('packages') or []) + ["clearml"] + \
(["jupyter", "jupyterlab", "jupyterlab_git"] if state.get('jupyter_lab') else []) + \
(['pylint'] if state.get('vscode_server') else [])
task.update_task(task_state)
section, _, _ = _get_config_section_name()
task.set_parameters({
@ -156,7 +205,7 @@ def create_debugging_task(state, debug_task_id):
"{}/user_key".format(section): '',
"{}/user_secret".format(section): '',
"properties/external_address": '',
"properties/internal_ssh_port": 10022,
"properties/internal_ssh_port": '',
"properties/jupyter_token": '',
"properties/jupyter_port": '',
})
@ -167,9 +216,7 @@ def create_debugging_task(state, debug_task_id):
def delete_old_tasks(client, base_task_id):
print('Removing stale interactive sessions')
res = client.session.send_request(service='users', action='get_current_user', async_enable=False)
assert res.ok
current_user_id = res.json()['data']['user']['id']
current_user_id = _get_user_id(client)
previous_tasks = client.tasks.get_all(**{
'status': ['failed', 'stopped', 'completed'],
'parent': base_task_id or None,
@ -184,6 +231,40 @@ def delete_old_tasks(client, base_task_id):
logging.getLogger().warning('{}\nFailed deleting old session {}'.format(ex, t.id))
def _get_running_tasks(client, prev_task_id):
current_user_id = _get_user_id(client)
previous_tasks = client.tasks.get_all(**{
'status': ['in_progress'],
'system_tags': [system_tag],
'page_size': 10, 'page': 0,
'order_by': ['-last_update'],
'user': [current_user_id], 'only_fields': ['id', 'created', 'parent']
})
tasks_id_created = [(t.id, t.created, t.parent) for t in previous_tasks]
if prev_task_id and prev_task_id not in (t[0] for t in tasks_id_created):
# manually check the last task.id
prev_tasks = client.tasks.get_all(**{
'status': ['in_progress'],
'id': [prev_task_id],
'page_size': 10, 'page': 0,
'order_by': ['-last_update'],
'only_fields': ['id', 'created', 'parent']
})
if prev_tasks:
tasks_id_created += [(prev_tasks[0].id, prev_tasks[0].created, prev_tasks[0].parent)]
return tasks_id_created
def _get_user_id(client):
if not client:
client = APIClient()
res = client.session.send_request(service='users', action='get_current_user', async_enable=False)
assert res.ok
current_user_id = res.json()['data']['user']['id']
return current_user_id
def get_project_id(state):
project_id = None
project_name = state.get('project') or None
@ -252,7 +333,7 @@ def get_user_inputs(args, parser, state, client):
ask_queues = not state.get('queue')
if state.get('queue'):
choice = input('Use previous queue (resource) \'{}\' [Y]/n? '.format(state['queue']))
if choice in ('n', 'N', 'no', 'No', 'NO'):
if str(choice).strip().lower() in ('n', 'no'):
ask_queues = True
if ask_queues:
print('Select the queue (resource) you request:')
@ -262,9 +343,10 @@ def get_user_inputs(args, parser, state, client):
queues_list = '\n'.join('{}] {}'.format(i, q) for i, q in enumerate(queues))
while True:
try:
choice = int(input(queues_list+'\nSelect a queue [0-{}] '.format(len(queues))))
choice = int(input(queues_list+'\nSelect a queue [0-{}] '.format(len(queues)-1)))
assert 0 <= choice < len(queues)
break
except (TypeError, ValueError):
except (TypeError, ValueError, AssertionError):
pass
state['queue'] = queues[int(choice)]
@ -272,7 +354,7 @@ def get_user_inputs(args, parser, state, client):
json.dumps({k: v for k, v in state.items() if not str(k).startswith('__')}, indent=4, sort_keys=True)))
choice = input('Launch interactive session [Y]/n? ')
if choice in ('n', 'N', 'no', 'No', 'NO'):
if str(choice).strip().lower() in ('n', 'no'):
print('User aborted')
exit(0)
@ -332,6 +414,7 @@ def clone_task(state, project_id):
task_params['{}/ssh_password'.format(section)] = state['password']
task_params['{}/user_key'.format(section)] = config_obj.get("api.credentials.access_key")
task_params['{}/user_secret'.format(section)] = config_obj.get("api.credentials.secret_key")
task_params["{}/jupyterlab".format(section)] = bool(state.get('jupyter_lab'))
task_params["{}/vscode_server".format(section)] = bool(state.get('vscode_server'))
task_params["{}/public_ip".format(section)] = bool(state.get('public_ip'))
if state.get('user_folder'):
@ -380,7 +463,7 @@ def clone_task(state, project_id):
return task
def wait_for_machine(task):
def wait_for_machine(state, task):
# wait until task is running
print('Waiting for remote machine allocation [id={}]'.format(task.id))
last_status = None
@ -400,7 +483,24 @@ def wait_for_machine(task):
print('Waiting for environment setup to complete [usually about 20-30 seconds]')
# monitor progress, until we get the new jupyter, then we know it is working
task.reload()
while not task.get_parameter('properties/jupyter_port') and task.get_status() == 'in_progress':
section, _, _ = _get_config_section_name()
jupyterlab = \
task.get_parameter("{}/jupyterlab".format(section)) or \
task.get_parameter("General/jupyterlab") or ''
state['jupyter_lab'] = jupyterlab.strip().lower() != 'false'
vscode_server = \
task.get_parameter("{}/vscode_server".format(section)) or \
task.get_parameter("General/vscode_server") or ''
state['vscode_server'] = vscode_server.strip().lower() != 'false'
wait_properties = ['properties/internal_ssh_port']
if state.get('jupyter_lab'):
wait_properties += ['properties/jupyter_port']
if state.get('vscode_server'):
wait_properties += ['properties/vscode_port']
while any(bool(not task.get_parameter(p)) for p in wait_properties) and task.get_status() == 'in_progress':
print('.', end='', flush=True)
sleep(3.)
task.reload()
@ -412,11 +512,11 @@ def wait_for_machine(task):
return task
def start_ssh_tunnel(remote_address, ssh_port, ssh_password, local_remote_pair_list):
def start_ssh_tunnel(username, remote_address, ssh_port, ssh_password, local_remote_pair_list):
print('Starting SSH tunnel')
child = None
args = ['-N', '-C',
'root@{}'.format(remote_address), '-p', '{}'.format(ssh_port),
'{}@{}'.format(username, remote_address), '-p', '{}'.format(ssh_port),
'-o', 'UserKnownHostsFile=/dev/null',
'-o', 'StrictHostKeyChecking=no',
'-o', 'ServerAliveInterval=10',
@ -431,39 +531,46 @@ def start_ssh_tunnel(remote_address, ssh_port, ssh_password, local_remote_pair_l
command=_check_ssh_executable(),
args=args,
logfile=sys.stdout, timeout=20, encoding='utf-8')
i = child.expect(['password:', r'\(yes\/no\)', r'.*[$#] ', pexpect.EOF])
i = child.expect([r'(?i)password:', r'\(yes\/no\)', r'.*[$#] ', pexpect.EOF])
if i == 0:
child.sendline(ssh_password)
try:
child.expect(['password:'], timeout=5)
print('Incorrect password')
child.expect([r'(?i)password:'], timeout=5)
print('Error: incorrect password')
ssh_password = input('Please enter password manually: ')
child.sendline(ssh_password)
child.expect([r'(?i)password:'], timeout=5)
print('Error: incorrect user input password')
raise ValueError('Incorrect password')
except pexpect.TIMEOUT:
pass
elif i == 1:
child.sendline("yes")
ret1 = child.expect(["password:", pexpect.EOF])
ret1 = child.expect([r"(?i)password:", pexpect.EOF])
if ret1 == 0:
child.sendline(ssh_password)
try:
child.expect(['password:'], timeout=5)
print('Incorrect password')
child.expect([r'(?i)password:'], timeout=5)
print('Error: incorrect password')
ssh_password = input('Please enter password manually: ')
child.sendline(ssh_password)
child.expect([r'(?i)password:'], timeout=5)
print('Error: incorrect user input password')
raise ValueError('Incorrect password')
except pexpect.TIMEOUT:
pass
except Exception:
except Exception as ex:
child.terminate(force=True)
child = None
print('\n')
return child
return child, ssh_password
def monitor_ssh_tunnel(state, task):
print('Setting up connection to remote session')
local_jupyter_port = 8878
local_ssh_port = 8022
local_vscode_port = 8898
local_jupyter_port, local_jupyter_port_, local_ssh_port, local_ssh_port_, local_vscode_port, local_vscode_port_ = \
_get_available_ports([8878, 8878+1, 8022, 8022+1, 8898, 8898+1])
ssh_process = None
sleep_period = 3
ssh_port = jupyter_token = jupyter_port = internal_ssh_port = ssh_password = remote_address = None
@ -471,12 +578,10 @@ def monitor_ssh_tunnel(state, task):
connect_state = {'reconnect': False}
if not state.get('disable_keepalive'):
local_jupyter_port_ = local_jupyter_port + 1
SingleThreadProxy(local_jupyter_port, local_jupyter_port_)
local_vscode_port_ = local_vscode_port + 1
if state.get('jupyter_lab'):
SingleThreadProxy(local_jupyter_port, local_jupyter_port_)
if state.get('vscode_server'):
SingleThreadProxy(local_vscode_port, local_vscode_port_)
local_ssh_port_ = local_ssh_port + 1
TcpProxy(local_ssh_port, local_ssh_port_, connect_state, verbose=False,
keep_connection=True, is_connection_server=False)
else:
@ -495,7 +600,7 @@ def monitor_ssh_tunnel(state, task):
remote_address = \
task_parameters.get('properties/k8s-gateway-address') or \
task_parameters.get('properties/external_address')
ssh_password = task_parameters.get('{}/ssh_password'.format(section)) or state['password']
ssh_password = task_parameters.get('{}/ssh_password'.format(section)) or state.get('password', '')
internal_ssh_port = task_parameters.get('properties/internal_ssh_port')
jupyter_port = task_parameters.get('properties/jupyter_port')
jupyter_token = task_parameters.get('properties/jupyter_token')
@ -504,13 +609,20 @@ def monitor_ssh_tunnel(state, task):
task_parameters.get('properties/external_ssh_port') or internal_ssh_port
if not state.get('disable_keepalive'):
internal_ssh_port = task_parameters.get('properties/internal_stable_ssh_port') or internal_ssh_port
local_remote_pair_list = [(local_jupyter_port_, jupyter_port), (local_ssh_port_, internal_ssh_port)]
local_remote_pair_list = [(local_ssh_port_, internal_ssh_port)]
if state.get('jupyter_lab'):
local_remote_pair_list += [(local_jupyter_port_, jupyter_port)]
if state.get('vscode_server'):
vscode_port = task_parameters.get('properties/vscode_port')
try:
if vscode_port and int(vscode_port) <= 0:
vscode_port = None
except (ValueError, TypeError):
pass
if vscode_port:
local_remote_pair_list += [(local_vscode_port_, vscode_port)]
if not jupyter_port:
if not jupyter_port and state.get('jupyter_lab'):
print('Waiting for Jupyter server...')
continue
@ -523,19 +635,23 @@ def monitor_ssh_tunnel(state, task):
pass
if not ssh_process or not ssh_process.isalive():
ssh_process = start_ssh_tunnel(
ssh_process, ssh_password = start_ssh_tunnel(
state.get('username') or 'root',
remote_address, ssh_port, ssh_password,
local_remote_pair_list=local_remote_pair_list)
if ssh_process and ssh_process.isalive():
msg = \
'Interactive session is running:\n'\
'SSH: ssh root@localhost -p {local_ssh_port} [password: {ssh_password}]\n'\
'Jupyter Lab URL: http://localhost:{local_jupyter_port}/?token={jupyter_token}'.format(
local_jupyter_port=local_jupyter_port, local_ssh_port=local_ssh_port,
ssh_password=ssh_password, jupyter_token=jupyter_token)
'SSH: ssh {username}@localhost -p {local_ssh_port} [password: {ssh_password}]'.format(
username=state.get('username') or 'root',
local_ssh_port=local_ssh_port, ssh_password=ssh_password)
if jupyter_port:
msg += \
'\nJupyter Lab URL: http://localhost:{local_jupyter_port}/?token={jupyter_token}'.format(
local_jupyter_port=local_jupyter_port, jupyter_token=jupyter_token.rstrip())
if vscode_port:
msg += 'VSCode server available at http://localhost:{local_vscode_port}/'.format(
msg += '\nVSCode server available at http://localhost:{local_vscode_port}/'.format(
local_vscode_port=local_vscode_port)
print(msg)
@ -559,7 +675,7 @@ def monitor_ssh_tunnel(state, task):
proc = psutil.Process(ssh_process.pid)
open_ports = [p.laddr.port for p in proc.connections(kind='tcp4') if p.status == 'LISTEN']
remote_ports = [p.raddr.port for p in proc.connections(kind='tcp4') if p.status == 'ESTABLISHED']
if int(local_jupyter_port_) not in open_ports or \
if (state.get('jupyter_lab') and int(local_jupyter_port_) not in open_ports) or \
int(local_ssh_port_) not in open_ports or \
int(ssh_port) not in remote_ports:
connect_state['reconnect'] = True
@ -619,7 +735,10 @@ def setup_parser(parser):
'Default: false (use for local / on-premises)')
parser.add_argument('--vscode-server', default=True, nargs='?', const='true', metavar='true/false',
type=lambda x: (str(x).strip().lower() in ('true', 'yes')),
help='Installing vscode server (code-server) on interactive session (default: true)')
help='Install vscode server (code-server) on interactive session (default: true)')
parser.add_argument('--jupyter-lab', default=True, nargs='?', const='true', metavar='true/false',
type=lambda x: (str(x).strip().lower() in ('true', 'yes')),
help='Install Jupyter-Lab on interactive session (default: true)')
parser.add_argument('--git-credentials', default=False, nargs='?', const='true', metavar='true/false',
type=lambda x: (str(x).strip().lower() in ('true', 'yes')),
help='If true, local .git-credentials file is sent to the interactive session. '
@ -661,7 +780,10 @@ def setup_parser(parser):
'(assumes k8s network ingestion) (default: false)')
parser.add_argument('--password', type=str, default=None,
help='Advanced: Select ssh password for the interactive session '
'(default: previously used one)')
'(default: `randomly-generated` or previously used one)')
parser.add_argument('--username', type=str, default=None,
help='Advanced: Select ssh username for the interactive session '
'(default: `root` or previously used one)')
def get_version():
@ -688,40 +810,31 @@ def cli():
if not _check_ssh_executable():
raise ValueError("Could not locate SSH executable")
# check clearml.conf
if not _check_configuration():
raise ValueError("ClearML configuration not found. Please run `clearml-init`")
# load previous state
state_file = os.path.abspath(os.path.expandvars(os.path.expanduser(args.config_file)))
state = load_state(state_file)
task = None
if not args.debugging and (args.attach or state.get('task_id')):
task_id = args.attach or state.get('task_id')
print('Checking previous session')
try:
task = Task.get_task(task_id=task_id)
except ValueError:
task = None
previous_status = task.get_status() if task else None
if previous_status == 'in_progress':
# only ask if we were not requested directly
if args.attach is False:
choice = input('Connect to active session id={} [Y]/n? '.format(task_id))
if choice in ('n', 'N', 'no', 'No', 'NO'):
task = None
else:
print('Using active session id={}'.format(task_id))
else:
print('Previous session is unavailable [status={}], starting a new session.'.format(previous_status))
task = None
client = APIClient()
# get previous session, if it is running
task = _check_previous_session(client, args, state)
if task:
state['task_id'] = task.id
save_state(state, state_file)
if args.username:
state['username'] = args.username
if args.password:
state['password'] = args.password
else:
state.pop('task_id', None)
save_state(state, state_file)
print('Verifying credentials')
client = APIClient()
# update state with new args
# and make sure we have all the required fields
@ -746,7 +859,7 @@ def cli():
# wait for machine to become available
try:
wait_for_machine(task)
wait_for_machine(state, task)
except ValueError as ex:
print('\nERROR: {}'.format(ex))
return 1
@ -758,6 +871,72 @@ def cli():
print('Leaving interactive session')
def _check_previous_session(client, args, state):
# now let's see if we have the requested Task
if args.attach:
task_id = args.attach
print('Checking previous session')
try:
task = Task.get_task(task_id=task_id)
except ValueError:
task = None
status = task.get_status() if task else None
if status == 'in_progress':
if not args.debugging or task.parent == args.debugging:
# only ask if we were not requested directly
print('Using active session id={}'.format(task_id))
return task
raise ValueError('Could not connect to requested session id={} - status \'{}\''.format(
task_id, status or 'Not Found'))
# let's see if we have any other running sessions
running_task_ids_created = _get_running_tasks(client, state.get('task_id'))
if not running_task_ids_created:
return None
if args.debugging:
running_task_ids_created = [t for t in running_task_ids_created if t[2] == args.debugging]
if not running_task_ids_created:
print('No active task={} debugging session found'.format(args.debugging))
return None
# a single running session
if len(running_task_ids_created) == 1:
task_id = running_task_ids_created[0][0]
choice = input('Connect to active session id={} [Y]/n? '.format(task_id))
if str(choice).strip().lower() not in ('n', 'no'):
return Task.get_task(task_id=task_id)
# multiple sessions running
print('Active sessions:')
try:
prev_task_id = state.get('task_id')
default_i = next(i for i, (tid, _, _) in enumerate(running_task_ids_created) if prev_task_id == tid)
except StopIteration:
default_i = None
session_list = '\n'.join(
'{}{}] {} id={}'.format(i, '*' if i == default_i else '', dt.strftime("%Y-%m-%d %H:%M:%S"), tid)
for i, (tid, dt, _) in enumerate(running_task_ids_created))
while True:
try:
choice = input(session_list+'\nConnect to session [0-{}] or \'N\' to skip'.format(
len(running_task_ids_created)-1))
if choice.strip().lower().startswith('n'):
choice = None
elif default_i is not None and not choice.strip():
choice = default_i
else:
choice = int(choice)
assert 0 <= choice < len(running_task_ids_created)
break
except (TypeError, ValueError, AssertionError):
pass
if choice is None:
return None
return Task.get_task(task_id=running_task_ids_created[choice][0])
def main():
try:
cli()