mirror of
https://github.com/clearml/clearml
synced 2025-04-24 00:07:48 +00:00
Fix output_uri support for local folders
This commit is contained in:
parent
67fc8e3eb0
commit
1a658e9d89
@ -392,6 +392,12 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
fd, local_filename = mkstemp(suffix='.'+ext)
|
||||
os.close(fd)
|
||||
local_download = helper.download_to_file(uri, local_path=local_filename, overwrite_existing=True, verbose=True)
|
||||
# if we ended up without any local copy, delete the temp file
|
||||
if local_download != local_filename:
|
||||
try:
|
||||
Path(local_filename).unlink()
|
||||
except Exception:
|
||||
pass
|
||||
# save local model, so we can later query what was the original one
|
||||
Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri)
|
||||
return local_download
|
||||
|
@ -7,6 +7,7 @@ import os
|
||||
import shutil
|
||||
import sys
|
||||
import threading
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import namedtuple
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import copy
|
||||
@ -80,6 +81,50 @@ class _DownloadProgressReport(object):
|
||||
(self.downloaded_mb, self._total_size, speed, self._remote_path))
|
||||
|
||||
|
||||
@six.add_metaclass(ABCMeta)
|
||||
class _Driver(object):
|
||||
|
||||
@abstractmethod
|
||||
def get_container(self, container_name, config=None, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def test_upload(self, test_path, config, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def upload_object_via_stream(self, iterator, container, object_name, extra, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_container_objects(self, container, ex_prefix, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_direct_access(self, remote_path, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def download_object(self, obj, local_path, overwrite_existing, delete_on_failure, callback, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def download_object_as_stream(self, obj, chunk_size, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_object(self, obj, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def upload_object(self, file_path, container, object_name, extra, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_object(self, container_name, object_name, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class StorageHelper(object):
|
||||
""" Storage helper.
|
||||
Used by the entire system to download/upload files.
|
||||
@ -572,6 +617,11 @@ class StorageHelper(object):
|
||||
remote_path = self._canonize_url(remote_path)
|
||||
verbose = self._verbose if verbose is None else verbose
|
||||
|
||||
# Check if driver type supports direct access:
|
||||
direct_access_path = self._driver.get_direct_access(remote_path)
|
||||
if direct_access_path:
|
||||
return direct_access_path
|
||||
|
||||
temp_local_path = None
|
||||
try:
|
||||
if verbose:
|
||||
@ -893,8 +943,9 @@ class StorageHelper(object):
|
||||
return None
|
||||
|
||||
|
||||
class _HttpDriver(object):
|
||||
class _HttpDriver(_Driver):
|
||||
""" LibCloud http/https adapter (simple, enough for now) """
|
||||
|
||||
timeout = (5.0, 30.)
|
||||
|
||||
class _Container(object):
|
||||
@ -920,7 +971,7 @@ class _HttpDriver(object):
|
||||
self._retries = retries
|
||||
self._containers = {}
|
||||
|
||||
def get_container(self, container_name, *_, **kwargs):
|
||||
def get_container(self, container_name, config=None, **kwargs):
|
||||
if container_name not in self._containers:
|
||||
self._containers[container_name] = self._Container(name=container_name, retries=self._retries, **kwargs)
|
||||
return self._containers[container_name]
|
||||
@ -951,11 +1002,11 @@ class _HttpDriver(object):
|
||||
raise ValueError('Failed getting object %s (%d): %s' % (object_name, res.status_code, res.text))
|
||||
return res
|
||||
|
||||
def download_object_as_stream(self, obj, chunk_size=64 * 1024):
|
||||
def download_object_as_stream(self, obj, chunk_size=64 * 1024, **_):
|
||||
# return iterable object
|
||||
return obj.iter_content(chunk_size=chunk_size)
|
||||
|
||||
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None):
|
||||
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, **_):
|
||||
p = Path(local_path)
|
||||
if not overwrite_existing and p.is_file():
|
||||
log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
|
||||
@ -974,6 +1025,17 @@ class _HttpDriver(object):
|
||||
|
||||
return length
|
||||
|
||||
def get_direct_access(self, remote_path, **_):
|
||||
return None
|
||||
|
||||
def test_upload(self, test_path, config, **kwargs):
|
||||
return True
|
||||
|
||||
def upload_object(self, file_path, container, object_name, extra, **kwargs):
|
||||
with open(file_path, 'rb') as stream:
|
||||
return self.upload_object_via_stream(iterator=stream, container=container,
|
||||
object_name=object_name, extra=extra, **kwargs)
|
||||
|
||||
|
||||
class _Stream(object):
|
||||
encoding = None
|
||||
@ -1068,8 +1130,9 @@ class _Stream(object):
|
||||
self.write(s)
|
||||
|
||||
|
||||
class _Boto3Driver(object):
|
||||
class _Boto3Driver(_Driver):
|
||||
""" Boto3 storage adapter (simple, enough for now) """
|
||||
|
||||
_max_multipart_concurrency = config.get('aws.boto3.max_multipart_concurrency', 16)
|
||||
|
||||
_min_pool_connections = 512
|
||||
@ -1134,9 +1197,9 @@ class _Boto3Driver(object):
|
||||
self._stream_download_pool = ThreadPoolExecutor(max_workers=self._stream_download_pool_connections)
|
||||
return self._stream_download_pool
|
||||
|
||||
def get_container(self, container_name, *_, **kwargs):
|
||||
def get_container(self, container_name, config=None, **kwargs):
|
||||
if container_name not in self._containers:
|
||||
self._containers[container_name] = self._Container(name=container_name, cfg=kwargs.get('config'))
|
||||
self._containers[container_name] = self._Container(name=container_name, cfg=config)
|
||||
self._containers[container_name].config.retries = kwargs.get('retries', 5)
|
||||
return self._containers[container_name]
|
||||
|
||||
@ -1183,7 +1246,7 @@ class _Boto3Driver(object):
|
||||
obj.container_name = full_container_name
|
||||
return obj
|
||||
|
||||
def download_object_as_stream(self, obj, chunk_size=64 * 1024):
|
||||
def download_object_as_stream(self, obj, chunk_size=64 * 1024, **_):
|
||||
def async_download(a_obj, a_stream, cfg):
|
||||
try:
|
||||
a_obj.download_fileobj(a_stream, Config=cfg)
|
||||
@ -1203,7 +1266,7 @@ class _Boto3Driver(object):
|
||||
|
||||
return stream
|
||||
|
||||
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None):
|
||||
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, **_):
|
||||
import boto3.s3.transfer
|
||||
p = Path(local_path)
|
||||
if not overwrite_existing and p.is_file():
|
||||
@ -1311,8 +1374,14 @@ class _Boto3Driver(object):
|
||||
|
||||
return None
|
||||
|
||||
def get_direct_access(self, remote_path, **_):
|
||||
return None
|
||||
|
||||
class _GoogleCloudStorageDriver(object):
|
||||
def test_upload(self, test_path, config, **_):
|
||||
return True
|
||||
|
||||
|
||||
class _GoogleCloudStorageDriver(_Driver):
|
||||
"""Storage driver for google cloud storage"""
|
||||
|
||||
_stream_download_pool_connections = 128
|
||||
@ -1350,9 +1419,9 @@ class _GoogleCloudStorageDriver(object):
|
||||
self._stream_download_pool = ThreadPoolExecutor(max_workers=self._stream_download_pool_connections)
|
||||
return self._stream_download_pool
|
||||
|
||||
def get_container(self, container_name, *_, **kwargs):
|
||||
def get_container(self, container_name, config=None, **kwargs):
|
||||
if container_name not in self._containers:
|
||||
self._containers[container_name] = self._Container(name=container_name, cfg=kwargs.get('config'))
|
||||
self._containers[container_name] = self._Container(name=container_name, cfg=config)
|
||||
self._containers[container_name].config.retries = kwargs.get('retries', 5)
|
||||
return self._containers[container_name]
|
||||
|
||||
@ -1387,7 +1456,7 @@ class _GoogleCloudStorageDriver(object):
|
||||
obj.container_name = full_container_name
|
||||
return obj
|
||||
|
||||
def download_object_as_stream(self, obj, chunk_size=256 * 1024):
|
||||
def download_object_as_stream(self, obj, chunk_size=256 * 1024, **_):
|
||||
raise NotImplementedError('Unsupported for google storage')
|
||||
|
||||
def async_download(a_obj, a_stream):
|
||||
@ -1404,14 +1473,14 @@ class _GoogleCloudStorageDriver(object):
|
||||
|
||||
return stream
|
||||
|
||||
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None):
|
||||
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, **_):
|
||||
p = Path(local_path)
|
||||
if not overwrite_existing and p.is_file():
|
||||
log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
|
||||
return
|
||||
obj.download_to_filename(str(p))
|
||||
|
||||
def test_upload(self, test_path, config):
|
||||
def test_upload(self, test_path, config, **_):
|
||||
bucket_url = str(furl(scheme=self.scheme, netloc=config.bucket, path=config.subdir))
|
||||
bucket = self.get_container(container_name=bucket_url, config=config).bucket
|
||||
|
||||
@ -1429,8 +1498,11 @@ class _GoogleCloudStorageDriver(object):
|
||||
permissions_to_test = ('storage.objects.get', 'storage.objects.update')
|
||||
return set(test_obj.test_iam_permissions(permissions_to_test)) == set(permissions_to_test)
|
||||
|
||||
def get_direct_access(self, remote_path, **_):
|
||||
return None
|
||||
|
||||
class _AzureBlobServiceStorageDriver(object):
|
||||
|
||||
class _AzureBlobServiceStorageDriver(_Driver):
|
||||
scheme = 'azure'
|
||||
|
||||
_containers = {}
|
||||
@ -1459,8 +1531,8 @@ class _AzureBlobServiceStorageDriver(object):
|
||||
blob_name = attrib()
|
||||
content_length = attrib()
|
||||
|
||||
def get_container(self, config, *_, **kwargs):
|
||||
container_name = config.container_name
|
||||
def get_container(self, container_name=None, config=None, **kwargs):
|
||||
container_name = container_name or config.container_name
|
||||
if container_name not in self._containers:
|
||||
self._containers[container_name] = self._Container(name=container_name, config=config)
|
||||
# self._containers[container_name].config.retries = kwargs.get('retries', 5)
|
||||
@ -1542,7 +1614,7 @@ class _AzureBlobServiceStorageDriver(object):
|
||||
)
|
||||
return blob.content
|
||||
|
||||
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None):
|
||||
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None, **_):
|
||||
p = Path(local_path)
|
||||
if not overwrite_existing and p.is_file():
|
||||
log.warning('failed saving after download: overwrite=False and file exists (%s)' % str(p))
|
||||
@ -1570,7 +1642,7 @@ class _AzureBlobServiceStorageDriver(object):
|
||||
)
|
||||
download_done.wait()
|
||||
|
||||
def test_upload(self, test_path, config):
|
||||
def test_upload(self, test_path, config, **_):
|
||||
container = self.get_container(config=config)
|
||||
try:
|
||||
container.blob_service.get_container_properties(container.name)
|
||||
@ -1620,8 +1692,11 @@ class _AzureBlobServiceStorageDriver(object):
|
||||
|
||||
return name
|
||||
|
||||
def get_direct_access(self, remote_path, **_):
|
||||
return None
|
||||
|
||||
class _FileStorageDriver(object):
|
||||
|
||||
class _FileStorageDriver(_Driver):
|
||||
"""
|
||||
A base StorageDriver to derive from.
|
||||
"""
|
||||
@ -1775,7 +1850,7 @@ class _FileStorageDriver(object):
|
||||
|
||||
return self._get_objects(container)
|
||||
|
||||
def get_container(self, container_name):
|
||||
def get_container(self, container_name, **_):
|
||||
"""
|
||||
Return a container instance.
|
||||
|
||||
@ -1807,7 +1882,7 @@ class _FileStorageDriver(object):
|
||||
|
||||
return path
|
||||
|
||||
def get_object(self, container_name, object_name):
|
||||
def get_object(self, container_name, object_name, **_):
|
||||
"""
|
||||
Return an object instance.
|
||||
|
||||
@ -1835,8 +1910,7 @@ class _FileStorageDriver(object):
|
||||
"""
|
||||
return os.path.realpath(os.path.join(self.base_path, obj.container.name, obj.name))
|
||||
|
||||
def download_object(self, obj, destination_path, overwrite_existing=False,
|
||||
delete_on_failure=True):
|
||||
def download_object(self, obj, destination_path, overwrite_existing=False, delete_on_failure=True, **_):
|
||||
"""
|
||||
Download an object to the specified destination path.
|
||||
|
||||
@ -1888,7 +1962,7 @@ class _FileStorageDriver(object):
|
||||
|
||||
return True
|
||||
|
||||
def download_object_as_stream(self, obj, chunk_size=None):
|
||||
def download_object_as_stream(self, obj, chunk_size=None, **_):
|
||||
"""
|
||||
Return a generator which yields object data.
|
||||
|
||||
@ -1906,8 +1980,7 @@ class _FileStorageDriver(object):
|
||||
for data in self._read_in_chunks(obj_file, chunk_size=chunk_size):
|
||||
yield data
|
||||
|
||||
def upload_object(self, file_path, container, object_name, extra=None,
|
||||
verify_hash=True):
|
||||
def upload_object(self, file_path, container, object_name, extra=None, verify_hash=True, **_):
|
||||
"""
|
||||
Upload an object currently located on a disk.
|
||||
|
||||
@ -1941,9 +2014,7 @@ class _FileStorageDriver(object):
|
||||
|
||||
return self._make_object(container, object_name)
|
||||
|
||||
def upload_object_via_stream(self, iterator, container,
|
||||
object_name,
|
||||
extra=None):
|
||||
def upload_object_via_stream(self, iterator, container, object_name, extra=None, **kwargs):
|
||||
"""
|
||||
Upload an object using an iterator.
|
||||
|
||||
@ -1990,7 +2061,7 @@ class _FileStorageDriver(object):
|
||||
os.chmod(obj_path, int('664', 8))
|
||||
return self._make_object(container, object_name)
|
||||
|
||||
def delete_object(self, obj):
|
||||
def delete_object(self, obj, **_):
|
||||
"""
|
||||
Delete an object.
|
||||
|
||||
@ -2144,3 +2215,12 @@ class _FileStorageDriver(object):
|
||||
else:
|
||||
yield data
|
||||
data = bytes('')
|
||||
|
||||
def get_direct_access(self, remote_path, **_):
|
||||
# this will always make sure we have full path and file:// prefix
|
||||
full_url = StorageHelper.conform_url(remote_path)
|
||||
# now get rid of the file:// prefix
|
||||
return Path(full_url[7:]).as_posix()
|
||||
|
||||
def test_upload(self, test_path, config, **kwargs):
|
||||
return True
|
||||
|
Loading…
Reference in New Issue
Block a user