clearml-session/clearml_session/tcp_proxy.py

359 lines
16 KiB
Python
Raw Normal View History

2020-12-22 19:32:02 +00:00
import hashlib
import sys
import threading
import socket
import time
import select
import errno
from typing import Union
class TcpProxy(object):
__header = 'PROXY#'
__close_header = 'CLOSE#'
__uid_length = 64
__socket_test_timeout = 3
__max_sockets = 100
__wait_timeout = 300 # make sure we do not collect lost sockets, and drop it after 5 minutes
__default_packet_size = 4096
def __init__(self,
listen_port=8868, target_port=8878, proxy_state=None, verbose=None,
keep_connection=False, is_connection_server=False):
# type: (int, int, dict, bool, bool, bool) -> ()
self.listen_ip = '127.0.0.1'
self.target_ip = '127.0.0.1'
self.logfile = None # sys.stdout
self.listen_port = listen_port
self.target_port = target_port
self.proxy_state = proxy_state or {}
self.verbose = verbose
self.proxy_socket = None
self.active_local_sockets = {}
self.close_local_sockets = set()
self.keep_connection = keep_connection
self.keep_connection_server = keep_connection and is_connection_server
self.keep_connection_client = keep_connection and not is_connection_server
# set max number of open files
# noinspection PyBroadException
try:
if sys.platform == 'win32':
import ctypes
ctypes.windll.msvcrt._setmaxstdio(max(2048, ctypes.windll.msvcrt._getmaxstdio())) # noqa
else:
import resource
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (max(4096, soft), hard))
except Exception:
pass
2023-02-11 21:13:47 +00:00
self._proxy_daemon_thread = threading.Thread(target=self.daemon, daemon=True)
2020-12-22 19:32:02 +00:00
self._proxy_daemon_thread.start()
def get_thread(self):
return self._proxy_daemon_thread
@staticmethod
def receive_from(s, size=0):
# type: (socket.socket, int) -> bytes
# receive data from a socket until no more data is there
b = b""
while True:
data = s.recv(size-len(b) if size else TcpProxy.__default_packet_size)
b += data
if size and len(b) < size:
continue
if size or not data or len(data) < TcpProxy.__default_packet_size:
break
return b
@staticmethod
def send_to(s, data):
# type: (socket.socket, Union[str, bytes]) -> ()
s.send(data.encode() if isinstance(data, str) else data)
def start_proxy_thread(self, local_socket, uuid, init_data):
try:
remote_socket = self._open_remote_socket(local_socket)
except Exception as ex:
self.vprint('Exception {}: {}'.format(type(ex), ex))
return
while True:
try:
init_data_ = init_data
init_data = None
self._process_socket_proxy(local_socket, remote_socket, uuid=uuid, init_data=init_data_)
return
except Exception as ex:
self.vprint('Exception {}: {}'.format(type(ex), ex))
time.sleep(0.1)
def _open_remote_socket(self, local_socket):
# This method is executed in a thread. It will relay data between the local
# host and the remote host, while letting modules work on the data before
# passing it on.
remote_socket = None
while True:
if remote_socket:
remote_socket.close()
remote_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM)
timeout = 60
try:
remote_socket.settimeout(timeout)
remote_socket.connect((self.target_ip, self.target_port))
msg = 'Connected to {}'.format(remote_socket.getpeername())
self.vprint(msg)
self.log(msg)
except socket.error as serr:
if serr.errno == errno.ECONNREFUSED:
# for s in [remote_socket, local_socket]:
# s.close()
msg = '{}, {}:{} - Connection refused'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), self.target_ip, self.target_port)
self.vprint(msg)
self.log(msg)
# return None
self.proxy_state['reconnect'] = True
time.sleep(1)
continue
elif serr.errno == errno.ETIMEDOUT:
# for s in [remote_socket, local_socket]:
# s.close()
msg = '{}, {}:{} - Connection connection timed out'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), self.target_ip, self.target_port)
self.vprint(msg)
self.log(msg)
# return None
self.proxy_state['reconnect'] = True
time.sleep(1)
continue
else:
self.vprint("Connection error {}".format(serr.errno))
for s in [remote_socket, local_socket]:
s.close()
raise serr
break
return remote_socket
def _process_socket_proxy(self, local_socket, remote_socket, uuid=None, init_data=None):
# This method is executed in a thread. It will relay data between the local
# host and the remote host, while letting modules work on the data before
# passing it on.
timeout = 60
# if we are self.keep_connection_client we need to generate uuid, send it
if self.keep_connection_client:
if uuid is None:
uuid = hashlib.sha256('{}{}'.format(time.time(), local_socket.getpeername()).encode()).hexdigest()
self.vprint('sending UUID {}'.format(uuid))
self.send_to(remote_socket, self.__header + uuid)
# check if we need to send init_data
if init_data:
self.vprint('sending init data {}'.format(len(init_data)))
self.send_to(remote_socket, init_data)
# This loop ends when no more data is received on either the local or the
# remote socket
running = True
while running:
read_sockets, _, _ = select.select([remote_socket, local_socket], [], [])
for sock in read_sockets:
try:
peer = sock.getpeername()
except socket.error as serr:
if serr.errno == errno.ENOTCONN:
# kind of a blind shot at fixing issue #15
# I don't yet understand how this error can happen,
# but if it happens I'll just shut down the thread
# the connection is not in a useful state anymore
for s in [remote_socket, local_socket]:
s.close()
running = False
break
else:
self.vprint("{}: Socket exception in start_proxy_thread".format(
time.strftime('%Y-%m-%d %H:%M:%S')))
raise serr
data = self.receive_from(sock)
self.log('Received %d bytes' % len(data))
if sock == local_socket:
if len(data):
# log(args.logfile, b'< < < out\n' + data)
self.send_to(remote_socket, data)
else:
msg = "Connection from local client %s:%d closed" % peer
self.vprint(msg)
self.log(msg)
local_socket.close()
if not self.keep_connection or not uuid:
remote_socket.close()
running = False
elif self.keep_connection_server:
# test remote socket
self.vprint('waiting for reconnection, sleep 1 sec')
tic = time.time()
while uuid not in self.close_local_sockets and \
self.active_local_sockets.get(uuid, {}).get('local_socket') == local_socket:
time.sleep(1)
self.vprint('wait local reconnect [{}]'.format(uuid))
if time.time() - tic > self.__wait_timeout:
remote_socket.close()
running = False
break
if not running:
break
self.vprint('done waiting')
if uuid in self.close_local_sockets:
self.vprint('client closed connection')
remote_socket.close()
running = False
self.close_local_sockets.remove(uuid)
else:
self.vprint('reconnecting local client')
local_socket = self.active_local_sockets.get(uuid, {}).get('local_socket')
elif self.keep_connection_client:
# send UUID goodbye message
self.vprint('client {} closing socket'.format(uuid))
remote_socket.close()
running = False
break
elif sock == remote_socket:
if len(data):
# log(args.logfile, b'> > > in\n' + data)
self.send_to(local_socket, data)
else:
msg = "Connection to remote server %s:%d closed" % peer
self.vprint(msg)
self.log(msg)
remote_socket.close()
if self.keep_connection_client and uuid:
# self.proxy_state['reconnect'] = True
self.vprint('Wait for remote reconnect')
time.sleep(1)
return self.start_proxy_thread(local_socket, uuid=uuid, init_data=None)
else:
local_socket.close()
running = False
break
# remove the socket from the global list
if uuid:
self.active_local_sockets.pop(uuid, None)
if self.keep_connection_client:
self._send_remote_close_msg(timeout, uuid)
def _send_remote_close_msg(self, timeout, uuid):
if not self.keep_connection_client or not uuid:
return
try:
self.vprint('create new control socket')
control_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM)
control_socket.settimeout(timeout)
control_socket.connect((self.target_ip, self.target_port))
self.vprint('send close header [{}]'.format(uuid))
self.send_to(control_socket, self.__close_header + uuid)
self.vprint('close control_socket')
control_socket.close()
except Exception as ex:
self.vprint('Error sending close header, '.format(ex))
def log(self, message, message_only=False):
# if message_only is True, only the message will be logged
# otherwise the message will be prefixed with a timestamp and a line is
# written after the message to make the log file easier to read
handle = self.logfile
if handle is None:
return
if not isinstance(message, bytes):
message = bytes(message, 'ascii')
if not message_only:
logentry = bytes("%s %s\n" % (time.strftime("%Y-%m-%d %H:%M:%S"), str(time.time())), 'ascii')
else:
logentry = b''
logentry += message
if not message_only:
logentry += b'\n' + b'-' * 20 + b'\n'
handle.write(logentry.decode())
def vprint(self, msg):
# this will print msg, but only if is_verbose is True
if self.verbose:
print(msg)
def daemon(self):
# this is the socket we will listen on for incoming connections
self.proxy_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.proxy_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
self.proxy_socket.bind((self.listen_ip, self.listen_port))
except socket.error as e:
print(e.strerror)
sys.exit(5)
self.proxy_socket.listen(self.__max_sockets)
# endless loop
while True:
try:
in_socket, in_addrinfo = self.proxy_socket.accept()
msg = 'Connection from %s:%d' % in_addrinfo # noqa
self.vprint(msg)
self.log(msg)
uuid = None
init_data = None
if self.keep_connection_server:
read_sockets, _, _ = select.select([in_socket], [], [])
if read_sockets:
data = self.receive_from(in_socket, size=self.__uid_length + len(self.__header))
self.vprint('Reading header [{}]'.format(len(data)))
if len(data) == self.__uid_length + len(self.__header):
# noinspection PyBroadException
try:
header = data.decode()
except Exception:
header = None
if header.startswith(self.__header):
uuid = header[len(self.__header):]
self.vprint('Reading UUID [{}] {}'.format(len(data), uuid))
elif header.startswith(self.__close_header):
uuid = header[len(self.__close_header):]
self.vprint('Closing UUID [{}] {}'.format(len(data), uuid))
self.close_local_sockets.add(uuid)
continue
else:
init_data = data
else:
init_data = data
if self.active_local_sockets and uuid is not None:
self.vprint('Check waiting threads')
# noinspection PyBroadException
try:
if uuid in self.active_local_sockets:
self.vprint('Updating thread uuid {}'.format(uuid))
self.active_local_sockets[uuid]['local_socket'] = in_socket
continue
except Exception:
pass
if uuid:
self.active_local_sockets[uuid] = {'local_socket': in_socket}
# check if thread is waiting
2023-02-11 21:13:47 +00:00
proxy_thread = threading.Thread(
target=self.start_proxy_thread, args=(in_socket, uuid, init_data), daemon=True)
2020-12-22 19:32:02 +00:00
self.log("Starting proxy thread " + proxy_thread.name)
proxy_thread.start()
except Exception as ex:
msg = 'Exception: {}'.format(ex)
self.vprint(msg)
self.log(msg)