Fix anonymous named Models and input model registration based on remote url instead of local file (when StorageManager is used to download)

This commit is contained in:
allegroai 2020-06-19 00:50:02 +03:00
parent 7ab93e7dba
commit 9a7850b23d
2 changed files with 45 additions and 3 deletions

View File

@ -11,10 +11,12 @@ from typing import List, Dict, Union, Optional, Mapping, TYPE_CHECKING, Sequence
from .backend_api import Session from .backend_api import Session
from .backend_api.services import models from .backend_api.services import models
from pathlib2 import Path from pathlib2 import Path
from .utilities.pyhocon import ConfigFactory, HOCONConverter from .utilities.pyhocon import ConfigFactory, HOCONConverter
from .backend_interface.util import validate_dict, get_single_result, mutually_exclusive from .backend_interface.util import validate_dict, get_single_result, mutually_exclusive
from .debugging.log import get_logger from .debugging.log import get_logger
from .storage.cache import CacheManager
from .storage.helper import StorageHelper from .storage.helper import StorageHelper
from .utilities.enum import Options from .utilities.enum import Options
from .backend_interface import Task as _Task from .backend_interface import Task as _Task
@ -526,6 +528,9 @@ class InputModel(Model):
weights_url = StorageHelper.conform_url(weights_url) weights_url = StorageHelper.conform_url(weights_url)
if not weights_url: if not weights_url:
raise ValueError("Please provide a valid weights_url parameter") raise ValueError("Please provide a valid weights_url parameter")
# convert local to file to remote one
weights_url = CacheManager.get_remote_url(weights_url)
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \ extra = {'system_tags': ["-" + ARCHIVED_TAG]} \
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]} if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]}
result = _Model._get_default_session().send(models.GetAllRequest( result = _Model._get_default_session().send(models.GetAllRequest(
@ -561,7 +566,10 @@ class InputModel(Model):
if task: if task:
comment = 'Imported by task id: {}'.format(task.id) + ('\n' + comment if comment else '') comment = 'Imported by task id: {}'.format(task.id) + ('\n' + comment if comment else '')
project_id = task.project project_id = task.project
task_id = task.id name = name or task.name
# do not register the Task, because we do not want it listed after as "output model",
# the Task never actually created the Model
task_id = None
else: else:
project_id = None project_id = None
task_id = None task_id = None
@ -624,6 +632,10 @@ class InputModel(Model):
weights_url = StorageHelper.conform_url(weights_url) weights_url = StorageHelper.conform_url(weights_url)
if not weights_url: if not weights_url:
raise ValueError("Please provide a valid weights_url parameter") raise ValueError("Please provide a valid weights_url parameter")
# convert local to file to remote one
weights_url = CacheManager.get_remote_url(weights_url)
if not load_archived: if not load_archived:
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \ extra = {'system_tags': ["-" + ARCHIVED_TAG]} \
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]} if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]}
@ -919,7 +931,7 @@ class OutputModel(BaseModel):
self._floating_data = create_dummy_model( self._floating_data = create_dummy_model(
design=_Model._wrap_design(config_text), design=_Model._wrap_design(config_text),
labels=label_enumeration or task.get_labels_enumeration(), labels=label_enumeration or task.get_labels_enumeration(),
name=name, name=name or task.name,
tags=tags, tags=tags,
comment='{} by task id: {}'.format('Created' if not base_model_id else 'Overwritten', task.id) + comment='{} by task id: {}'.format('Created' if not base_model_id else 'Overwritten', task.id) +
('\n' + comment if comment else ''), ('\n' + comment if comment else ''),

View File

@ -1,6 +1,7 @@
import hashlib import hashlib
import shutil import shutil
from collections import OrderedDict
from pathlib2 import Path from pathlib2 import Path
from .helper import StorageHelper from .helper import StorageHelper
@ -14,6 +15,8 @@ class CacheManager(object):
_default_cache_file_limit = 100 _default_cache_file_limit = 100
_storage_manager_folder = "storage_manager" _storage_manager_folder = "storage_manager"
_default_context = "global" _default_context = "global"
_local_to_remote_url_lookup = OrderedDict()
__local_to_remote_url_lookup_max_size = 1024
class CacheContext(object): class CacheContext(object):
def __init__(self, cache_context, default_cache_file_limit=10): def __init__(self, cache_context, default_cache_file_limit=10):
@ -30,6 +33,7 @@ class CacheManager(object):
raise ValueError("Storage access failed: {}".format(remote_url)) raise ValueError("Storage access failed: {}".format(remote_url))
# check if we need to cache the file # check if we need to cache the file
try: try:
# noinspection PyProtectedMember
direct_access = helper._driver.get_direct_access(remote_url) direct_access = helper._driver.get_direct_access(remote_url)
except (OSError, ValueError): except (OSError, ValueError):
LoggerRoot.get_base_logger().warning("Failed accessing local file: {}".format(remote_url)) LoggerRoot.get_base_logger().warning("Failed accessing local file: {}".format(remote_url))
@ -41,20 +45,24 @@ class CacheManager(object):
# check if we already have the file in our cache # check if we already have the file in our cache
cached_file, cached_size = self._get_cache_file(remote_url) cached_file, cached_size = self._get_cache_file(remote_url)
if cached_size is not None: if cached_size is not None:
CacheManager._add_remote_url(remote_url, cached_file)
return cached_file return cached_file
# we need to download the file: # we need to download the file:
downloaded_file = helper.download_to_file(remote_url, cached_file) downloaded_file = helper.download_to_file(remote_url, cached_file)
if downloaded_file != cached_file: if downloaded_file != cached_file:
# something happened # something happened
return None return None
CacheManager._add_remote_url(remote_url, cached_file)
return cached_file return cached_file
@staticmethod @staticmethod
def upload_file(local_file, remote_url, wait_for_upload=True): def upload_file(local_file, remote_url, wait_for_upload=True):
helper = StorageHelper.get(remote_url) helper = StorageHelper.get(remote_url)
return helper.upload( result = helper.upload(
local_file, remote_url, async_enable=not wait_for_upload local_file, remote_url, async_enable=not wait_for_upload
) )
CacheManager._add_remote_url(remote_url, local_file)
return result
@classmethod @classmethod
def _get_hashed_url_file(cls, url): def _get_hashed_url_file(cls, url):
@ -123,3 +131,25 @@ class CacheManager(object):
cls.__cache_managers[cache_context].set_cache_limit(cache_file_limit) cls.__cache_managers[cache_context].set_cache_limit(cache_file_limit)
return cls.__cache_managers[cache_context] return cls.__cache_managers[cache_context]
@staticmethod
def get_remote_url(local_copy_path):
if not CacheManager._local_to_remote_url_lookup:
return local_copy_path
conform_local_copy_path = StorageHelper.conform_url(local_copy_path)
return CacheManager._local_to_remote_url_lookup.get(hash(conform_local_copy_path), local_copy_path)
@staticmethod
def _add_remote_url(remote_url, local_copy_path):
# so that we can disable the cache lookup altogether
if CacheManager._local_to_remote_url_lookup is None:
return
remote_url = StorageHelper.conform_url(remote_url)
if remote_url.startswith('file://'):
return
local_copy_path = StorageHelper.conform_url(local_copy_path)
CacheManager._local_to_remote_url_lookup[hash(local_copy_path)] = remote_url
# protect against overuse, so we do not blowup the memory
if len(CacheManager._local_to_remote_url_lookup) > CacheManager.__local_to_remote_url_lookup_max_size:
# pop the first item (FIFO)
CacheManager._local_to_remote_url_lookup.popitem(last=False)