From d9aee8582115e1166ff50b755d602ee50ef96e01 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 22 Dec 2021 13:27:57 +0200 Subject: [PATCH] Fix forked StorageHelper should use its own ThreadExecuter --- clearml/storage/helper.py | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/clearml/storage/helper.py b/clearml/storage/helper.py index 756ee676..b56d993b 100644 --- a/clearml/storage/helper.py +++ b/clearml/storage/helper.py @@ -37,6 +37,7 @@ from ..backend_config.bucket_config import S3BucketConfigurations, GSBucketConfi from ..config import config, deferred_config from ..debugging import get_logger from ..errors import UsageError +from ..utilities.process.mp import ForkSafeRLock class StorageError(Exception): @@ -184,9 +185,10 @@ class StorageHelper(object): _helpers = {} # cache of helper instances # global terminate event for async upload threads - _terminate = threading.Event() + # _terminate = threading.Event() _async_upload_threads = set() _upload_pool = None + _upload_pool_pid = None # collect all bucket credentials that aren't empty (ignore entries with an empty key or secret) _s3_configurations = deferred_config('aws.s3', {}, transform=S3BucketConfigurations.from_config) @@ -365,7 +367,7 @@ class StorageHelper(object): # since async uploaders are daemon threads, we can just return and let them close by themselves return # signal all threads to terminate and give them a chance for 'timeout' seconds (total, not per-thread) - cls._terminate.set() + # cls._terminate.set() remaining_timeout = timeout for thread in cls._async_upload_threads: t = time() @@ -980,7 +982,8 @@ class StorageHelper(object): @staticmethod def _initialize_upload_pool(): - if not StorageHelper._upload_pool: + if not StorageHelper._upload_pool or StorageHelper._upload_pool_pid != os.getpid(): + StorageHelper._upload_pool_pid = os.getpid() StorageHelper._upload_pool = ThreadPool(processes=1) @staticmethod @@ -1235,9 +1238,12 @@ class _Boto3Driver(_Driver): _min_pool_connections = 512 _max_multipart_concurrency = deferred_config('aws.boto3.max_multipart_concurrency', 16) _pool_connections = deferred_config('aws.boto3.pool_connections', 512) + _connect_timeout = deferred_config('aws.boto3.connect_timeout', 60) + _read_timeout = deferred_config('aws.boto3.read_timeout', 60) _stream_download_pool_connections = 128 _stream_download_pool = None + _stream_download_pool_pid = None _containers = {} @@ -1247,7 +1253,7 @@ class _Boto3Driver(_Driver): _bucket_location_failure_reported = set() class _Container(object): - _creation_lock = threading.Lock() + _creation_lock = ForkSafeRLock() def __init__(self, name, cfg): try: @@ -1272,8 +1278,10 @@ class _Boto3Driver(_Driver): "verify": cfg.verify, "config": botocore.client.Config( max_pool_connections=max( - _Boto3Driver._min_pool_connections, - _Boto3Driver._pool_connections) + int(_Boto3Driver._min_pool_connections), + int(_Boto3Driver._pool_connections)), + connect_timeout=int(_Boto3Driver._connect_timeout), + read_timeout=int(_Boto3Driver._read_timeout), ) } if not cfg.use_credentials_chain: @@ -1297,7 +1305,8 @@ class _Boto3Driver(_Driver): pass def _get_stream_download_pool(self): - if self._stream_download_pool is None: + if self._stream_download_pool is None or self._stream_download_pool_pid != os.getpid(): + self._stream_download_pool_pid = os.getpid() self._stream_download_pool = ThreadPoolExecutor(max_workers=self._stream_download_pool_connections) return self._stream_download_pool @@ -1390,12 +1399,12 @@ class _Boto3Driver(_Driver): self.get_logger().warning('failed saving after download: overwrite=False and file exists (%s)' % str(p)) return container = self._containers[obj.container_name] - obj.download_file(str(p), - Callback=callback, - Config=boto3.s3.transfer.TransferConfig( - use_threads=container.config.multipart, - max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1, - num_download_attempts=container.config.retries)) + Config = boto3.s3.transfer.TransferConfig( + use_threads=container.config.multipart, + max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1, + num_download_attempts=container.config.retries + ) + obj.download_file(str(p), Callback=callback, Config=Config) @classmethod def _test_bucket_config(cls, conf, log, test_path='', raise_on_error=True, log_on_error=True): @@ -1503,6 +1512,7 @@ class _GoogleCloudStorageDriver(_Driver): _stream_download_pool_connections = 128 _stream_download_pool = None + _stream_download_pool_pid = None _containers = {} @@ -1538,7 +1548,8 @@ class _GoogleCloudStorageDriver(_Driver): self.bucket = self.client.bucket(self.name) def _get_stream_download_pool(self): - if self._stream_download_pool is None: + if self._stream_download_pool is None or self._stream_download_pool_pid != os.getpid(): + self._stream_download_pool_pid = os.getpid() self._stream_download_pool = ThreadPoolExecutor(max_workers=self._stream_download_pool_connections) return self._stream_download_pool