mirror of
https://github.com/clearml/clearml-session
synced 2025-03-13 07:08:08 +00:00
Add multi machine session detection. Add multi session support. Add venv agent support.
This commit is contained in:
parent
b30e7b6a03
commit
75c47f6b73
@ -85,6 +85,53 @@ def _check_ssh_executable():
|
||||
return None
|
||||
|
||||
|
||||
def _check_configuration():
|
||||
from clearml.backend_api import Session
|
||||
|
||||
return Session.get_api_server_host() != Session.default_demo_host
|
||||
|
||||
|
||||
def _check_available_port(port, ipv6=True):
|
||||
""" True -- it's possible to listen on this port for TCP/IPv4 or TCP/IPv6
|
||||
connections. False -- otherwise.
|
||||
"""
|
||||
import socket
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.bind(('127.0.0.1', port))
|
||||
sock.listen(1)
|
||||
sock.close()
|
||||
if ipv6:
|
||||
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
|
||||
sock.bind(('::1', port))
|
||||
sock.listen(1)
|
||||
sock.close()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _get_available_ports(list_initial_ports):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
used_ports = [i.laddr.port for i in psutil.net_connections()]
|
||||
except Exception:
|
||||
used_ports = None
|
||||
|
||||
available_ports = []
|
||||
for p in list_initial_ports:
|
||||
port = next(
|
||||
i for i in range(p, 65000)
|
||||
if i not in available_ports and
|
||||
((used_ports is not None and i not in used_ports) or (used_ports is None and _check_available_port(i)))
|
||||
)
|
||||
available_ports.append(port)
|
||||
return available_ports
|
||||
|
||||
|
||||
def create_base_task(state, project_name=None, task_name=None):
|
||||
task = Task.create(project_name=project_name or 'DevOps',
|
||||
task_name=task_name or 'Interactive Session',
|
||||
@ -100,7 +147,7 @@ def create_base_task(state, project_name=None, task_name=None):
|
||||
task_state['script']['working_dir'] = '.'
|
||||
task_state['script']['entry_point'] = 'interactive_session.py'
|
||||
task_state['script']['requirements'] = {'pip': '\n'.join(
|
||||
["clearml", "jupyter", "jupyterlab", "jupyterlab_git"] +
|
||||
["clearml"] + (["jupyter", "jupyterlab", "jupyterlab_git"] if state.get('jupyter_lab') else []) +
|
||||
(['pylint'] if state.get('vscode_server') else []))}
|
||||
task.update_task(task_state)
|
||||
section, _, _ = _get_config_section_name()
|
||||
@ -112,7 +159,7 @@ def create_base_task(state, project_name=None, task_name=None):
|
||||
"{}/user_key".format(section): '',
|
||||
"{}/user_secret".format(section): '',
|
||||
"properties/external_address": '',
|
||||
"properties/internal_ssh_port": 10022,
|
||||
"properties/internal_ssh_port": '',
|
||||
"properties/jupyter_token": '',
|
||||
"properties/jupyter_port": '',
|
||||
})
|
||||
@ -144,8 +191,10 @@ def create_debugging_task(state, debug_task_id):
|
||||
entry_diff_header + '\n'.join(entry_diff) + '\n' + (task_state['script']['diff'] or '')
|
||||
task_state['script']['working_dir'] = '.'
|
||||
task_state['script']['entry_point'] = '__interactive_session__.py'
|
||||
state['packages'] = (state.get('packages') or []) + ["clearml", "jupyter", "jupyterlab", "jupyterlab_git"] + (
|
||||
['pylint'] if state.get('vscode_server') else [])
|
||||
state['packages'] = \
|
||||
(state.get('packages') or []) + ["clearml"] + \
|
||||
(["jupyter", "jupyterlab", "jupyterlab_git"] if state.get('jupyter_lab') else []) + \
|
||||
(['pylint'] if state.get('vscode_server') else [])
|
||||
task.update_task(task_state)
|
||||
section, _, _ = _get_config_section_name()
|
||||
task.set_parameters({
|
||||
@ -156,7 +205,7 @@ def create_debugging_task(state, debug_task_id):
|
||||
"{}/user_key".format(section): '',
|
||||
"{}/user_secret".format(section): '',
|
||||
"properties/external_address": '',
|
||||
"properties/internal_ssh_port": 10022,
|
||||
"properties/internal_ssh_port": '',
|
||||
"properties/jupyter_token": '',
|
||||
"properties/jupyter_port": '',
|
||||
})
|
||||
@ -167,9 +216,7 @@ def create_debugging_task(state, debug_task_id):
|
||||
|
||||
def delete_old_tasks(client, base_task_id):
|
||||
print('Removing stale interactive sessions')
|
||||
res = client.session.send_request(service='users', action='get_current_user', async_enable=False)
|
||||
assert res.ok
|
||||
current_user_id = res.json()['data']['user']['id']
|
||||
current_user_id = _get_user_id(client)
|
||||
previous_tasks = client.tasks.get_all(**{
|
||||
'status': ['failed', 'stopped', 'completed'],
|
||||
'parent': base_task_id or None,
|
||||
@ -184,6 +231,40 @@ def delete_old_tasks(client, base_task_id):
|
||||
logging.getLogger().warning('{}\nFailed deleting old session {}'.format(ex, t.id))
|
||||
|
||||
|
||||
def _get_running_tasks(client, prev_task_id):
|
||||
current_user_id = _get_user_id(client)
|
||||
previous_tasks = client.tasks.get_all(**{
|
||||
'status': ['in_progress'],
|
||||
'system_tags': [system_tag],
|
||||
'page_size': 10, 'page': 0,
|
||||
'order_by': ['-last_update'],
|
||||
'user': [current_user_id], 'only_fields': ['id', 'created', 'parent']
|
||||
})
|
||||
tasks_id_created = [(t.id, t.created, t.parent) for t in previous_tasks]
|
||||
if prev_task_id and prev_task_id not in (t[0] for t in tasks_id_created):
|
||||
# manually check the last task.id
|
||||
prev_tasks = client.tasks.get_all(**{
|
||||
'status': ['in_progress'],
|
||||
'id': [prev_task_id],
|
||||
'page_size': 10, 'page': 0,
|
||||
'order_by': ['-last_update'],
|
||||
'only_fields': ['id', 'created', 'parent']
|
||||
})
|
||||
if prev_tasks:
|
||||
tasks_id_created += [(prev_tasks[0].id, prev_tasks[0].created, prev_tasks[0].parent)]
|
||||
|
||||
return tasks_id_created
|
||||
|
||||
|
||||
def _get_user_id(client):
|
||||
if not client:
|
||||
client = APIClient()
|
||||
res = client.session.send_request(service='users', action='get_current_user', async_enable=False)
|
||||
assert res.ok
|
||||
current_user_id = res.json()['data']['user']['id']
|
||||
return current_user_id
|
||||
|
||||
|
||||
def get_project_id(state):
|
||||
project_id = None
|
||||
project_name = state.get('project') or None
|
||||
@ -252,7 +333,7 @@ def get_user_inputs(args, parser, state, client):
|
||||
ask_queues = not state.get('queue')
|
||||
if state.get('queue'):
|
||||
choice = input('Use previous queue (resource) \'{}\' [Y]/n? '.format(state['queue']))
|
||||
if choice in ('n', 'N', 'no', 'No', 'NO'):
|
||||
if str(choice).strip().lower() in ('n', 'no'):
|
||||
ask_queues = True
|
||||
if ask_queues:
|
||||
print('Select the queue (resource) you request:')
|
||||
@ -262,9 +343,10 @@ def get_user_inputs(args, parser, state, client):
|
||||
queues_list = '\n'.join('{}] {}'.format(i, q) for i, q in enumerate(queues))
|
||||
while True:
|
||||
try:
|
||||
choice = int(input(queues_list+'\nSelect a queue [0-{}] '.format(len(queues))))
|
||||
choice = int(input(queues_list+'\nSelect a queue [0-{}] '.format(len(queues)-1)))
|
||||
assert 0 <= choice < len(queues)
|
||||
break
|
||||
except (TypeError, ValueError):
|
||||
except (TypeError, ValueError, AssertionError):
|
||||
pass
|
||||
state['queue'] = queues[int(choice)]
|
||||
|
||||
@ -272,7 +354,7 @@ def get_user_inputs(args, parser, state, client):
|
||||
json.dumps({k: v for k, v in state.items() if not str(k).startswith('__')}, indent=4, sort_keys=True)))
|
||||
|
||||
choice = input('Launch interactive session [Y]/n? ')
|
||||
if choice in ('n', 'N', 'no', 'No', 'NO'):
|
||||
if str(choice).strip().lower() in ('n', 'no'):
|
||||
print('User aborted')
|
||||
exit(0)
|
||||
|
||||
@ -332,6 +414,7 @@ def clone_task(state, project_id):
|
||||
task_params['{}/ssh_password'.format(section)] = state['password']
|
||||
task_params['{}/user_key'.format(section)] = config_obj.get("api.credentials.access_key")
|
||||
task_params['{}/user_secret'.format(section)] = config_obj.get("api.credentials.secret_key")
|
||||
task_params["{}/jupyterlab".format(section)] = bool(state.get('jupyter_lab'))
|
||||
task_params["{}/vscode_server".format(section)] = bool(state.get('vscode_server'))
|
||||
task_params["{}/public_ip".format(section)] = bool(state.get('public_ip'))
|
||||
if state.get('user_folder'):
|
||||
@ -380,7 +463,7 @@ def clone_task(state, project_id):
|
||||
return task
|
||||
|
||||
|
||||
def wait_for_machine(task):
|
||||
def wait_for_machine(state, task):
|
||||
# wait until task is running
|
||||
print('Waiting for remote machine allocation [id={}]'.format(task.id))
|
||||
last_status = None
|
||||
@ -400,7 +483,24 @@ def wait_for_machine(task):
|
||||
print('Waiting for environment setup to complete [usually about 20-30 seconds]')
|
||||
# monitor progress, until we get the new jupyter, then we know it is working
|
||||
task.reload()
|
||||
while not task.get_parameter('properties/jupyter_port') and task.get_status() == 'in_progress':
|
||||
|
||||
section, _, _ = _get_config_section_name()
|
||||
jupyterlab = \
|
||||
task.get_parameter("{}/jupyterlab".format(section)) or \
|
||||
task.get_parameter("General/jupyterlab") or ''
|
||||
state['jupyter_lab'] = jupyterlab.strip().lower() != 'false'
|
||||
vscode_server = \
|
||||
task.get_parameter("{}/vscode_server".format(section)) or \
|
||||
task.get_parameter("General/vscode_server") or ''
|
||||
state['vscode_server'] = vscode_server.strip().lower() != 'false'
|
||||
|
||||
wait_properties = ['properties/internal_ssh_port']
|
||||
if state.get('jupyter_lab'):
|
||||
wait_properties += ['properties/jupyter_port']
|
||||
if state.get('vscode_server'):
|
||||
wait_properties += ['properties/vscode_port']
|
||||
|
||||
while any(bool(not task.get_parameter(p)) for p in wait_properties) and task.get_status() == 'in_progress':
|
||||
print('.', end='', flush=True)
|
||||
sleep(3.)
|
||||
task.reload()
|
||||
@ -412,11 +512,11 @@ def wait_for_machine(task):
|
||||
return task
|
||||
|
||||
|
||||
def start_ssh_tunnel(remote_address, ssh_port, ssh_password, local_remote_pair_list):
|
||||
def start_ssh_tunnel(username, remote_address, ssh_port, ssh_password, local_remote_pair_list):
|
||||
print('Starting SSH tunnel')
|
||||
child = None
|
||||
args = ['-N', '-C',
|
||||
'root@{}'.format(remote_address), '-p', '{}'.format(ssh_port),
|
||||
'{}@{}'.format(username, remote_address), '-p', '{}'.format(ssh_port),
|
||||
'-o', 'UserKnownHostsFile=/dev/null',
|
||||
'-o', 'StrictHostKeyChecking=no',
|
||||
'-o', 'ServerAliveInterval=10',
|
||||
@ -431,39 +531,46 @@ def start_ssh_tunnel(remote_address, ssh_port, ssh_password, local_remote_pair_l
|
||||
command=_check_ssh_executable(),
|
||||
args=args,
|
||||
logfile=sys.stdout, timeout=20, encoding='utf-8')
|
||||
i = child.expect(['password:', r'\(yes\/no\)', r'.*[$#] ', pexpect.EOF])
|
||||
i = child.expect([r'(?i)password:', r'\(yes\/no\)', r'.*[$#] ', pexpect.EOF])
|
||||
if i == 0:
|
||||
child.sendline(ssh_password)
|
||||
try:
|
||||
child.expect(['password:'], timeout=5)
|
||||
print('Incorrect password')
|
||||
child.expect([r'(?i)password:'], timeout=5)
|
||||
print('Error: incorrect password')
|
||||
ssh_password = input('Please enter password manually: ')
|
||||
child.sendline(ssh_password)
|
||||
child.expect([r'(?i)password:'], timeout=5)
|
||||
print('Error: incorrect user input password')
|
||||
raise ValueError('Incorrect password')
|
||||
except pexpect.TIMEOUT:
|
||||
pass
|
||||
|
||||
elif i == 1:
|
||||
child.sendline("yes")
|
||||
ret1 = child.expect(["password:", pexpect.EOF])
|
||||
ret1 = child.expect([r"(?i)password:", pexpect.EOF])
|
||||
if ret1 == 0:
|
||||
child.sendline(ssh_password)
|
||||
try:
|
||||
child.expect(['password:'], timeout=5)
|
||||
print('Incorrect password')
|
||||
child.expect([r'(?i)password:'], timeout=5)
|
||||
print('Error: incorrect password')
|
||||
ssh_password = input('Please enter password manually: ')
|
||||
child.sendline(ssh_password)
|
||||
child.expect([r'(?i)password:'], timeout=5)
|
||||
print('Error: incorrect user input password')
|
||||
raise ValueError('Incorrect password')
|
||||
except pexpect.TIMEOUT:
|
||||
pass
|
||||
except Exception:
|
||||
except Exception as ex:
|
||||
child.terminate(force=True)
|
||||
child = None
|
||||
print('\n')
|
||||
return child
|
||||
return child, ssh_password
|
||||
|
||||
|
||||
def monitor_ssh_tunnel(state, task):
|
||||
print('Setting up connection to remote session')
|
||||
local_jupyter_port = 8878
|
||||
local_ssh_port = 8022
|
||||
local_vscode_port = 8898
|
||||
local_jupyter_port, local_jupyter_port_, local_ssh_port, local_ssh_port_, local_vscode_port, local_vscode_port_ = \
|
||||
_get_available_ports([8878, 8878+1, 8022, 8022+1, 8898, 8898+1])
|
||||
ssh_process = None
|
||||
sleep_period = 3
|
||||
ssh_port = jupyter_token = jupyter_port = internal_ssh_port = ssh_password = remote_address = None
|
||||
@ -471,12 +578,10 @@ def monitor_ssh_tunnel(state, task):
|
||||
|
||||
connect_state = {'reconnect': False}
|
||||
if not state.get('disable_keepalive'):
|
||||
local_jupyter_port_ = local_jupyter_port + 1
|
||||
SingleThreadProxy(local_jupyter_port, local_jupyter_port_)
|
||||
local_vscode_port_ = local_vscode_port + 1
|
||||
if state.get('jupyter_lab'):
|
||||
SingleThreadProxy(local_jupyter_port, local_jupyter_port_)
|
||||
if state.get('vscode_server'):
|
||||
SingleThreadProxy(local_vscode_port, local_vscode_port_)
|
||||
local_ssh_port_ = local_ssh_port + 1
|
||||
TcpProxy(local_ssh_port, local_ssh_port_, connect_state, verbose=False,
|
||||
keep_connection=True, is_connection_server=False)
|
||||
else:
|
||||
@ -495,7 +600,7 @@ def monitor_ssh_tunnel(state, task):
|
||||
remote_address = \
|
||||
task_parameters.get('properties/k8s-gateway-address') or \
|
||||
task_parameters.get('properties/external_address')
|
||||
ssh_password = task_parameters.get('{}/ssh_password'.format(section)) or state['password']
|
||||
ssh_password = task_parameters.get('{}/ssh_password'.format(section)) or state.get('password', '')
|
||||
internal_ssh_port = task_parameters.get('properties/internal_ssh_port')
|
||||
jupyter_port = task_parameters.get('properties/jupyter_port')
|
||||
jupyter_token = task_parameters.get('properties/jupyter_token')
|
||||
@ -504,13 +609,20 @@ def monitor_ssh_tunnel(state, task):
|
||||
task_parameters.get('properties/external_ssh_port') or internal_ssh_port
|
||||
if not state.get('disable_keepalive'):
|
||||
internal_ssh_port = task_parameters.get('properties/internal_stable_ssh_port') or internal_ssh_port
|
||||
local_remote_pair_list = [(local_jupyter_port_, jupyter_port), (local_ssh_port_, internal_ssh_port)]
|
||||
local_remote_pair_list = [(local_ssh_port_, internal_ssh_port)]
|
||||
if state.get('jupyter_lab'):
|
||||
local_remote_pair_list += [(local_jupyter_port_, jupyter_port)]
|
||||
if state.get('vscode_server'):
|
||||
vscode_port = task_parameters.get('properties/vscode_port')
|
||||
try:
|
||||
if vscode_port and int(vscode_port) <= 0:
|
||||
vscode_port = None
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if vscode_port:
|
||||
local_remote_pair_list += [(local_vscode_port_, vscode_port)]
|
||||
|
||||
if not jupyter_port:
|
||||
if not jupyter_port and state.get('jupyter_lab'):
|
||||
print('Waiting for Jupyter server...')
|
||||
continue
|
||||
|
||||
@ -523,19 +635,23 @@ def monitor_ssh_tunnel(state, task):
|
||||
pass
|
||||
|
||||
if not ssh_process or not ssh_process.isalive():
|
||||
ssh_process = start_ssh_tunnel(
|
||||
ssh_process, ssh_password = start_ssh_tunnel(
|
||||
state.get('username') or 'root',
|
||||
remote_address, ssh_port, ssh_password,
|
||||
local_remote_pair_list=local_remote_pair_list)
|
||||
|
||||
if ssh_process and ssh_process.isalive():
|
||||
msg = \
|
||||
'Interactive session is running:\n'\
|
||||
'SSH: ssh root@localhost -p {local_ssh_port} [password: {ssh_password}]\n'\
|
||||
'Jupyter Lab URL: http://localhost:{local_jupyter_port}/?token={jupyter_token}'.format(
|
||||
local_jupyter_port=local_jupyter_port, local_ssh_port=local_ssh_port,
|
||||
ssh_password=ssh_password, jupyter_token=jupyter_token)
|
||||
'SSH: ssh {username}@localhost -p {local_ssh_port} [password: {ssh_password}]'.format(
|
||||
username=state.get('username') or 'root',
|
||||
local_ssh_port=local_ssh_port, ssh_password=ssh_password)
|
||||
if jupyter_port:
|
||||
msg += \
|
||||
'\nJupyter Lab URL: http://localhost:{local_jupyter_port}/?token={jupyter_token}'.format(
|
||||
local_jupyter_port=local_jupyter_port, jupyter_token=jupyter_token.rstrip())
|
||||
if vscode_port:
|
||||
msg += 'VSCode server available at http://localhost:{local_vscode_port}/'.format(
|
||||
msg += '\nVSCode server available at http://localhost:{local_vscode_port}/'.format(
|
||||
local_vscode_port=local_vscode_port)
|
||||
print(msg)
|
||||
|
||||
@ -559,7 +675,7 @@ def monitor_ssh_tunnel(state, task):
|
||||
proc = psutil.Process(ssh_process.pid)
|
||||
open_ports = [p.laddr.port for p in proc.connections(kind='tcp4') if p.status == 'LISTEN']
|
||||
remote_ports = [p.raddr.port for p in proc.connections(kind='tcp4') if p.status == 'ESTABLISHED']
|
||||
if int(local_jupyter_port_) not in open_ports or \
|
||||
if (state.get('jupyter_lab') and int(local_jupyter_port_) not in open_ports) or \
|
||||
int(local_ssh_port_) not in open_ports or \
|
||||
int(ssh_port) not in remote_ports:
|
||||
connect_state['reconnect'] = True
|
||||
@ -619,7 +735,10 @@ def setup_parser(parser):
|
||||
'Default: false (use for local / on-premises)')
|
||||
parser.add_argument('--vscode-server', default=True, nargs='?', const='true', metavar='true/false',
|
||||
type=lambda x: (str(x).strip().lower() in ('true', 'yes')),
|
||||
help='Installing vscode server (code-server) on interactive session (default: true)')
|
||||
help='Install vscode server (code-server) on interactive session (default: true)')
|
||||
parser.add_argument('--jupyter-lab', default=True, nargs='?', const='true', metavar='true/false',
|
||||
type=lambda x: (str(x).strip().lower() in ('true', 'yes')),
|
||||
help='Install Jupyter-Lab on interactive session (default: true)')
|
||||
parser.add_argument('--git-credentials', default=False, nargs='?', const='true', metavar='true/false',
|
||||
type=lambda x: (str(x).strip().lower() in ('true', 'yes')),
|
||||
help='If true, local .git-credentials file is sent to the interactive session. '
|
||||
@ -661,7 +780,10 @@ def setup_parser(parser):
|
||||
'(assumes k8s network ingestion) (default: false)')
|
||||
parser.add_argument('--password', type=str, default=None,
|
||||
help='Advanced: Select ssh password for the interactive session '
|
||||
'(default: previously used one)')
|
||||
'(default: `randomly-generated` or previously used one)')
|
||||
parser.add_argument('--username', type=str, default=None,
|
||||
help='Advanced: Select ssh username for the interactive session '
|
||||
'(default: `root` or previously used one)')
|
||||
|
||||
|
||||
def get_version():
|
||||
@ -688,40 +810,31 @@ def cli():
|
||||
if not _check_ssh_executable():
|
||||
raise ValueError("Could not locate SSH executable")
|
||||
|
||||
# check clearml.conf
|
||||
if not _check_configuration():
|
||||
raise ValueError("ClearML configuration not found. Please run `clearml-init`")
|
||||
|
||||
# load previous state
|
||||
state_file = os.path.abspath(os.path.expandvars(os.path.expanduser(args.config_file)))
|
||||
state = load_state(state_file)
|
||||
|
||||
task = None
|
||||
if not args.debugging and (args.attach or state.get('task_id')):
|
||||
task_id = args.attach or state.get('task_id')
|
||||
print('Checking previous session')
|
||||
try:
|
||||
task = Task.get_task(task_id=task_id)
|
||||
except ValueError:
|
||||
task = None
|
||||
previous_status = task.get_status() if task else None
|
||||
if previous_status == 'in_progress':
|
||||
# only ask if we were not requested directly
|
||||
if args.attach is False:
|
||||
choice = input('Connect to active session id={} [Y]/n? '.format(task_id))
|
||||
if choice in ('n', 'N', 'no', 'No', 'NO'):
|
||||
task = None
|
||||
else:
|
||||
print('Using active session id={}'.format(task_id))
|
||||
else:
|
||||
print('Previous session is unavailable [status={}], starting a new session.'.format(previous_status))
|
||||
task = None
|
||||
client = APIClient()
|
||||
|
||||
# get previous session, if it is running
|
||||
task = _check_previous_session(client, args, state)
|
||||
|
||||
if task:
|
||||
state['task_id'] = task.id
|
||||
save_state(state, state_file)
|
||||
if args.username:
|
||||
state['username'] = args.username
|
||||
if args.password:
|
||||
state['password'] = args.password
|
||||
else:
|
||||
state.pop('task_id', None)
|
||||
save_state(state, state_file)
|
||||
|
||||
print('Verifying credentials')
|
||||
client = APIClient()
|
||||
|
||||
# update state with new args
|
||||
# and make sure we have all the required fields
|
||||
@ -746,7 +859,7 @@ def cli():
|
||||
|
||||
# wait for machine to become available
|
||||
try:
|
||||
wait_for_machine(task)
|
||||
wait_for_machine(state, task)
|
||||
except ValueError as ex:
|
||||
print('\nERROR: {}'.format(ex))
|
||||
return 1
|
||||
@ -758,6 +871,72 @@ def cli():
|
||||
print('Leaving interactive session')
|
||||
|
||||
|
||||
def _check_previous_session(client, args, state):
|
||||
# now let's see if we have the requested Task
|
||||
if args.attach:
|
||||
task_id = args.attach
|
||||
print('Checking previous session')
|
||||
try:
|
||||
task = Task.get_task(task_id=task_id)
|
||||
except ValueError:
|
||||
task = None
|
||||
status = task.get_status() if task else None
|
||||
if status == 'in_progress':
|
||||
if not args.debugging or task.parent == args.debugging:
|
||||
# only ask if we were not requested directly
|
||||
print('Using active session id={}'.format(task_id))
|
||||
return task
|
||||
raise ValueError('Could not connect to requested session id={} - status \'{}\''.format(
|
||||
task_id, status or 'Not Found'))
|
||||
|
||||
# let's see if we have any other running sessions
|
||||
running_task_ids_created = _get_running_tasks(client, state.get('task_id'))
|
||||
if not running_task_ids_created:
|
||||
return None
|
||||
|
||||
if args.debugging:
|
||||
running_task_ids_created = [t for t in running_task_ids_created if t[2] == args.debugging]
|
||||
if not running_task_ids_created:
|
||||
print('No active task={} debugging session found'.format(args.debugging))
|
||||
return None
|
||||
|
||||
# a single running session
|
||||
if len(running_task_ids_created) == 1:
|
||||
task_id = running_task_ids_created[0][0]
|
||||
choice = input('Connect to active session id={} [Y]/n? '.format(task_id))
|
||||
if str(choice).strip().lower() not in ('n', 'no'):
|
||||
return Task.get_task(task_id=task_id)
|
||||
|
||||
# multiple sessions running
|
||||
print('Active sessions:')
|
||||
try:
|
||||
prev_task_id = state.get('task_id')
|
||||
default_i = next(i for i, (tid, _, _) in enumerate(running_task_ids_created) if prev_task_id == tid)
|
||||
except StopIteration:
|
||||
default_i = None
|
||||
|
||||
session_list = '\n'.join(
|
||||
'{}{}] {} id={}'.format(i, '*' if i == default_i else '', dt.strftime("%Y-%m-%d %H:%M:%S"), tid)
|
||||
for i, (tid, dt, _) in enumerate(running_task_ids_created))
|
||||
while True:
|
||||
try:
|
||||
choice = input(session_list+'\nConnect to session [0-{}] or \'N\' to skip'.format(
|
||||
len(running_task_ids_created)-1))
|
||||
if choice.strip().lower().startswith('n'):
|
||||
choice = None
|
||||
elif default_i is not None and not choice.strip():
|
||||
choice = default_i
|
||||
else:
|
||||
choice = int(choice)
|
||||
assert 0 <= choice < len(running_task_ids_created)
|
||||
break
|
||||
except (TypeError, ValueError, AssertionError):
|
||||
pass
|
||||
if choice is None:
|
||||
return None
|
||||
return Task.get_task(task_id=running_task_ids_created[choice][0])
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
cli()
|
||||
|
Loading…
Reference in New Issue
Block a user