Support num_workers in dataset operations

Support max connections setting for Azure storage using the `sdk.azure.storage.max_connection` setting
This commit is contained in:
allegroai 2022-11-21 17:12:02 +02:00
parent 68467d7288
commit ed2b8ed850
7 changed files with 176 additions and 83 deletions

View File

@ -392,7 +392,11 @@ class AzureContainerConfigurations(object):
)) ))
if configuration is None: if configuration is None:
return cls(default_container_configs, default_account=default_account, default_key=default_key) return cls(
default_container_configs,
default_account=default_account,
default_key=default_key
)
containers = configuration.get("containers", list()) containers = configuration.get("containers", list())
container_configs = [AzureContainerConfig(**entry) for entry in containers] + default_container_configs container_configs = [AzureContainerConfig(**entry) for entry in containers] + default_container_configs

View File

@ -118,6 +118,12 @@ def cli():
add.add_argument('--non-recursive', action='store_true', default=False, add.add_argument('--non-recursive', action='store_true', default=False,
help='Disable recursive scan of files') help='Disable recursive scan of files')
add.add_argument('--verbose', action='store_true', default=False, help='Verbose reporting') add.add_argument('--verbose', action='store_true', default=False, help='Verbose reporting')
add.add_argument(
"--max-workers",
type=int,
default=None,
help="Number of threads to add the files with. Defaults to the number of logical cores",
)
add.set_defaults(func=ds_add) add.set_defaults(func=ds_add)
set_description = subparsers.add_parser("set-description", help="Set description to the dataset") set_description = subparsers.add_parser("set-description", help="Set description to the dataset")
@ -195,6 +201,12 @@ def cli():
help='Set dataset artifact chunk size in MB. Default 512, (pass -1 for a single chunk). ' help='Set dataset artifact chunk size in MB. Default 512, (pass -1 for a single chunk). '
'Example: 512, dataset will be split and uploaded in 512mb chunks.') 'Example: 512, dataset will be split and uploaded in 512mb chunks.')
upload.add_argument('--verbose', default=False, action='store_true', help='Verbose reporting') upload.add_argument('--verbose', default=False, action='store_true', help='Verbose reporting')
upload.add_argument(
"--max-workers",
type=int,
default=None,
help="Number of threads to upload the files with. Defaults to 1 if uploading to a cloud provider ('s3', 'azure', 'gs') OR to the number of logical cores otherwise",
)
upload.set_defaults(func=ds_upload) upload.set_defaults(func=ds_upload)
finalize = subparsers.add_parser('close', help='Finalize and close the dataset (implies auto upload)') finalize = subparsers.add_parser('close', help='Finalize and close the dataset (implies auto upload)')
@ -210,6 +222,12 @@ def cli():
help='Set dataset artifact chunk size in MB. Default 512, (pass -1 for a single chunk). ' help='Set dataset artifact chunk size in MB. Default 512, (pass -1 for a single chunk). '
'Example: 512, dataset will be split and uploaded in 512mb chunks.') 'Example: 512, dataset will be split and uploaded in 512mb chunks.')
finalize.add_argument('--verbose', action='store_true', default=False, help='Verbose reporting') finalize.add_argument('--verbose', action='store_true', default=False, help='Verbose reporting')
finalize.add_argument(
"--max-workers",
type=int,
default=None,
help="Number of threads to upload the files with. Defaults to 1 if uploading to a cloud provider ('s3', 'azure', 'gs') OR to the number of logical cores otherwise",
)
finalize.set_defaults(func=ds_close) finalize.set_defaults(func=ds_close)
publish = subparsers.add_parser('publish', help='Publish dataset task') publish = subparsers.add_parser('publish', help='Publish dataset task')
@ -327,6 +345,12 @@ def cli():
'can be divided into 4 parts') 'can be divided into 4 parts')
get.add_argument('--overwrite', action='store_true', default=False, help='If True, overwrite the target folder') get.add_argument('--overwrite', action='store_true', default=False, help='If True, overwrite the target folder')
get.add_argument('--verbose', action='store_true', default=False, help='Verbose reporting') get.add_argument('--verbose', action='store_true', default=False, help='Verbose reporting')
get.add_argument(
"--max-workers",
type=int,
default=None,
help="Number of threads to get the files with. Defaults to the number of logical cores",
)
get.set_defaults(func=ds_get) get.set_defaults(func=ds_get)
args = parser.parse_args() args = parser.parse_args()
@ -426,7 +450,9 @@ def ds_get(args):
pass pass
if args.copy: if args.copy:
ds_folder = args.copy ds_folder = args.copy
ds.get_mutable_local_copy(target_folder=ds_folder, part=args.part, num_parts=args.num_parts) ds.get_mutable_local_copy(
target_folder=ds_folder, part=args.part, num_parts=args.num_parts, max_workers=args.max_workers
)
else: else:
if args.link: if args.link:
Path(args.link).mkdir(parents=True, exist_ok=True) Path(args.link).mkdir(parents=True, exist_ok=True)
@ -438,7 +464,7 @@ def ds_get(args):
Path(args.link).unlink() Path(args.link).unlink()
except Exception: except Exception:
raise ValueError("Target directory {} is not empty. Use --overwrite.".format(args.link)) raise ValueError("Target directory {} is not empty. Use --overwrite.".format(args.link))
ds_folder = ds.get_local_copy(part=args.part, num_parts=args.num_parts) ds_folder = ds.get_local_copy(part=args.part, num_parts=args.num_parts, max_workers=args.max_workers)
if args.link: if args.link:
os.symlink(ds_folder, args.link) os.symlink(ds_folder, args.link)
ds_folder = args.link ds_folder = args.link
@ -568,10 +594,13 @@ def ds_close(args):
raise ValueError("Pending uploads, cannot finalize dataset. run `clearml-data upload`") raise ValueError("Pending uploads, cannot finalize dataset. run `clearml-data upload`")
# upload the files # upload the files
print("Pending uploads, starting dataset upload to {}".format(args.storage or ds.get_default_storage())) print("Pending uploads, starting dataset upload to {}".format(args.storage or ds.get_default_storage()))
ds.upload(show_progress=True, ds.upload(
verbose=args.verbose, show_progress=True,
output_url=args.storage or None, verbose=args.verbose,
chunk_size=args.chunk_size or -1,) output_url=args.storage or None,
chunk_size=args.chunk_size or -1,
max_workers=args.max_workers,
)
ds.finalize() ds.finalize()
print('Dataset closed and finalized') print('Dataset closed and finalized')
@ -598,7 +627,12 @@ def ds_upload(args):
check_null_id(args) check_null_id(args)
print_args(args) print_args(args)
ds = Dataset.get(dataset_id=args.id) ds = Dataset.get(dataset_id=args.id)
ds.upload(verbose=args.verbose, output_url=args.storage or None, chunk_size=args.chunk_size or -1) ds.upload(
verbose=args.verbose,
output_url=args.storage or None,
chunk_size=args.chunk_size or -1,
max_workers=args.max_workers,
)
print('Dataset upload completed') print('Dataset upload completed')
return 0 return 0
@ -667,6 +701,7 @@ def ds_add(args):
verbose=args.verbose, verbose=args.verbose,
dataset_path=args.dataset_folder or None, dataset_path=args.dataset_folder or None,
wildcard=args.wildcard, wildcard=args.wildcard,
max_workers=args.max_workers
) )
for link in args.links or []: for link in args.links or []:
num_files += ds.add_external_files( num_files += ds.add_external_files(
@ -675,6 +710,7 @@ def ds_add(args):
recursive=not args.non_recursive, recursive=not args.non_recursive,
verbose=args.verbose, verbose=args.verbose,
wildcard=args.wildcard, wildcard=args.wildcard,
max_workers=args.max_workers
) )
message = "{} file{} added".format(num_files, "s" if num_files != 1 else "") message = "{} file{} added".format(num_files, "s" if num_files != 1 else "")
print(message) print(message)

