diff --git a/clearml_session/__main__.py b/clearml_session/__main__.py index 167f578..4a299d8 100644 --- a/clearml_session/__main__.py +++ b/clearml_session/__main__.py @@ -562,6 +562,7 @@ def clone_task(state, project_id=None): task_params["{}/vscode_version".format(section)] = state.get('vscode_version') or '' task_params["{}/vscode_extensions".format(section)] = state.get('vscode_extensions') or '' task_params["{}/force_dropbear".format(section)] = bool(state.get('force_dropbear')) + task_params["{}/tailscale".format(section)] = bool(state.get('tailscale')) if state.get('user_folder'): task_params['{}/user_base_directory'.format(section)] = state.get('user_folder') docker = state.get('docker') or task.get_base_docker() @@ -1087,6 +1088,9 @@ def setup_parser(parser): action='store_true', default=False, help='Automatic yes to prompts; assume \"yes\" as answer ' 'to all prompts and run non-interactively',) + parser.add_argument('--tailscale', + action='store_true', default=False, + help='Use tailscale to network (host and client need tailscale access)',) def get_version(): diff --git a/clearml_session/interactive_session_task.py b/clearml_session/interactive_session_task.py index 0a3819c..8146d0c 100644 --- a/clearml_session/interactive_session_task.py +++ b/clearml_session/interactive_session_task.py @@ -1,4 +1,5 @@ import base64 +import ipaddress import json import os import socket @@ -910,6 +911,33 @@ def setup_user_env(param, task): def get_host_name(task, param): + hostname = [] + hostnames = [] + + + if task.get_parameter(name='interactive_session/tailscale'): + def get_tailscale_interfaces(): + interfaces = psutil.net_if_addrs() + tailscale_address = None + for interface_name, interface_addresses in interfaces.items(): + for address in interface_addresses: + if "tailscale" in interface_name: + print(address) + try: + ipaddress.IPv4Address(address.address) + except ValueError: + pass + else: + tailscale_address = address.address + return tailscale_address + + tl_addr = get_tailscale_interfaces() + if tl_addr: + hostname = tl_addr + hostnames = tl_addr + task.set_parameter(name='properties/external_address', value=str(tl_addr)) + return hostname, hostnames + # noinspection PyBroadException try: hostname = socket.gethostname() @@ -1014,6 +1042,7 @@ def main(): "public_ip": False, "ssh_ports": None, "force_dropbear": False, + "tailscale": False } task = init_task(param, default_ssh_fingerprint)