Fix forked StorageHelper should use its own ThreadExecuter

This commit is contained in:
allegroai 2021-12-22 13:27:57 +02:00
parent 7e32278ebf
commit d9aee85821

View File

@ -37,6 +37,7 @@ from ..backend_config.bucket_config import S3BucketConfigurations, GSBucketConfi
from ..config import config, deferred_config
from ..debugging import get_logger
from ..errors import UsageError
from ..utilities.process.mp import ForkSafeRLock
class StorageError(Exception):
@ -184,9 +185,10 @@ class StorageHelper(object):
_helpers = {} # cache of helper instances
# global terminate event for async upload threads
_terminate = threading.Event()
# _terminate = threading.Event()
_async_upload_threads = set()
_upload_pool = None
_upload_pool_pid = None
# collect all bucket credentials that aren't empty (ignore entries with an empty key or secret)
_s3_configurations = deferred_config('aws.s3', {}, transform=S3BucketConfigurations.from_config)
@ -365,7 +367,7 @@ class StorageHelper(object):
# since async uploaders are daemon threads, we can just return and let them close by themselves
return
# signal all threads to terminate and give them a chance for 'timeout' seconds (total, not per-thread)
cls._terminate.set()
# cls._terminate.set()
remaining_timeout = timeout
for thread in cls._async_upload_threads:
t = time()
@ -980,7 +982,8 @@ class StorageHelper(object):
@staticmethod
def _initialize_upload_pool():
if not StorageHelper._upload_pool:
if not StorageHelper._upload_pool or StorageHelper._upload_pool_pid != os.getpid():
StorageHelper._upload_pool_pid = os.getpid()
StorageHelper._upload_pool = ThreadPool(processes=1)
@staticmethod
@ -1235,9 +1238,12 @@ class _Boto3Driver(_Driver):
_min_pool_connections = 512
_max_multipart_concurrency = deferred_config('aws.boto3.max_multipart_concurrency', 16)
_pool_connections = deferred_config('aws.boto3.pool_connections', 512)
_connect_timeout = deferred_config('aws.boto3.connect_timeout', 60)
_read_timeout = deferred_config('aws.boto3.read_timeout', 60)
_stream_download_pool_connections = 128
_stream_download_pool = None
_stream_download_pool_pid = None
_containers = {}
@ -1247,7 +1253,7 @@ class _Boto3Driver(_Driver):
_bucket_location_failure_reported = set()
class _Container(object):
_creation_lock = threading.Lock()
_creation_lock = ForkSafeRLock()
def __init__(self, name, cfg):
try:
@ -1272,8 +1278,10 @@ class _Boto3Driver(_Driver):
"verify": cfg.verify,
"config": botocore.client.Config(
max_pool_connections=max(
_Boto3Driver._min_pool_connections,
_Boto3Driver._pool_connections)
int(_Boto3Driver._min_pool_connections),
int(_Boto3Driver._pool_connections)),
connect_timeout=int(_Boto3Driver._connect_timeout),
read_timeout=int(_Boto3Driver._read_timeout),
)
}
if not cfg.use_credentials_chain:
@ -1297,7 +1305,8 @@ class _Boto3Driver(_Driver):
pass
def _get_stream_download_pool(self):
if self._stream_download_pool is None:
if self._stream_download_pool is None or self._stream_download_pool_pid != os.getpid():
self._stream_download_pool_pid = os.getpid()
self._stream_download_pool = ThreadPoolExecutor(max_workers=self._stream_download_pool_connections)
return self._stream_download_pool
@ -1390,12 +1399,12 @@ class _Boto3Driver(_Driver):
self.get_logger().warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
return
container = self._containers[obj.container_name]
obj.download_file(str(p),
Callback=callback,
Config=boto3.s3.transfer.TransferConfig(
use_threads=container.config.multipart,
max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1,
num_download_attempts=container.config.retries))
Config = boto3.s3.transfer.TransferConfig(
use_threads=container.config.multipart,
max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1,
num_download_attempts=container.config.retries
)
obj.download_file(str(p), Callback=callback, Config=Config)
@classmethod
def _test_bucket_config(cls, conf, log, test_path='', raise_on_error=True, log_on_error=True):
@ -1503,6 +1512,7 @@ class _GoogleCloudStorageDriver(_Driver):
_stream_download_pool_connections = 128
_stream_download_pool = None
_stream_download_pool_pid = None
_containers = {}
@ -1538,7 +1548,8 @@ class _GoogleCloudStorageDriver(_Driver):
self.bucket = self.client.bucket(self.name)
def _get_stream_download_pool(self):
if self._stream_download_pool is None:
if self._stream_download_pool is None or self._stream_download_pool_pid != os.getpid():
self._stream_download_pool_pid = os.getpid()
self._stream_download_pool = ThreadPoolExecutor(max_workers=self._stream_download_pool_connections)
return self._stream_download_pool