Delete old sessions only after a successful session was created (this ensures we do not delete workspaces unless we have a new successful session)

This commit is contained in:
allegroai 2024-05-20 15:54:40 +03:00
parent 78300d9ffb
commit d3ffdf92e8

View File

@ -7,7 +7,7 @@ import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from argparse import ArgumentParser, FileType from argparse import ArgumentParser, FileType
from functools import reduce from functools import reduce, partial
from getpass import getpass from getpass import getpass
from io import TextIOBase, StringIO from io import TextIOBase, StringIO
from time import time, sleep from time import time, sleep
@ -845,9 +845,10 @@ def start_ssh_tunnel(username, remote_address, ssh_port, ssh_password, local_rem
args.extend(['-L', '{}:localhost:{}'.format(local, remote)]) args.extend(['-L', '{}:localhost:{}'.format(local, remote)])
# store SSH output # store SSH output
fd = StringIO() if debug else sys.stdout fd = sys.stdout if debug else StringIO()
command = None command = None
child = None
# noinspection PyBroadException # noinspection PyBroadException
try: try:
command = _check_ssh_executable() command = _check_ssh_executable()
@ -895,13 +896,21 @@ def start_ssh_tunnel(username, remote_address, ssh_port, ssh_password, local_rem
if child: if child:
child.terminate(force=True) child.terminate(force=True)
child = None child = None
if child:
# noinspection PyBroadException
try:
child.flush()
except BaseException as ex:
pass # print("Failed to flush: {}".format(ex))
print('\n') print('\n')
if child: if child:
child.logfile = None child.logfile = None
return child, ssh_password return child, ssh_password
def monitor_ssh_tunnel(state, task): def monitor_ssh_tunnel(state, task, ssh_setup_completed_callback=None):
def interactive_ssh(p): def interactive_ssh(p):
import struct, fcntl, termios, signal, sys # noqa import struct, fcntl, termios, signal, sys # noqa
@ -1053,6 +1062,16 @@ def monitor_ssh_tunnel(state, task):
if workspace_header_msg: if workspace_header_msg:
msg += "\n\n{}".format(workspace_header_msg) msg += "\n\n{}".format(workspace_header_msg)
# we are here, we just connected, if this is the first time run the callback
if ssh_setup_completed_callback and callable(ssh_setup_completed_callback):
print("SSH setup completed calling callback")
try:
ssh_setup_completed_callback()
except Exception as ex:
print("Error executing callback function: {}".format(ex))
# so we only do it once
ssh_setup_completed_callback = None
print(msg) print(msg)
print(connect_message) print(connect_message)
else: else:
@ -1445,6 +1464,8 @@ def cli():
# get previous session, if it is running # get previous session, if it is running
task = _get_previous_session(client, args, state, task_id=args.attach) task = _get_previous_session(client, args, state, task_id=args.attach)
delete_old_tasks_callback = None
if task: if task:
state['task_id'] = task.id state['task_id'] = task.id
save_state(state, state_file) save_state(state, state_file)
@ -1472,7 +1493,8 @@ def cli():
ask_launch(args) ask_launch(args)
# remove old Tasks created by us. # remove old Tasks created by us.
delete_old_tasks(state, client, state.get('base_task_id')) # (now we do it Only after a successful remote session)
delete_old_tasks_callback = partial(delete_old_tasks, state, client, state.get('base_task_id'))
# Clone the Task and adjust parameters # Clone the Task and adjust parameters
task = clone_task(state) task = clone_task(state)
@ -1490,7 +1512,7 @@ def cli():
return 1 return 1
# launch ssh tunnel # launch ssh tunnel
monitor_ssh_tunnel(state, task) monitor_ssh_tunnel(state, task, ssh_setup_completed_callback=delete_old_tasks_callback)
# we are done # we are done
print('Goodbye') print('Goodbye')