From aeb1a8e64bbeeb6dc561367ef91b080ba6d20d74 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 12 Oct 2020 11:07:52 +0300 Subject: [PATCH] Automatically increase write (connection) timeouts when session header is large (default 15kb threshold 300 sec timeout) --- trains/backend_api/session/session.py | 2 ++ trains/backend_api/utils.py | 24 +++++++++++++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/trains/backend_api/session/session.py b/trains/backend_api/session/session.py index 321ca7b7..34a0efe7 100644 --- a/trains/backend_api/session/session.py +++ b/trains/backend_api/session/session.py @@ -154,6 +154,8 @@ class Session(TokenManager): "api.http.retries", ConfigTree()).as_plain_ordered_dict() http_retries_config["status_forcelist"] = self._retry_codes self.__http_session = get_http_session_with_retry(**http_retries_config) + self.__http_session.write_timeout = self._write_session_timeout + self.__http_session.request_size_threshold = self._write_session_data_size self.__worker = worker or self.get_worker_host_name() diff --git a/trains/backend_api/utils.py b/trains/backend_api/utils.py index 14fc7eca..84d19211 100644 --- a/trains/backend_api/utils.py +++ b/trains/backend_api/utils.py @@ -63,6 +63,28 @@ class TLSv1HTTPAdapter(HTTPAdapter): ssl_version=ssl.PROTOCOL_TLSv1_2) +class SessionWithTimeout(requests.Session): + write_timeout = (300., 300.) + request_size_threshold = 15000 + + def __init__(self, *args, **kwargs): + super(SessionWithTimeout, self).__init__(*args, **kwargs) + + def send(self, request, **kwargs): + if isinstance(request, requests.models.PreparedRequest) and \ + request.headers and request.headers.get('Content-Length'): + try: + if int(request.headers['Content-Length']) > self.request_size_threshold: + timeout = kwargs.get('timeout', 0) + kwargs['timeout'] = \ + (max(self.write_timeout[0], timeout[0]), max(self.write_timeout[1], timeout[1])) \ + if isinstance(timeout, (list, tuple)) \ + else max(self.write_timeout[0], timeout) + except (TypeError, ValueError, NameError): + pass + return super(SessionWithTimeout, self).send(request, **kwargs) + + def get_http_session_with_retry( total=0, connect=None, @@ -93,7 +115,7 @@ def get_http_session_with_retry( else get_config().get('api.http.pool_connections', 512) ) - session = requests.Session() + session = SessionWithTimeout() # HACK: with python 2.7 there is a potential race condition that can cause # a deadlock when importing "netrc", inside the get_netrc_auth() function