Fix clearml-data previews are saved on file server even when output_uri is specified

This commit is contained in:
allegroai 2023-01-19 11:16:21 +02:00
parent 896e949649
commit 812d22a3c7
2 changed files with 18 additions and 8 deletions

View File

@ -630,6 +630,7 @@ class Dataset(object):
# set output_url # set output_url
if output_url: if output_url:
self._task.output_uri = output_url self._task.output_uri = output_url
self._task.get_logger().set_default_upload_destination(output_url)
if not max_workers: if not max_workers:
max_workers = 1 if self._task.output_uri.startswith(tuple(cloud_driver_schemes)) else psutil.cpu_count() max_workers = 1 if self._task.output_uri.startswith(tuple(cloud_driver_schemes)) else psutil.cpu_count()
@ -1253,6 +1254,8 @@ class Dataset(object):
# noinspection PyProtectedMember # noinspection PyProtectedMember
instance._task.output_uri = output_uri instance._task.output_uri = output_uri
# noinspection PyProtectedMember # noinspection PyProtectedMember
instance._task.get_logger().set_default_upload_destination(output_uri)
# noinspection PyProtectedMember
instance._using_current_task = use_current_task instance._using_current_task = use_current_task
# noinspection PyProtectedMember # noinspection PyProtectedMember
instance._dataset_file_entries = dataset_file_entries instance._dataset_file_entries = dataset_file_entries
@ -2891,7 +2894,9 @@ class Dataset(object):
if artifact is not None: if artifact is not None:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if isinstance(artifact, pd.DataFrame): # we do not use report_table if default_upload_destination is set because it will
# not upload the sample to that destination, use report_media instead
if isinstance(artifact, pd.DataFrame) and not self._task.get_logger().get_default_upload_destination():
self._task.get_logger().report_table( self._task.get_logger().report_table(
"Tables", "summary", table_plot=artifact "Tables", "summary", table_plot=artifact
) )

View File

@ -1459,6 +1459,9 @@ class _Stream(object):
if self._input_iterator: if self._input_iterator:
try: try:
chunck = next(self._input_iterator) chunck = next(self._input_iterator)
# make sure we always return bytes
if isinstance(chunck, six.string_types):
chunck = chunck.encode("utf-8")
return chunck return chunck
except StopIteration: except StopIteration:
self.closed = True self.closed = True
@ -1607,6 +1610,7 @@ class _Boto3Driver(_Driver):
def upload_object_via_stream(self, iterator, container, object_name, callback=None, extra=None, **kwargs): def upload_object_via_stream(self, iterator, container, object_name, callback=None, extra=None, **kwargs):
import boto3.s3.transfer import boto3.s3.transfer
stream = _Stream(iterator) stream = _Stream(iterator)
extra_args = {} extra_args = {}
try: try:
@ -2077,7 +2081,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
client.upload_blob( client.upload_blob(
data, overwrite=True, data, overwrite=True,
content_settings=content_settings, content_settings=content_settings,
**self._get_max_connections_dict(max_connections) **self._get_max_connections_dict(max_connections, key="max_concurrency")
) )
def create_blob_from_path( def create_blob_from_path(
@ -2093,10 +2097,11 @@ class _AzureBlobServiceStorageDriver(_Driver):
**self._get_max_connections_dict(max_connections) **self._get_max_connections_dict(max_connections)
) )
else: else:
with open(path, "rb") as f:
self.create_blob_from_data( self.create_blob_from_data(
container_name, None, blob_name, open(path, "rb"), container_name, None, blob_name, f,
content_settings=content_settings, content_settings=content_settings,
**self._get_max_connections_dict(max_connections) max_connections=max_connections
) )
def delete_blob(self, container_name, blob_name): def delete_blob(self, container_name, blob_name):
@ -2154,7 +2159,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
client = self.__blob_service.get_blob_client(container_name, blob_name) client = self.__blob_service.get_blob_client(container_name, blob_name)
with open(path, "wb") as file: with open(path, "wb") as file:
return client.download_blob( return client.download_blob(
**self._get_max_connections_dict(max_connections, "max_concurrency") **self._get_max_connections_dict(max_connections, key="max_concurrency")
).download_to_stream(file) ).download_to_stream(file)
def is_legacy(self): def is_legacy(self):