Fix clearml-data previews are saved on file server even when output_uri is specified

This commit is contained in:
allegroai 2023-01-19 11:16:21 +02:00
parent 896e949649
commit 812d22a3c7
2 changed files with 18 additions and 8 deletions

View File

@ -630,6 +630,7 @@ class Dataset(object):
# set output_url
if output_url:
self._task.output_uri = output_url
self._task.get_logger().set_default_upload_destination(output_url)
if not max_workers:
max_workers = 1 if self._task.output_uri.startswith(tuple(cloud_driver_schemes)) else psutil.cpu_count()
@ -1252,6 +1253,8 @@ class Dataset(object):
if output_uri and not Task._offline_mode:
# noinspection PyProtectedMember
instance._task.output_uri = output_uri
# noinspection PyProtectedMember
instance._task.get_logger().set_default_upload_destination(output_uri)
# noinspection PyProtectedMember
instance._using_current_task = use_current_task
# noinspection PyProtectedMember
@ -2891,7 +2894,9 @@ class Dataset(object):
if artifact is not None:
# noinspection PyBroadException
try:
if isinstance(artifact, pd.DataFrame):
# we do not use report_table if default_upload_destination is set because it will
# not upload the sample to that destination, use report_media instead
if isinstance(artifact, pd.DataFrame) and not self._task.get_logger().get_default_upload_destination():
self._task.get_logger().report_table(
"Tables", "summary", table_plot=artifact
)

View File

@ -1459,6 +1459,9 @@ class _Stream(object):
if self._input_iterator:
try:
chunck = next(self._input_iterator)
# make sure we always return bytes
if isinstance(chunck, six.string_types):
chunck = chunck.encode("utf-8")
return chunck
except StopIteration:
self.closed = True
@ -1607,6 +1610,7 @@ class _Boto3Driver(_Driver):
def upload_object_via_stream(self, iterator, container, object_name, callback=None, extra=None, **kwargs):
import boto3.s3.transfer
stream = _Stream(iterator)
extra_args = {}
try:
@ -2077,7 +2081,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
client.upload_blob(
data, overwrite=True,
content_settings=content_settings,
**self._get_max_connections_dict(max_connections)
**self._get_max_connections_dict(max_connections, key="max_concurrency")
)
def create_blob_from_path(
@ -2093,11 +2097,12 @@ class _AzureBlobServiceStorageDriver(_Driver):
**self._get_max_connections_dict(max_connections)
)
else:
self.create_blob_from_data(
container_name, None, blob_name, open(path, "rb"),
content_settings=content_settings,
**self._get_max_connections_dict(max_connections)
)
with open(path, "rb") as f:
self.create_blob_from_data(
container_name, None, blob_name, f,
content_settings=content_settings,
max_connections=max_connections
)
def delete_blob(self, container_name, blob_name):
if self.__legacy:
@ -2154,7 +2159,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
client = self.__blob_service.get_blob_client(container_name, blob_name)
with open(path, "wb") as file:
return client.download_blob(
**self._get_max_connections_dict(max_connections, "max_concurrency")
**self._get_max_connections_dict(max_connections, key="max_concurrency")
).download_to_stream(file)
def is_legacy(self):