View File

@ -121,6 +121,8 @@
# ] # ]
} }
azure.storage { azure.storage {
# max_connections: 2
# containers: [ # containers: [
# { # {
# account_name: "clearml" # account_name: "clearml"

View File

@ -8,7 +8,7 @@ import re
import logging import logging
from copy import deepcopy, copy from copy import deepcopy, copy
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import Union, Optional, Sequence, List, Dict, Any, Mapping, Tuple from typing import Union, Optional, Sequence, List, Dict, Any, Mapping, Tuple
from zipfile import ZIP_DEFLATED from zipfile import ZIP_DEFLATED
@ -23,7 +23,7 @@ from ..backend_interface.task.development.worker import DevWorker
from ..backend_interface.util import mutually_exclusive, exact_match_regex, get_or_create_project, rename_project from ..backend_interface.util import mutually_exclusive, exact_match_regex, get_or_create_project, rename_project
from ..config import deferred_config, running_remotely, get_remote_task_id from ..config import deferred_config, running_remotely, get_remote_task_id
from ..debugging.log import LoggerRoot from ..debugging.log import LoggerRoot
from ..storage.helper import StorageHelper from ..storage.helper import StorageHelper, cloud_driver_schemes
from ..storage.cache import CacheManager from ..storage.cache import CacheManager
from ..storage.util import sha256sum, is_windows, md5text, format_size from ..storage.util import sha256sum, is_windows, md5text, format_size
from ..utilities.matching import matches_any_wildcard from ..utilities.matching import matches_any_wildcard
@ -368,7 +368,8 @@ class Dataset(object):
local_base_folder=None, # type: Optional[str] local_base_folder=None, # type: Optional[str]
dataset_path=None, # type: Optional[str] dataset_path=None, # type: Optional[str]
recursive=True, # type: bool recursive=True, # type: bool
verbose=False # type: bool verbose=False, # type: bool
max_workers=None, # type: Optional[int]
): ):
# type: (...) -> () # type: (...) -> ()
""" """
@ -382,8 +383,10 @@ class Dataset(object):
:param dataset_path: where in the dataset the folder/files should be located :param dataset_path: where in the dataset the folder/files should be located
:param recursive: If True match all wildcard files recursively :param recursive: If True match all wildcard files recursively
:param verbose: If True print to console files added/modified :param verbose: If True print to console files added/modified
:param max_workers: The number of threads to add the files with. Defaults to the number of logical cores
:return: number of files added :return: number of files added
""" """
max_workers = max_workers or psutil.cpu_count()
self._dirty = True self._dirty = True
self._task.get_logger().report_text( self._task.get_logger().report_text(
'Adding files to dataset: {}'.format( 'Adding files to dataset: {}'.format(
@ -392,8 +395,14 @@ class Dataset(object):
print_console=False) print_console=False)
num_added, num_modified = self._add_files( num_added, num_modified = self._add_files(
path=path, wildcard=wildcard, local_base_folder=local_base_folder, path=path,
dataset_path=dataset_path, recursive=recursive, verbose=verbose) wildcard=wildcard,
local_base_folder=local_base_folder,
dataset_path=dataset_path,
recursive=recursive,
verbose=verbose,
max_workers=max_workers,
)
# update the task script # update the task script
self._add_script_call( self._add_script_call(
@ -587,9 +596,16 @@ class Dataset(object):
return num_removed, num_added, num_modified return num_removed, num_added, num_modified
def upload( def upload(
self, show_progress=True, verbose=False, output_url=None, compression=None, chunk_size=None, max_workers=None self,
show_progress=True,
verbose=False,
output_url=None,
compression=None,
chunk_size=None,
max_workers=None,
retries=3,
): ):
# type: (bool, bool, Optional[str], Optional[str], int, Optional[int]) -> () # type: (bool, bool, Optional[str], Optional[str], int, Optional[int], int) -> ()
""" """
Start file uploading, the function returns when all files are uploaded. Start file uploading, the function returns when all files are uploaded.
@ -602,17 +618,22 @@ class Dataset(object):
if not provided (None) use the default chunk size (512mb). if not provided (None) use the default chunk size (512mb).
If -1 is provided, use a single zip artifact for the entire dataset change-set (old behaviour) If -1 is provided, use a single zip artifact for the entire dataset change-set (old behaviour)
:param max_workers: Numbers of threads to be spawned when zipping and uploading the files. :param max_workers: Numbers of threads to be spawned when zipping and uploading the files.
Defaults to the number of logical cores. If None (default) it will be set to:
- 1: if the upload destination is a cloud provider ('s3', 'gs', 'azure')
- number of logical cores: otherwise
:param int retries: Number of retries before failing to upload each zip. If 0, the upload is not retried.
:raise: If the upload failed (i.e. at least one zip failed to upload), raise a `ValueError`
""" """
self._report_dataset_preview() self._report_dataset_preview()
if not max_workers:
max_workers = psutil.cpu_count()
# set output_url # set output_url
if output_url: if output_url:
self._task.output_uri = output_url self._task.output_uri = output_url
if not max_workers:
max_workers = 1 if self._task.output_uri.startswith(tuple(cloud_driver_schemes)) else psutil.cpu_count()
self._task.get_logger().report_text( self._task.get_logger().report_text(
"Uploading dataset files: {}".format( "Uploading dataset files: {}".format(
dict(show_progress=show_progress, verbose=verbose, output_url=output_url, compression=compression) dict(show_progress=show_progress, verbose=verbose, output_url=output_url, compression=compression)
@ -625,9 +646,9 @@ class Dataset(object):
total_preview_size = 0 total_preview_size = 0
keep_as_file_entry = set() keep_as_file_entry = set()
chunk_size = int(self._dataset_chunk_size_mb if not chunk_size else chunk_size) chunk_size = int(self._dataset_chunk_size_mb if not chunk_size else chunk_size)
upload_futures = []
with ThreadPoolExecutor(max_workers=max_workers) as pool: with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = []
parallel_zipper = ParallelZipper( parallel_zipper = ParallelZipper(
chunk_size, chunk_size,
max_workers, max_workers,
@ -648,6 +669,15 @@ class Dataset(object):
file_paths.append(f.local_path) file_paths.append(f.local_path)
arcnames[f.local_path] = f.relative_path arcnames[f.local_path] = f.relative_path
for zip_ in parallel_zipper.zip_iter(file_paths, arcnames=arcnames): for zip_ in parallel_zipper.zip_iter(file_paths, arcnames=arcnames):
running_futures = []
for upload_future in upload_futures:
if upload_future.running():
running_futures.append(upload_future)
else:
if not upload_future.result():
raise ValueError("Failed uploading dataset with ID {}".format(self._id))
upload_futures = running_futures
zip_path = Path(zip_.zip_path) zip_path = Path(zip_.zip_path)
artifact_name = self._data_artifact_name artifact_name = self._data_artifact_name
self._data_artifact_name = self._get_next_data_artifact_name(self._data_artifact_name) self._data_artifact_name = self._get_next_data_artifact_name(self._data_artifact_name)
@ -675,14 +705,17 @@ class Dataset(object):
preview = truncated_preview + (truncated_message if add_truncated_message else "") preview = truncated_preview + (truncated_message if add_truncated_message else "")
total_preview_size += len(preview) total_preview_size += len(preview)
futures.append(pool.submit( upload_futures.append(
self._task.upload_artifact, pool.submit(
name=artifact_name, self._task.upload_artifact,
artifact_object=Path(zip_path), name=artifact_name,
preview=preview, artifact_object=Path(zip_path),
delete_after_upload=True, preview=preview,
wait_on_upload=True, delete_after_upload=True,
)) wait_on_upload=True,
retries=retries
)
)
for file_entry in self._dataset_file_entries.values(): for file_entry in self._dataset_file_entries.values():
if file_entry.local_path is not None and \ if file_entry.local_path is not None and \
Path(file_entry.local_path).as_posix() in zip_.files_zipped: Path(file_entry.local_path).as_posix() in zip_.files_zipped:
@ -691,12 +724,6 @@ class Dataset(object):
if file_entry.parent_dataset_id == self._id: if file_entry.parent_dataset_id == self._id:
file_entry.local_path = None file_entry.local_path = None
self._serialize() self._serialize()
num_threads_with_errors = 0
for future in as_completed(futures):
if future.exception():
num_threads_with_errors += 1
if num_threads_with_errors > 0:
raise ValueError(f"errors reported uploading {num_threads_with_errors} chunks")
self._task.get_logger().report_text( self._task.get_logger().report_text(
"File compression and upload completed: total size {}, {} chunk(s) stored (average size {})".format( "File compression and upload completed: total size {}, {} chunk(s) stored (average size {})".format(
@ -836,8 +863,7 @@ class Dataset(object):
self._task = Task.get_task(task_id=self._id) self._task = Task.get_task(task_id=self._id)
if not self.is_final(): if not self.is_final():
raise ValueError("Cannot get a local copy of a dataset that was not finalized/closed") raise ValueError("Cannot get a local copy of a dataset that was not finalized/closed")
if not max_workers: max_workers = max_workers or psutil.cpu_count()
max_workers = psutil.cpu_count()
# now let's merge the parents # now let's merge the parents
target_folder = self._merge_datasets( target_folder = self._merge_datasets(
@ -878,8 +904,7 @@ class Dataset(object):
:return: The target folder containing the entire dataset :return: The target folder containing the entire dataset
""" """
assert self._id assert self._id
if not max_workers: max_workers = max_workers or psutil.cpu_count()
max_workers = psutil.cpu_count()
target_folder = Path(target_folder).absolute() target_folder = Path(target_folder).absolute()
target_folder.mkdir(parents=True, exist_ok=True) target_folder.mkdir(parents=True, exist_ok=True)
# noinspection PyBroadException # noinspection PyBroadException
@ -1812,14 +1837,16 @@ class Dataset(object):
for d in datasets for d in datasets
] ]
def _add_files(self, def _add_files(
path, # type: Union[str, Path, _Path] self,
wildcard=None, # type: Optional[Union[str, Sequence[str]]] path, # type: Union[str, Path, _Path]
local_base_folder=None, # type: Optional[str] wildcard=None, # type: Optional[Union[str, Sequence[str]]]
dataset_path=None, # type: Optional[str] local_base_folder=None, # type: Optional[str]
recursive=True, # type: bool dataset_path=None, # type: Optional[str]
verbose=False # type: bool recursive=True, # type: bool
): verbose=False, # type: bool
max_workers=None, # type: Optional[int]
):
# type: (...) -> tuple[int, int] # type: (...) -> tuple[int, int]
""" """
Add a folder into the current dataset. calculate file hash, Add a folder into the current dataset. calculate file hash,
@ -1832,7 +1859,9 @@ class Dataset(object):
:param dataset_path: where in the dataset the folder/files should be located :param dataset_path: where in the dataset the folder/files should be located
:param recursive: If True match all wildcard files recursively :param recursive: If True match all wildcard files recursively
:param verbose: If True print to console added files :param verbose: If True print to console added files
:param max_workers: The number of threads to add the files with. Defaults to the number of logical cores
""" """
max_workers = max_workers or psutil.cpu_count()
if dataset_path: if dataset_path:
dataset_path = dataset_path.lstrip("/") dataset_path = dataset_path.lstrip("/")
path = Path(path) path = Path(path)
@ -1869,7 +1898,7 @@ class Dataset(object):
for f in file_entries for f in file_entries
] ]
self._task.get_logger().report_text('Generating SHA2 hash for {} files'.format(len(file_entries))) self._task.get_logger().report_text('Generating SHA2 hash for {} files'.format(len(file_entries)))
pool = ThreadPool(psutil.cpu_count()) pool = ThreadPool(max_workers)
try: try:
import tqdm # noqa import tqdm # noqa
for _ in tqdm.tqdm(pool.imap_unordered(self._calc_file_hash, file_entries), total=len(file_entries)): for _ in tqdm.tqdm(pool.imap_unordered(self._calc_file_hash, file_entries), total=len(file_entries)):
@ -2083,8 +2112,7 @@ class Dataset(object):
:return: Path to the local storage where the data was downloaded :return: Path to the local storage where the data was downloaded
""" """
if not max_workers: max_workers = max_workers or psutil.cpu_count()
max_workers = psutil.cpu_count()
local_folder = self._extract_dataset_archive( local_folder = self._extract_dataset_archive(
force=force, force=force,
selected_chunks=selected_chunks, selected_chunks=selected_chunks,
@ -2189,8 +2217,7 @@ class Dataset(object):
if not self._task: if not self._task:
self._task = Task.get_task(task_id=self._id) self._task = Task.get_task(task_id=self._id)
if not max_workers: max_workers = max_workers or psutil.cpu_count()
max_workers = psutil.cpu_count()
data_artifact_entries = self._get_data_artifact_names() data_artifact_entries = self._get_data_artifact_names()
@ -2305,8 +2332,7 @@ class Dataset(object):
assert part is None or (isinstance(part, int) and part >= 0) assert part is None or (isinstance(part, int) and part >= 0)
assert num_parts is None or (isinstance(num_parts, int) and num_parts >= 1) assert num_parts is None or (isinstance(num_parts, int) and num_parts >= 1)
if max_workers is None: max_workers = max_workers or psutil.cpu_count()
max_workers = psutil.cpu_count()
if use_soft_links is None: if use_soft_links is None:
use_soft_links = False if is_windows() else True use_soft_links = False if is_windows() else True
@ -2863,8 +2889,7 @@ class Dataset(object):
): ):
# type: (Path, List[str], dict, bool, bool, bool, Optional[int]) -> () # type: (Path, List[str], dict, bool, bool, bool, Optional[int]) -> ()
# create thread pool, for creating soft-links / copying # create thread pool, for creating soft-links / copying
if not max_workers: max_workers = max_workers or psutil.cpu_count()
max_workers = psutil.cpu_count()
pool = ThreadPool(max_workers) pool = ThreadPool(max_workers)
for dataset_version_id in dependencies_by_order: for dataset_version_id in dependencies_by_order:
# make sure we skip over empty dependencies # make sure we skip over empty dependencies

View File

@ -1252,7 +1252,8 @@ class _HttpDriver(_Driver):
requests_codes.service_unavailable, requests_codes.service_unavailable,
requests_codes.bandwidth_limit_exceeded, requests_codes.bandwidth_limit_exceeded,
requests_codes.too_many_requests, requests_codes.too_many_requests,
] ],
config=config
) )
self.attach_auth_header = any( self.attach_auth_header = any(
(name.rstrip('/') == host.rstrip('/') or name.startswith(host.rstrip('/') + '/')) (name.rstrip('/') == host.rstrip('/') or name.startswith(host.rstrip('/') + '/'))
@ -1921,9 +1922,10 @@ class _GoogleCloudStorageDriver(_Driver):
class _AzureBlobServiceStorageDriver(_Driver): class _AzureBlobServiceStorageDriver(_Driver):
scheme = 'azure' scheme = "azure"
_containers = {} _containers = {}
_max_connections = deferred_config("azure.storage.max_connections", None)
class _Container(object): class _Container(object):
def __init__(self, name, config, account_url): def __init__(self, name, config, account_url):
@ -1965,9 +1967,10 @@ class _AzureBlobServiceStorageDriver(_Driver):
) )
def create_blob_from_data( def create_blob_from_data(
self, container_name, object_name, blob_name, data, max_connections=2, self, container_name, object_name, blob_name, data, max_connections=None,
progress_callback=None, content_settings=None progress_callback=None, content_settings=None
): ):
max_connections = max_connections or _AzureBlobServiceStorageDriver._max_connections
if self.__legacy: if self.__legacy:
self.__blob_service.create_blob_from_bytes( self.__blob_service.create_blob_from_bytes(
container_name, container_name,
@ -1985,8 +1988,9 @@ class _AzureBlobServiceStorageDriver(_Driver):
) )
def create_blob_from_path( def create_blob_from_path(
self, container_name, blob_name, path, max_connections=2, content_settings=None, progress_callback=None self, container_name, blob_name, path, max_connections=None, content_settings=None, progress_callback=None
): ):
max_connections = max_connections or _AzureBlobServiceStorageDriver._max_connections
if self.__legacy: if self.__legacy:
self.__blob_service.create_blob_from_path( self.__blob_service.create_blob_from_path(
container_name, container_name,
@ -2045,7 +2049,8 @@ class _AzureBlobServiceStorageDriver(_Driver):
client = self.__blob_service.get_blob_client(container_name, blob_name) client = self.__blob_service.get_blob_client(container_name, blob_name)
return client.download_blob().content_as_bytes() return client.download_blob().content_as_bytes()
def get_blob_to_path(self, container_name, blob_name, path, max_connections=10, progress_callback=None): def get_blob_to_path(self, container_name, blob_name, path, max_connections=None, progress_callback=None):
max_connections = max_connections or _AzureBlobServiceStorageDriver._max_connections
if self.__legacy: if self.__legacy:
return self.__blob_service.get_blob_to_path( return self.__blob_service.get_blob_to_path(
container_name, container_name,
@ -2078,10 +2083,11 @@ class _AzureBlobServiceStorageDriver(_Driver):
self._containers[container_name] = self._Container( self._containers[container_name] = self._Container(
name=container_name, config=config, account_url=account_url name=container_name, config=config, account_url=account_url
) )
# self._containers[container_name].config.retries = kwargs.get('retries', 5)
return self._containers[container_name] return self._containers[container_name]
def upload_object_via_stream(self, iterator, container, object_name, callback=None, extra=None, **kwargs): def upload_object_via_stream(
self, iterator, container, object_name, callback=None, extra=None, max_connections=None, **kwargs
):
try: try:
from azure.common import AzureHttpError # noqa from azure.common import AzureHttpError # noqa
except ImportError: except ImportError:
@ -2096,17 +2102,17 @@ class _AzureBlobServiceStorageDriver(_Driver):
object_name, object_name,
blob_name, blob_name,
iterator.read() if hasattr(iterator, "read") else bytes(iterator), iterator.read() if hasattr(iterator, "read") else bytes(iterator),
max_connections=2, max_connections=max_connections,
progress_callback=callback, progress_callback=callback,
) )
return True return True
except AzureHttpError as ex: except AzureHttpError as ex:
self.get_logger().error('Failed uploading (Azure error): %s' % ex) self.get_logger().error("Failed uploading (Azure error): %s" % ex)
except Exception as ex: except Exception as ex:
self.get_logger().error('Failed uploading: %s' % ex) self.get_logger().error("Failed uploading: %s" % ex)
return False return False
def upload_object(self, file_path, container, object_name, callback=None, extra=None, **kwargs): def upload_object(self, file_path, container, object_name, callback=None, extra=None, max_connections=None, **kwargs):
try: try:
from azure.common import AzureHttpError # noqa from azure.common import AzureHttpError # noqa
except ImportError: except ImportError:
@ -2123,7 +2129,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
container.name, container.name,
blob_name, blob_name,
file_path, file_path,
max_connections=2, max_connections=max_connections,
content_settings=ContentSettings(content_type=get_file_mimetype(object_name or file_path)), content_settings=ContentSettings(content_type=get_file_mimetype(object_name or file_path)),
progress_callback=callback, progress_callback=callback,
) )
@ -2177,10 +2183,10 @@ class _AzureBlobServiceStorageDriver(_Driver):
else: else:
return blob return blob
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, **_): def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, max_connections=None, **_):
p = Path(local_path) p = Path(local_path)
if not overwrite_existing and p.is_file(): if not overwrite_existing and p.is_file():
self.get_logger().warning("failed saving after download: overwrite=False and file exists (%s)" % str(p)) self.get_logger().warning("Failed saving after download: overwrite=False and file exists (%s)" % str(p))
return return
download_done = SafeEvent() download_done = SafeEvent()
@ -2200,7 +2206,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
container.name, container.name,
obj.blob_name, obj.blob_name,
local_path, local_path,
max_connections=10, max_connections=max_connections,
progress_callback=callback_func, progress_callback=callback_func,
) )
if container.is_legacy(): if container.is_legacy():
@ -2836,3 +2842,4 @@ driver_schemes = set(
) )
remote_driver_schemes = driver_schemes - {_FileStorageDriver.scheme} remote_driver_schemes = driver_schemes - {_FileStorageDriver.scheme}
cloud_driver_schemes = remote_driver_schemes - set(_HttpDriver.schemes)

View File

@ -1889,7 +1889,8 @@ class Task(_Task):
preview=None, # type: Any preview=None, # type: Any
wait_on_upload=False, # type: bool wait_on_upload=False, # type: bool
extension_name=None, # type: Optional[str] extension_name=None, # type: Optional[str]
serialization_function=None # type: Optional[Callable[[Any], Union[bytes, bytearray]]] serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
retries=0 # type: int
): ):
# type: (...) -> bool # type: (...) -> bool
""" """
@ -1946,6 +1947,8 @@ class Task(_Task):
(e.g. `pandas.DataFrame.to_csv`), even if possible. To deserialize this artifact when getting (e.g. `pandas.DataFrame.to_csv`), even if possible. To deserialize this artifact when getting
it using the `Artifact.get` method, use its `deserialization_function` argument. it using the `Artifact.get` method, use its `deserialization_function` argument.
:param int retries: Number of retries before failing to upload artifact. If 0, the upload is not retried
:return: The status of the upload. :return: The status of the upload.
- ``True`` - Upload succeeded. - ``True`` - Upload succeeded.
@ -1953,17 +1956,31 @@ class Task(_Task):
:raise: If the artifact object type is not supported, raise a ``ValueError``. :raise: If the artifact object type is not supported, raise a ``ValueError``.
""" """
return self._artifacts_manager.upload_artifact( exception_to_raise = None
name=name, for retry in range(retries + 1):
artifact_object=artifact_object, # noinspection PyBroadException
metadata=metadata, try:
delete_after_upload=delete_after_upload, if self._artifacts_manager.upload_artifact(
auto_pickle=auto_pickle, name=name,
preview=preview, artifact_object=artifact_object,
wait_on_upload=wait_on_upload, metadata=metadata,
extension_name=extension_name, delete_after_upload=delete_after_upload,
serialization_function=serialization_function, auto_pickle=auto_pickle,
) preview=preview,
wait_on_upload=wait_on_upload,
extension_name=extension_name,
serialization_function=serialization_function,
):
return True
except Exception as e:
exception_to_raise = e
if retry < retries:
getLogger().warning(
"Failed uploading artifact '{}'. Retrying... ({}/{})".format(name, retry + 1, retries)
)
if exception_to_raise:
raise exception_to_raise
return False
def get_models(self): def get_models(self):
# type: () -> Mapping[str, Sequence[Model]] # type: () -> Mapping[str, Sequence[Model]]

View File

@ -127,6 +127,8 @@ sdk {
# ] # ]
} }
azure.storage { azure.storage {
# max_connections: 2
# containers: [ # containers: [
# { # {
# account_name: "clearml" # account_name: "clearml"