Add StorageManager

This commit is contained in:
allegroai 2020-04-09 12:03:41 +03:00
parent e1fc9b3dc8
commit e6f29428eb
8 changed files with 151 additions and 27 deletions

View File

@ -10,7 +10,8 @@ from ..backend_api import Session
from ..backend_api.services import models
from .base import IdObjectBase
from .util import make_message
from ..storage import StorageHelper
from ..storage import StorageManager
from ..storage.helper import StorageHelper
from ..utilities.async_manager import AsyncManagerMixin
ModelPackage = namedtuple('ModelPackage', 'weights design')
@ -54,10 +55,6 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
def model_id(self):
return self.id
@property
def storage(self):
return StorageHelper.get(self.upload_storage_uri)
def __init__(self, upload_storage_uri, cache_dir, model_id=None,
upload_storage_suffix='models', session=None, log=None):
super(Model, self).__init__(id=model_id, session=session, log=log)
@ -84,10 +81,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
def _upload_model(self, model_file, async_enable=False, target_filename=None, cb=None):
if not self.upload_storage_uri:
raise ValueError('Model has no storage URI defined (nowhere to upload to)')
helper = self.storage
target_filename = target_filename or Path(model_file).name
dest_path = '/'.join((self.upload_storage_uri, self._upload_storage_suffix or '.', target_filename))
result = helper.upload(
result = StorageHelper.get(dest_path).upload(
src_path=model_file,
dest_path=dest_path,
async_enable=async_enable,
@ -412,7 +408,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
# remove non existing model file
Model._local_model_to_id_uri.pop(dl_file, None)
local_download = StorageHelper.get(uri).get_local_copy(uri)
local_download = StorageManager.get_local_copy(uri)
# save local model, so we can later query what was the original one
Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri)

View File

@ -1,7 +1,7 @@
from abc import abstractproperty
from ..backend_config.bucket_config import S3BucketConfig
from ..storage import StorageHelper
from ..storage.helper import StorageHelper
class SetupUploadMixin(object):

View File

@ -147,15 +147,6 @@ class Artifact(object):
if self._object is None:
self._object = local_file
else:
from trains.storage.helper import StorageHelper
# only of we are not using cache, we should delete the file
if not hasattr(StorageHelper, 'get_cached_disabled'):
# delete the temporary file, we already used it
try:
local_file.unlink()
except Exception:
pass
return self._object
@ -165,8 +156,8 @@ class Artifact(object):
The returned path will be a temporary folder containing the archive content
:return: a local path to a downloaded copy of the artifact
"""
from trains.storage.helper import StorageHelper
local_path = StorageHelper.get_local_copy(self.url)
from trains.storage import StorageManager
local_path = StorageManager.get_local_copy(self.url)
if local_path and extract_archive and self.type == 'archive':
temp_folder = None
try:
@ -179,10 +170,6 @@ class Artifact(object):
except Exception:
pass
return local_path
try:
Path(local_path).unlink()
except Exception:
pass
return temp_folder
return local_path

View File

@ -20,7 +20,7 @@ from .backend_interface.util import mutually_exclusive
from .config import running_remotely, get_cache_dir, config
from .debugging.log import LoggerRoot
from .errors import UsageError
from .storage import StorageHelper
from .storage.helper import StorageHelper
from .utilities.plotly_reporter import SeriesInfo
# Make sure that DeprecationWarning within this package always gets printed

View File

@ -14,7 +14,7 @@ from .utilities.pyhocon import ConfigFactory, HOCONConverter
from .backend_interface.util import validate_dict, get_single_result, mutually_exclusive
from .debugging.log import get_logger
from .storage import StorageHelper
from .storage.helper import StorageHelper
from .utilities.enum import Options
from .backend_interface import Task as _Task
from .backend_interface.model import create_dummy_model, Model as _Model

View File

@ -1,2 +1,2 @@
""" Local and remote storage support """
from .helper import StorageHelper
from .manager import StorageManager

85
trains/storage/cache.py Normal file
View File

@ -0,0 +1,85 @@
import hashlib
from pathlib2 import Path
from .helper import StorageHelper
from .util import quote_url
from ..config import get_cache_dir
class CacheManager(object):
__cache_managers = {}
_default_cache_file_limit = 100
_storage_manager_folder = 'storage_manager'
_default_context = 'global'
class CacheContext(object):
def __init__(self, cache_context, default_cache_file_limit=10):
self._context = str(cache_context)
self._file_limit = int(default_cache_file_limit)
def set_cache_limit(self, cache_file_limit):
self._file_limit = max(self._file_limit, int(cache_file_limit))
return self._file_limit
def get_local_copy(self, remote_url):
helper = StorageHelper.get(remote_url)
if not helper:
raise ValueError("Remote storage not supported: {}".format(remote_url))
# check if we need to cache the file
direct_access = helper._driver.get_direct_access(remote_url)
if direct_access:
return direct_access
# check if we already have the file in our cache
cached_file, cached_size = self._get_cache_file(remote_url)
if cached_size is not None:
return cached_file
# we need to download the file:
downloaded_file = helper.download_to_file(remote_url, cached_file)
if downloaded_file != cached_file:
# something happened
return None
return cached_file
@staticmethod
def upload_file(local_file, remote_url, wait_for_upload=True):
helper = StorageHelper.get(remote_url)
return helper.upload(local_file, remote_url, async_enable=not wait_for_upload)
@classmethod
def _get_hashed_url_file(cls, url):
str_hash = hashlib.md5(url.encode()).hexdigest()
filename = url.split('/')[-1]
return '{}.{}'.format(str_hash, quote_url(filename))
def _get_cache_file(self, remote_url):
"""
:param remote_url: check if we have the remote url in our cache
:return: full path to file name, current file size or None
"""
folder = Path(get_cache_dir() / CacheManager._storage_manager_folder / self._context)
folder.mkdir(parents=True, exist_ok=True)
local_filename = self._get_hashed_url_file(remote_url)
new_file = folder / local_filename
if new_file.exists():
new_file.touch(exist_ok=True)
# delete old files
files = sorted(folder.iterdir(), reverse=True, key=lambda x: x.stat().st_atime)
files = files[self._file_limit:]
for f in files:
f.unlink()
# if file doesn't exist, return file size None
return new_file.as_posix(), new_file.stat().st_size if new_file.exists() else None
@classmethod
def get_cache_manager(cls, cache_context=None, cache_file_limit=None):
cache_context = cache_context or cls._default_context
if cache_context not in cls.__cache_managers:
cls.__cache_managers[cache_context] = cls.CacheContext(
cache_context, cache_file_limit or cls._default_cache_file_limit)
if cache_file_limit:
cls.__cache_managers[cache_context].set_cache_limit(cache_file_limit)
return cls.__cache_managers[cache_context]

56
trains/storage/manager.py Normal file
View File

@ -0,0 +1,56 @@
from typing import Optional
from .cache import CacheManager
class StorageManager(object):
"""
StorageManager is helper interface for downloading & uploading files to supported remote storage
Support remote servers: http(s)/S3/GS/Azure/File-System-Folder
Cache is enabled by default for all downloaded remote urls/files
"""
@classmethod
def get_local_copy(cls, remote_url, cache_context=None): # type: (str, Optional[str]) -> str
"""
Get a local copy of the remote file. If the remote URL is a direct file access,
the returned link is the same, otherwise a link to a local copy of the url file is returned.
Caching is enabled by default, cache limited by number of stored files per cache context.
Oldest accessed files are deleted when cache is full.
:param str remote_url: remote url link (string)
:param str cache_context: Optional caching context identifier (string), default context 'global'
:return str: full path to local copy of the requested url. Return None on Error.
"""
return CacheManager.get_cache_manager(cache_context=cache_context).get_local_copy(remote_url=remote_url)
@classmethod
def upload_file(cls, local_file, remote_url, wait_for_upload=True): # type: (str, str, bool) -> str
"""
Upload a local file to a remote location.
remote url is the finale destination of the uploaded file.
Examples:
upload_file('/tmp/artifact.yaml', 'http://localhost:8081/manual_artifacts/my_artifact.yaml')
upload_file('/tmp/artifact.yaml', 's3://a_bucket/artifacts/my_artifact.yaml')
upload_file('/tmp/artifact.yaml', '/mnt/share/folder/artifacts/my_artifact.yaml')
:param str local_file: Full path of a local file to be uploaded
:param str remote_url: Full path or remote url to upload to (including file name)
:param bool wait_for_upload: If False, return immediately and upload in the background. Default True.
:return str: Newly uploaded remote url
"""
return CacheManager.get_cache_manager().upload_file(
local_file=local_file, remote_url=remote_url, wait_for_upload=wait_for_upload)
@classmethod
def set_cache_file_limit(cls, cache_file_limit, cache_context=None): # type: (int, Optional[str]) -> int
"""
Set the cache context file limit. File limit is the maximum number of files the specific cache context holds.
Notice, there is no limit on the size of these files, only the total number of cached files.
:param int cache_file_limit: New maximum number of cached files
:param str cache_context: Optional cache context identifier, default global context
:return int: Return new cache context file limit
"""
return CacheManager.get_cache_manager(
cache_context=cache_context, cache_file_limit=cache_file_limit).set_cache_limit(cache_file_limit)