mirror of
https://github.com/clearml/clearml
synced 2025-04-02 12:08:33 +00:00
Support num_workers
in dataset operations
Support max connections setting for Azure storage using the `sdk.azure.storage.max_connection` setting
This commit is contained in:
parent
68467d7288
commit
ed2b8ed850
@ -392,7 +392,11 @@ class AzureContainerConfigurations(object):
|
||||
))
|
||||
|
||||
if configuration is None:
|
||||
return cls(default_container_configs, default_account=default_account, default_key=default_key)
|
||||
return cls(
|
||||
default_container_configs,
|
||||
default_account=default_account,
|
||||
default_key=default_key
|
||||
)
|
||||
|
||||
containers = configuration.get("containers", list())
|
||||
container_configs = [AzureContainerConfig(**entry) for entry in containers] + default_container_configs
|
||||
|
@ -118,6 +118,12 @@ def cli():
|
||||
add.add_argument('--non-recursive', action='store_true', default=False,
|
||||
help='Disable recursive scan of files')
|
||||
add.add_argument('--verbose', action='store_true', default=False, help='Verbose reporting')
|
||||
add.add_argument(
|
||||
"--max-workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of threads to add the files with. Defaults to the number of logical cores",
|
||||
)
|
||||
add.set_defaults(func=ds_add)
|
||||
|
||||
set_description = subparsers.add_parser("set-description", help="Set description to the dataset")
|
||||
@ -195,6 +201,12 @@ def cli():
|
||||
help='Set dataset artifact chunk size in MB. Default 512, (pass -1 for a single chunk). '
|
||||
'Example: 512, dataset will be split and uploaded in 512mb chunks.')
|
||||
upload.add_argument('--verbose', default=False, action='store_true', help='Verbose reporting')
|
||||
upload.add_argument(
|
||||
"--max-workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of threads to upload the files with. Defaults to 1 if uploading to a cloud provider ('s3', 'azure', 'gs') OR to the number of logical cores otherwise",
|
||||
)
|
||||
upload.set_defaults(func=ds_upload)
|
||||
|
||||
finalize = subparsers.add_parser('close', help='Finalize and close the dataset (implies auto upload)')
|
||||
@ -210,6 +222,12 @@ def cli():
|
||||
help='Set dataset artifact chunk size in MB. Default 512, (pass -1 for a single chunk). '
|
||||
'Example: 512, dataset will be split and uploaded in 512mb chunks.')
|
||||
finalize.add_argument('--verbose', action='store_true', default=False, help='Verbose reporting')
|
||||
finalize.add_argument(
|
||||
"--max-workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of threads to upload the files with. Defaults to 1 if uploading to a cloud provider ('s3', 'azure', 'gs') OR to the number of logical cores otherwise",
|
||||
)
|
||||
finalize.set_defaults(func=ds_close)
|
||||
|
||||
publish = subparsers.add_parser('publish', help='Publish dataset task')
|
||||
@ -327,6 +345,12 @@ def cli():
|
||||
'can be divided into 4 parts')
|
||||
get.add_argument('--overwrite', action='store_true', default=False, help='If True, overwrite the target folder')
|
||||
get.add_argument('--verbose', action='store_true', default=False, help='Verbose reporting')
|
||||
get.add_argument(
|
||||
"--max-workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of threads to get the files with. Defaults to the number of logical cores",
|
||||
)
|
||||
get.set_defaults(func=ds_get)
|
||||
|
||||
args = parser.parse_args()
|
||||
@ -426,7 +450,9 @@ def ds_get(args):
|
||||
pass
|
||||
if args.copy:
|
||||
ds_folder = args.copy
|
||||
ds.get_mutable_local_copy(target_folder=ds_folder, part=args.part, num_parts=args.num_parts)
|
||||
ds.get_mutable_local_copy(
|
||||
target_folder=ds_folder, part=args.part, num_parts=args.num_parts, max_workers=args.max_workers
|
||||
)
|
||||
else:
|
||||
if args.link:
|
||||
Path(args.link).mkdir(parents=True, exist_ok=True)
|
||||
@ -438,7 +464,7 @@ def ds_get(args):
|
||||
Path(args.link).unlink()
|
||||
except Exception:
|
||||
raise ValueError("Target directory {} is not empty. Use --overwrite.".format(args.link))
|
||||
ds_folder = ds.get_local_copy(part=args.part, num_parts=args.num_parts)
|
||||
ds_folder = ds.get_local_copy(part=args.part, num_parts=args.num_parts, max_workers=args.max_workers)
|
||||
if args.link:
|
||||
os.symlink(ds_folder, args.link)
|
||||
ds_folder = args.link
|
||||
@ -568,10 +594,13 @@ def ds_close(args):
|
||||
raise ValueError("Pending uploads, cannot finalize dataset. run `clearml-data upload`")
|
||||
# upload the files
|
||||
print("Pending uploads, starting dataset upload to {}".format(args.storage or ds.get_default_storage()))
|
||||
ds.upload(show_progress=True,
|
||||
verbose=args.verbose,
|
||||
output_url=args.storage or None,
|
||||
chunk_size=args.chunk_size or -1,)
|
||||
ds.upload(
|
||||
show_progress=True,
|
||||
verbose=args.verbose,
|
||||
output_url=args.storage or None,
|
||||
chunk_size=args.chunk_size or -1,
|
||||
max_workers=args.max_workers,
|
||||
)
|
||||
|
||||
ds.finalize()
|
||||
print('Dataset closed and finalized')
|
||||
@ -598,7 +627,12 @@ def ds_upload(args):
|
||||
check_null_id(args)
|
||||
print_args(args)
|
||||
ds = Dataset.get(dataset_id=args.id)
|
||||
ds.upload(verbose=args.verbose, output_url=args.storage or None, chunk_size=args.chunk_size or -1)
|
||||
ds.upload(
|
||||
verbose=args.verbose,
|
||||
output_url=args.storage or None,
|
||||
chunk_size=args.chunk_size or -1,
|
||||
max_workers=args.max_workers,
|
||||
)
|
||||
print('Dataset upload completed')
|
||||
return 0
|
||||
|
||||
@ -667,6 +701,7 @@ def ds_add(args):
|
||||
verbose=args.verbose,
|
||||
dataset_path=args.dataset_folder or None,
|
||||
wildcard=args.wildcard,
|
||||
max_workers=args.max_workers
|
||||
)
|
||||
for link in args.links or []:
|
||||
num_files += ds.add_external_files(
|
||||
@ -675,6 +710,7 @@ def ds_add(args):
|
||||
recursive=not args.non_recursive,
|
||||
verbose=args.verbose,
|
||||
wildcard=args.wildcard,
|
||||
max_workers=args.max_workers
|
||||
)
|
||||
message = "{} file{} added".format(num_files, "s" if num_files != 1 else "")
|
||||
print(message)
|
||||
|
@ -121,6 +121,8 @@
|
||||
# ]
|
||||
}
|
||||
azure.storage {
|
||||
# max_connections: 2
|
||||
|
||||
# containers: [
|
||||
# {
|
||||
# account_name: "clearml"
|
||||
|
@ -8,7 +8,7 @@ import re
|
||||
import logging
|
||||
from copy import deepcopy, copy
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tempfile import mkdtemp
|
||||
from typing import Union, Optional, Sequence, List, Dict, Any, Mapping, Tuple
|
||||
from zipfile import ZIP_DEFLATED
|
||||
@ -23,7 +23,7 @@ from ..backend_interface.task.development.worker import DevWorker
|
||||
from ..backend_interface.util import mutually_exclusive, exact_match_regex, get_or_create_project, rename_project
|
||||
from ..config import deferred_config, running_remotely, get_remote_task_id
|
||||
from ..debugging.log import LoggerRoot
|
||||
from ..storage.helper import StorageHelper
|
||||
from ..storage.helper import StorageHelper, cloud_driver_schemes
|
||||
from ..storage.cache import CacheManager
|
||||
from ..storage.util import sha256sum, is_windows, md5text, format_size
|
||||
from ..utilities.matching import matches_any_wildcard
|
||||
@ -368,7 +368,8 @@ class Dataset(object):
|
||||
local_base_folder=None, # type: Optional[str]
|
||||
dataset_path=None, # type: Optional[str]
|
||||
recursive=True, # type: bool
|
||||
verbose=False # type: bool
|
||||
verbose=False, # type: bool
|
||||
max_workers=None, # type: Optional[int]
|
||||
):
|
||||
# type: (...) -> ()
|
||||
"""
|
||||
@ -382,8 +383,10 @@ class Dataset(object):
|
||||
:param dataset_path: where in the dataset the folder/files should be located
|
||||
:param recursive: If True match all wildcard files recursively
|
||||
:param verbose: If True print to console files added/modified
|
||||
:param max_workers: The number of threads to add the files with. Defaults to the number of logical cores
|
||||
:return: number of files added
|
||||
"""
|
||||
max_workers = max_workers or psutil.cpu_count()
|
||||
self._dirty = True
|
||||
self._task.get_logger().report_text(
|
||||
'Adding files to dataset: {}'.format(
|
||||
@ -392,8 +395,14 @@ class Dataset(object):
|
||||
print_console=False)
|
||||
|
||||
num_added, num_modified = self._add_files(
|
||||
path=path, wildcard=wildcard, local_base_folder=local_base_folder,
|
||||
dataset_path=dataset_path, recursive=recursive, verbose=verbose)
|
||||
path=path,
|
||||
wildcard=wildcard,
|
||||
local_base_folder=local_base_folder,
|
||||
dataset_path=dataset_path,
|
||||
recursive=recursive,
|
||||
verbose=verbose,
|
||||
max_workers=max_workers,
|
||||
)
|
||||
|
||||
# update the task script
|
||||
self._add_script_call(
|
||||
@ -587,9 +596,16 @@ class Dataset(object):
|
||||
return num_removed, num_added, num_modified
|
||||
|
||||
def upload(
|
||||
self, show_progress=True, verbose=False, output_url=None, compression=None, chunk_size=None, max_workers=None
|
||||
self,
|
||||
show_progress=True,
|
||||
verbose=False,
|
||||
output_url=None,
|
||||
compression=None,
|
||||
chunk_size=None,
|
||||
max_workers=None,
|
||||
retries=3,
|
||||
):
|
||||
# type: (bool, bool, Optional[str], Optional[str], int, Optional[int]) -> ()
|
||||
# type: (bool, bool, Optional[str], Optional[str], int, Optional[int], int) -> ()
|
||||
"""
|
||||
Start file uploading, the function returns when all files are uploaded.
|
||||
|
||||
@ -602,17 +618,22 @@ class Dataset(object):
|
||||
if not provided (None) use the default chunk size (512mb).
|
||||
If -1 is provided, use a single zip artifact for the entire dataset change-set (old behaviour)
|
||||
:param max_workers: Numbers of threads to be spawned when zipping and uploading the files.
|
||||
Defaults to the number of logical cores.
|
||||
If None (default) it will be set to:
|
||||
- 1: if the upload destination is a cloud provider ('s3', 'gs', 'azure')
|
||||
- number of logical cores: otherwise
|
||||
:param int retries: Number of retries before failing to upload each zip. If 0, the upload is not retried.
|
||||
|
||||
:raise: If the upload failed (i.e. at least one zip failed to upload), raise a `ValueError`
|
||||
"""
|
||||
self._report_dataset_preview()
|
||||
|
||||
if not max_workers:
|
||||
max_workers = psutil.cpu_count()
|
||||
|
||||
# set output_url
|
||||
if output_url:
|
||||
self._task.output_uri = output_url
|
||||
|
||||
if not max_workers:
|
||||
max_workers = 1 if self._task.output_uri.startswith(tuple(cloud_driver_schemes)) else psutil.cpu_count()
|
||||
|
||||
self._task.get_logger().report_text(
|
||||
"Uploading dataset files: {}".format(
|
||||
dict(show_progress=show_progress, verbose=verbose, output_url=output_url, compression=compression)
|
||||
@ -625,9 +646,9 @@ class Dataset(object):
|
||||
total_preview_size = 0
|
||||
keep_as_file_entry = set()
|
||||
chunk_size = int(self._dataset_chunk_size_mb if not chunk_size else chunk_size)
|
||||
upload_futures = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = []
|
||||
parallel_zipper = ParallelZipper(
|
||||
chunk_size,
|
||||
max_workers,
|
||||
@ -648,6 +669,15 @@ class Dataset(object):
|
||||
file_paths.append(f.local_path)
|
||||
arcnames[f.local_path] = f.relative_path
|
||||
for zip_ in parallel_zipper.zip_iter(file_paths, arcnames=arcnames):
|
||||
running_futures = []
|
||||
for upload_future in upload_futures:
|
||||
if upload_future.running():
|
||||
running_futures.append(upload_future)
|
||||
else:
|
||||
if not upload_future.result():
|
||||
raise ValueError("Failed uploading dataset with ID {}".format(self._id))
|
||||
upload_futures = running_futures
|
||||
|
||||
zip_path = Path(zip_.zip_path)
|
||||
artifact_name = self._data_artifact_name
|
||||
self._data_artifact_name = self._get_next_data_artifact_name(self._data_artifact_name)
|
||||
@ -675,14 +705,17 @@ class Dataset(object):
|
||||
preview = truncated_preview + (truncated_message if add_truncated_message else "")
|
||||
total_preview_size += len(preview)
|
||||
|
||||
futures.append(pool.submit(
|
||||
self._task.upload_artifact,
|
||||
name=artifact_name,
|
||||
artifact_object=Path(zip_path),
|
||||
preview=preview,
|
||||
delete_after_upload=True,
|
||||
wait_on_upload=True,
|
||||
))
|
||||
upload_futures.append(
|
||||
pool.submit(
|
||||
self._task.upload_artifact,
|
||||
name=artifact_name,
|
||||
artifact_object=Path(zip_path),
|
||||
preview=preview,
|
||||
delete_after_upload=True,
|
||||
wait_on_upload=True,
|
||||
retries=retries
|
||||
)
|
||||
)
|
||||
for file_entry in self._dataset_file_entries.values():
|
||||
if file_entry.local_path is not None and \
|
||||
Path(file_entry.local_path).as_posix() in zip_.files_zipped:
|
||||
@ -691,12 +724,6 @@ class Dataset(object):
|
||||
if file_entry.parent_dataset_id == self._id:
|
||||
file_entry.local_path = None
|
||||
self._serialize()
|
||||
num_threads_with_errors = 0
|
||||
for future in as_completed(futures):
|
||||
if future.exception():
|
||||
num_threads_with_errors += 1
|
||||
if num_threads_with_errors > 0:
|
||||
raise ValueError(f"errors reported uploading {num_threads_with_errors} chunks")
|
||||
|
||||
self._task.get_logger().report_text(
|
||||
"File compression and upload completed: total size {}, {} chunk(s) stored (average size {})".format(
|
||||
@ -836,8 +863,7 @@ class Dataset(object):
|
||||
self._task = Task.get_task(task_id=self._id)
|
||||
if not self.is_final():
|
||||
raise ValueError("Cannot get a local copy of a dataset that was not finalized/closed")
|
||||
if not max_workers:
|
||||
max_workers = psutil.cpu_count()
|
||||
max_workers = max_workers or psutil.cpu_count()
|
||||
|
||||
# now let's merge the parents
|
||||
target_folder = self._merge_datasets(
|
||||
@ -878,8 +904,7 @@ class Dataset(object):
|
||||
:return: The target folder containing the entire dataset
|
||||
"""
|
||||
assert self._id
|
||||
if not max_workers:
|
||||
max_workers = psutil.cpu_count()
|
||||
max_workers = max_workers or psutil.cpu_count()
|
||||
target_folder = Path(target_folder).absolute()
|
||||
target_folder.mkdir(parents=True, exist_ok=True)
|
||||
# noinspection PyBroadException
|
||||
@ -1812,14 +1837,16 @@ class Dataset(object):
|
||||
for d in datasets
|
||||
]
|
||||
|
||||
def _add_files(self,
|
||||
path, # type: Union[str, Path, _Path]
|
||||
wildcard=None, # type: Optional[Union[str, Sequence[str]]]
|
||||
local_base_folder=None, # type: Optional[str]
|
||||
dataset_path=None, # type: Optional[str]
|
||||
recursive=True, # type: bool
|
||||
verbose=False # type: bool
|
||||
):
|
||||
def _add_files(
|
||||
self,
|
||||
path, # type: Union[str, Path, _Path]
|
||||
wildcard=None, # type: Optional[Union[str, Sequence[str]]]
|
||||
local_base_folder=None, # type: Optional[str]
|
||||
dataset_path=None, # type: Optional[str]
|
||||
recursive=True, # type: bool
|
||||
verbose=False, # type: bool
|
||||
max_workers=None, # type: Optional[int]
|
||||
):
|
||||
# type: (...) -> tuple[int, int]
|
||||
"""
|
||||
Add a folder into the current dataset. calculate file hash,
|
||||
@ -1832,7 +1859,9 @@ class Dataset(object):
|
||||
:param dataset_path: where in the dataset the folder/files should be located
|
||||
:param recursive: If True match all wildcard files recursively
|
||||
:param verbose: If True print to console added files
|
||||
:param max_workers: The number of threads to add the files with. Defaults to the number of logical cores
|
||||
"""
|
||||
max_workers = max_workers or psutil.cpu_count()
|
||||
if dataset_path:
|
||||
dataset_path = dataset_path.lstrip("/")
|
||||
path = Path(path)
|
||||
@ -1869,7 +1898,7 @@ class Dataset(object):
|
||||
for f in file_entries
|
||||
]
|
||||
self._task.get_logger().report_text('Generating SHA2 hash for {} files'.format(len(file_entries)))
|
||||
pool = ThreadPool(psutil.cpu_count())
|
||||
pool = ThreadPool(max_workers)
|
||||
try:
|
||||
import tqdm # noqa
|
||||
for _ in tqdm.tqdm(pool.imap_unordered(self._calc_file_hash, file_entries), total=len(file_entries)):
|
||||
@ -2083,8 +2112,7 @@ class Dataset(object):
|
||||
|
||||
:return: Path to the local storage where the data was downloaded
|
||||
"""
|
||||
if not max_workers:
|
||||
max_workers = psutil.cpu_count()
|
||||
max_workers = max_workers or psutil.cpu_count()
|
||||
local_folder = self._extract_dataset_archive(
|
||||
force=force,
|
||||
selected_chunks=selected_chunks,
|
||||
@ -2189,8 +2217,7 @@ class Dataset(object):
|
||||
if not self._task:
|
||||
self._task = Task.get_task(task_id=self._id)
|
||||
|
||||
if not max_workers:
|
||||
max_workers = psutil.cpu_count()
|
||||
max_workers = max_workers or psutil.cpu_count()
|
||||
|
||||
data_artifact_entries = self._get_data_artifact_names()
|
||||
|
||||
@ -2305,8 +2332,7 @@ class Dataset(object):
|
||||
assert part is None or (isinstance(part, int) and part >= 0)
|
||||
assert num_parts is None or (isinstance(num_parts, int) and num_parts >= 1)
|
||||
|
||||
if max_workers is None:
|
||||
max_workers = psutil.cpu_count()
|
||||
max_workers = max_workers or psutil.cpu_count()
|
||||
|
||||
if use_soft_links is None:
|
||||
use_soft_links = False if is_windows() else True
|
||||
@ -2863,8 +2889,7 @@ class Dataset(object):
|
||||
):
|
||||
# type: (Path, List[str], dict, bool, bool, bool, Optional[int]) -> ()
|
||||
# create thread pool, for creating soft-links / copying
|
||||
if not max_workers:
|
||||
max_workers = psutil.cpu_count()
|
||||
max_workers = max_workers or psutil.cpu_count()
|
||||
pool = ThreadPool(max_workers)
|
||||
for dataset_version_id in dependencies_by_order:
|
||||
# make sure we skip over empty dependencies
|
||||
|
@ -1252,7 +1252,8 @@ class _HttpDriver(_Driver):
|
||||
requests_codes.service_unavailable,
|
||||
requests_codes.bandwidth_limit_exceeded,
|
||||
requests_codes.too_many_requests,
|
||||
]
|
||||
],
|
||||
config=config
|
||||
)
|
||||
self.attach_auth_header = any(
|
||||
(name.rstrip('/') == host.rstrip('/') or name.startswith(host.rstrip('/') + '/'))
|
||||
@ -1921,9 +1922,10 @@ class _GoogleCloudStorageDriver(_Driver):
|
||||
|
||||
|
||||
class _AzureBlobServiceStorageDriver(_Driver):
|
||||
scheme = 'azure'
|
||||
scheme = "azure"
|
||||
|
||||
_containers = {}
|
||||
_max_connections = deferred_config("azure.storage.max_connections", None)
|
||||
|
||||
class _Container(object):
|
||||
def __init__(self, name, config, account_url):
|
||||
@ -1965,9 +1967,10 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
)
|
||||
|
||||
def create_blob_from_data(
|
||||
self, container_name, object_name, blob_name, data, max_connections=2,
|
||||
self, container_name, object_name, blob_name, data, max_connections=None,
|
||||
progress_callback=None, content_settings=None
|
||||
):
|
||||
max_connections = max_connections or _AzureBlobServiceStorageDriver._max_connections
|
||||
if self.__legacy:
|
||||
self.__blob_service.create_blob_from_bytes(
|
||||
container_name,
|
||||
@ -1985,8 +1988,9 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
)
|
||||
|
||||
def create_blob_from_path(
|
||||
self, container_name, blob_name, path, max_connections=2, content_settings=None, progress_callback=None
|
||||
self, container_name, blob_name, path, max_connections=None, content_settings=None, progress_callback=None
|
||||
):
|
||||
max_connections = max_connections or _AzureBlobServiceStorageDriver._max_connections
|
||||
if self.__legacy:
|
||||
self.__blob_service.create_blob_from_path(
|
||||
container_name,
|
||||
@ -2045,7 +2049,8 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
client = self.__blob_service.get_blob_client(container_name, blob_name)
|
||||
return client.download_blob().content_as_bytes()
|
||||
|
||||
def get_blob_to_path(self, container_name, blob_name, path, max_connections=10, progress_callback=None):
|
||||
def get_blob_to_path(self, container_name, blob_name, path, max_connections=None, progress_callback=None):
|
||||
max_connections = max_connections or _AzureBlobServiceStorageDriver._max_connections
|
||||
if self.__legacy:
|
||||
return self.__blob_service.get_blob_to_path(
|
||||
container_name,
|
||||
@ -2078,10 +2083,11 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
self._containers[container_name] = self._Container(
|
||||
name=container_name, config=config, account_url=account_url
|
||||
)
|
||||
# self._containers[container_name].config.retries = kwargs.get('retries', 5)
|
||||
return self._containers[container_name]
|
||||
|
||||
def upload_object_via_stream(self, iterator, container, object_name, callback=None, extra=None, **kwargs):
|
||||
def upload_object_via_stream(
|
||||
self, iterator, container, object_name, callback=None, extra=None, max_connections=None, **kwargs
|
||||
):
|
||||
try:
|
||||
from azure.common import AzureHttpError # noqa
|
||||
except ImportError:
|
||||
@ -2096,17 +2102,17 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
object_name,
|
||||
blob_name,
|
||||
iterator.read() if hasattr(iterator, "read") else bytes(iterator),
|
||||
max_connections=2,
|
||||
max_connections=max_connections,
|
||||
progress_callback=callback,
|
||||
)
|
||||
)
|
||||
return True
|
||||
except AzureHttpError as ex:
|
||||
self.get_logger().error('Failed uploading (Azure error): %s' % ex)
|
||||
self.get_logger().error("Failed uploading (Azure error): %s" % ex)
|
||||
except Exception as ex:
|
||||
self.get_logger().error('Failed uploading: %s' % ex)
|
||||
self.get_logger().error("Failed uploading: %s" % ex)
|
||||
return False
|
||||
|
||||
def upload_object(self, file_path, container, object_name, callback=None, extra=None, **kwargs):
|
||||
def upload_object(self, file_path, container, object_name, callback=None, extra=None, max_connections=None, **kwargs):
|
||||
try:
|
||||
from azure.common import AzureHttpError # noqa
|
||||
except ImportError:
|
||||
@ -2123,7 +2129,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
container.name,
|
||||
blob_name,
|
||||
file_path,
|
||||
max_connections=2,
|
||||
max_connections=max_connections,
|
||||
content_settings=ContentSettings(content_type=get_file_mimetype(object_name or file_path)),
|
||||
progress_callback=callback,
|
||||
)
|
||||
@ -2177,10 +2183,10 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
else:
|
||||
return blob
|
||||
|
||||
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, **_):
|
||||
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, max_connections=None, **_):
|
||||
p = Path(local_path)
|
||||
if not overwrite_existing and p.is_file():
|
||||
self.get_logger().warning("failed saving after download: overwrite=False and file exists (%s)" % str(p))
|
||||
self.get_logger().warning("Failed saving after download: overwrite=False and file exists (%s)" % str(p))
|
||||
return
|
||||
|
||||
download_done = SafeEvent()
|
||||
@ -2200,7 +2206,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
container.name,
|
||||
obj.blob_name,
|
||||
local_path,
|
||||
max_connections=10,
|
||||
max_connections=max_connections,
|
||||
progress_callback=callback_func,
|
||||
)
|
||||
if container.is_legacy():
|
||||
@ -2836,3 +2842,4 @@ driver_schemes = set(
|
||||
)
|
||||
|
||||
remote_driver_schemes = driver_schemes - {_FileStorageDriver.scheme}
|
||||
cloud_driver_schemes = remote_driver_schemes - set(_HttpDriver.schemes)
|
||||
|
@ -1889,7 +1889,8 @@ class Task(_Task):
|
||||
preview=None, # type: Any
|
||||
wait_on_upload=False, # type: bool
|
||||
extension_name=None, # type: Optional[str]
|
||||
serialization_function=None # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||
serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||
retries=0 # type: int
|
||||
):
|
||||
# type: (...) -> bool
|
||||
"""
|
||||
@ -1946,6 +1947,8 @@ class Task(_Task):
|
||||
(e.g. `pandas.DataFrame.to_csv`), even if possible. To deserialize this artifact when getting
|
||||
it using the `Artifact.get` method, use its `deserialization_function` argument.
|
||||
|
||||
:param int retries: Number of retries before failing to upload artifact. If 0, the upload is not retried
|
||||
|
||||
:return: The status of the upload.
|
||||
|
||||
- ``True`` - Upload succeeded.
|
||||
@ -1953,17 +1956,31 @@ class Task(_Task):
|
||||
|
||||
:raise: If the artifact object type is not supported, raise a ``ValueError``.
|
||||
"""
|
||||
return self._artifacts_manager.upload_artifact(
|
||||
name=name,
|
||||
artifact_object=artifact_object,
|
||||
metadata=metadata,
|
||||
delete_after_upload=delete_after_upload,
|
||||
auto_pickle=auto_pickle,
|
||||
preview=preview,
|
||||
wait_on_upload=wait_on_upload,
|
||||
extension_name=extension_name,
|
||||
serialization_function=serialization_function,
|
||||
)
|
||||
exception_to_raise = None
|
||||
for retry in range(retries + 1):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if self._artifacts_manager.upload_artifact(
|
||||
name=name,
|
||||
artifact_object=artifact_object,
|
||||
metadata=metadata,
|
||||
delete_after_upload=delete_after_upload,
|
||||
auto_pickle=auto_pickle,
|
||||
preview=preview,
|
||||
wait_on_upload=wait_on_upload,
|
||||
extension_name=extension_name,
|
||||
serialization_function=serialization_function,
|
||||
):
|
||||
return True
|
||||
except Exception as e:
|
||||
exception_to_raise = e
|
||||
if retry < retries:
|
||||
getLogger().warning(
|
||||
"Failed uploading artifact '{}'. Retrying... ({}/{})".format(name, retry + 1, retries)
|
||||
)
|
||||
if exception_to_raise:
|
||||
raise exception_to_raise
|
||||
return False
|
||||
|
||||
def get_models(self):
|
||||
# type: () -> Mapping[str, Sequence[Model]]
|
||||
|
@ -127,6 +127,8 @@ sdk {
|
||||
# ]
|
||||
}
|
||||
azure.storage {
|
||||
# max_connections: 2
|
||||
|
||||
# containers: [
|
||||
# {
|
||||
# account_name: "clearml"
|
||||
|
Loading…
Reference in New Issue
Block a user