Delete old sessions only after a successful session was created (this ensures we do not delete workspaces unless we have a new successful session)

This commit is contained in:
allegroai 2024-05-20 15:54:40 +03:00
parent 78300d9ffb
commit d3ffdf92e8

View File

@ -7,7 +7,7 @@ import subprocess
import sys
from pathlib import Path
from argparse import ArgumentParser, FileType
from functools import reduce
from functools import reduce, partial
from getpass import getpass
from io import TextIOBase, StringIO
from time import time, sleep
@ -845,9 +845,10 @@ def start_ssh_tunnel(username, remote_address, ssh_port, ssh_password, local_rem
args.extend(['-L', '{}:localhost:{}'.format(local, remote)])
# store SSH output
fd = StringIO() if debug else sys.stdout
fd = sys.stdout if debug else StringIO()
command = None
child = None
# noinspection PyBroadException
try:
command = _check_ssh_executable()
@ -895,13 +896,21 @@ def start_ssh_tunnel(username, remote_address, ssh_port, ssh_password, local_rem
if child:
child.terminate(force=True)
child = None
if child:
# noinspection PyBroadException
try:
child.flush()
except BaseException as ex:
pass # print("Failed to flush: {}".format(ex))
print('\n')
if child:
child.logfile = None
return child, ssh_password
def monitor_ssh_tunnel(state, task):
def monitor_ssh_tunnel(state, task, ssh_setup_completed_callback=None):
def interactive_ssh(p):
import struct, fcntl, termios, signal, sys # noqa
@ -1053,6 +1062,16 @@ def monitor_ssh_tunnel(state, task):
if workspace_header_msg:
msg += "\n\n{}".format(workspace_header_msg)
# we are here, we just connected, if this is the first time run the callback
if ssh_setup_completed_callback and callable(ssh_setup_completed_callback):
print("SSH setup completed calling callback")
try:
ssh_setup_completed_callback()
except Exception as ex:
print("Error executing callback function: {}".format(ex))
# so we only do it once
ssh_setup_completed_callback = None
print(msg)
print(connect_message)
else:
@ -1445,6 +1464,8 @@ def cli():
# get previous session, if it is running
task = _get_previous_session(client, args, state, task_id=args.attach)
delete_old_tasks_callback = None
if task:
state['task_id'] = task.id
save_state(state, state_file)
@ -1472,7 +1493,8 @@ def cli():
ask_launch(args)
# remove old Tasks created by us.
delete_old_tasks(state, client, state.get('base_task_id'))
# (now we do it Only after a successful remote session)
delete_old_tasks_callback = partial(delete_old_tasks, state, client, state.get('base_task_id'))
# Clone the Task and adjust parameters
task = clone_task(state)
@ -1490,7 +1512,7 @@ def cli():
return 1
# launch ssh tunnel
monitor_ssh_tunnel(state, task)
monitor_ssh_tunnel(state, task, ssh_setup_completed_callback=delete_old_tasks_callback)
# we are done
print('Goodbye')