mirror of
https://github.com/clearml/clearml
synced 2025-03-03 18:52:12 +00:00
Fix anonymous named Models and input model registration based on remote url instead of local file (when StorageManager is used to download)
This commit is contained in:
parent
7ab93e7dba
commit
9a7850b23d
@ -11,10 +11,12 @@ from typing import List, Dict, Union, Optional, Mapping, TYPE_CHECKING, Sequence
|
|||||||
from .backend_api import Session
|
from .backend_api import Session
|
||||||
from .backend_api.services import models
|
from .backend_api.services import models
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
|
||||||
from .utilities.pyhocon import ConfigFactory, HOCONConverter
|
from .utilities.pyhocon import ConfigFactory, HOCONConverter
|
||||||
|
|
||||||
from .backend_interface.util import validate_dict, get_single_result, mutually_exclusive
|
from .backend_interface.util import validate_dict, get_single_result, mutually_exclusive
|
||||||
from .debugging.log import get_logger
|
from .debugging.log import get_logger
|
||||||
|
from .storage.cache import CacheManager
|
||||||
from .storage.helper import StorageHelper
|
from .storage.helper import StorageHelper
|
||||||
from .utilities.enum import Options
|
from .utilities.enum import Options
|
||||||
from .backend_interface import Task as _Task
|
from .backend_interface import Task as _Task
|
||||||
@ -526,6 +528,9 @@ class InputModel(Model):
|
|||||||
weights_url = StorageHelper.conform_url(weights_url)
|
weights_url = StorageHelper.conform_url(weights_url)
|
||||||
if not weights_url:
|
if not weights_url:
|
||||||
raise ValueError("Please provide a valid weights_url parameter")
|
raise ValueError("Please provide a valid weights_url parameter")
|
||||||
|
# convert local to file to remote one
|
||||||
|
weights_url = CacheManager.get_remote_url(weights_url)
|
||||||
|
|
||||||
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \
|
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \
|
||||||
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]}
|
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]}
|
||||||
result = _Model._get_default_session().send(models.GetAllRequest(
|
result = _Model._get_default_session().send(models.GetAllRequest(
|
||||||
@ -561,7 +566,10 @@ class InputModel(Model):
|
|||||||
if task:
|
if task:
|
||||||
comment = 'Imported by task id: {}'.format(task.id) + ('\n' + comment if comment else '')
|
comment = 'Imported by task id: {}'.format(task.id) + ('\n' + comment if comment else '')
|
||||||
project_id = task.project
|
project_id = task.project
|
||||||
task_id = task.id
|
name = name or task.name
|
||||||
|
# do not register the Task, because we do not want it listed after as "output model",
|
||||||
|
# the Task never actually created the Model
|
||||||
|
task_id = None
|
||||||
else:
|
else:
|
||||||
project_id = None
|
project_id = None
|
||||||
task_id = None
|
task_id = None
|
||||||
@ -624,6 +632,10 @@ class InputModel(Model):
|
|||||||
weights_url = StorageHelper.conform_url(weights_url)
|
weights_url = StorageHelper.conform_url(weights_url)
|
||||||
if not weights_url:
|
if not weights_url:
|
||||||
raise ValueError("Please provide a valid weights_url parameter")
|
raise ValueError("Please provide a valid weights_url parameter")
|
||||||
|
|
||||||
|
# convert local to file to remote one
|
||||||
|
weights_url = CacheManager.get_remote_url(weights_url)
|
||||||
|
|
||||||
if not load_archived:
|
if not load_archived:
|
||||||
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \
|
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \
|
||||||
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]}
|
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]}
|
||||||
@ -919,7 +931,7 @@ class OutputModel(BaseModel):
|
|||||||
self._floating_data = create_dummy_model(
|
self._floating_data = create_dummy_model(
|
||||||
design=_Model._wrap_design(config_text),
|
design=_Model._wrap_design(config_text),
|
||||||
labels=label_enumeration or task.get_labels_enumeration(),
|
labels=label_enumeration or task.get_labels_enumeration(),
|
||||||
name=name,
|
name=name or task.name,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
comment='{} by task id: {}'.format('Created' if not base_model_id else 'Overwritten', task.id) +
|
comment='{} by task id: {}'.format('Created' if not base_model_id else 'Overwritten', task.id) +
|
||||||
('\n' + comment if comment else ''),
|
('\n' + comment if comment else ''),
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
|
||||||
from .helper import StorageHelper
|
from .helper import StorageHelper
|
||||||
@ -14,6 +15,8 @@ class CacheManager(object):
|
|||||||
_default_cache_file_limit = 100
|
_default_cache_file_limit = 100
|
||||||
_storage_manager_folder = "storage_manager"
|
_storage_manager_folder = "storage_manager"
|
||||||
_default_context = "global"
|
_default_context = "global"
|
||||||
|
_local_to_remote_url_lookup = OrderedDict()
|
||||||
|
__local_to_remote_url_lookup_max_size = 1024
|
||||||
|
|
||||||
class CacheContext(object):
|
class CacheContext(object):
|
||||||
def __init__(self, cache_context, default_cache_file_limit=10):
|
def __init__(self, cache_context, default_cache_file_limit=10):
|
||||||
@ -30,6 +33,7 @@ class CacheManager(object):
|
|||||||
raise ValueError("Storage access failed: {}".format(remote_url))
|
raise ValueError("Storage access failed: {}".format(remote_url))
|
||||||
# check if we need to cache the file
|
# check if we need to cache the file
|
||||||
try:
|
try:
|
||||||
|
# noinspection PyProtectedMember
|
||||||
direct_access = helper._driver.get_direct_access(remote_url)
|
direct_access = helper._driver.get_direct_access(remote_url)
|
||||||
except (OSError, ValueError):
|
except (OSError, ValueError):
|
||||||
LoggerRoot.get_base_logger().warning("Failed accessing local file: {}".format(remote_url))
|
LoggerRoot.get_base_logger().warning("Failed accessing local file: {}".format(remote_url))
|
||||||
@ -41,20 +45,24 @@ class CacheManager(object):
|
|||||||
# check if we already have the file in our cache
|
# check if we already have the file in our cache
|
||||||
cached_file, cached_size = self._get_cache_file(remote_url)
|
cached_file, cached_size = self._get_cache_file(remote_url)
|
||||||
if cached_size is not None:
|
if cached_size is not None:
|
||||||
|
CacheManager._add_remote_url(remote_url, cached_file)
|
||||||
return cached_file
|
return cached_file
|
||||||
# we need to download the file:
|
# we need to download the file:
|
||||||
downloaded_file = helper.download_to_file(remote_url, cached_file)
|
downloaded_file = helper.download_to_file(remote_url, cached_file)
|
||||||
if downloaded_file != cached_file:
|
if downloaded_file != cached_file:
|
||||||
# something happened
|
# something happened
|
||||||
return None
|
return None
|
||||||
|
CacheManager._add_remote_url(remote_url, cached_file)
|
||||||
return cached_file
|
return cached_file
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def upload_file(local_file, remote_url, wait_for_upload=True):
|
def upload_file(local_file, remote_url, wait_for_upload=True):
|
||||||
helper = StorageHelper.get(remote_url)
|
helper = StorageHelper.get(remote_url)
|
||||||
return helper.upload(
|
result = helper.upload(
|
||||||
local_file, remote_url, async_enable=not wait_for_upload
|
local_file, remote_url, async_enable=not wait_for_upload
|
||||||
)
|
)
|
||||||
|
CacheManager._add_remote_url(remote_url, local_file)
|
||||||
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_hashed_url_file(cls, url):
|
def _get_hashed_url_file(cls, url):
|
||||||
@ -123,3 +131,25 @@ class CacheManager(object):
|
|||||||
cls.__cache_managers[cache_context].set_cache_limit(cache_file_limit)
|
cls.__cache_managers[cache_context].set_cache_limit(cache_file_limit)
|
||||||
|
|
||||||
return cls.__cache_managers[cache_context]
|
return cls.__cache_managers[cache_context]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_remote_url(local_copy_path):
|
||||||
|
if not CacheManager._local_to_remote_url_lookup:
|
||||||
|
return local_copy_path
|
||||||
|
conform_local_copy_path = StorageHelper.conform_url(local_copy_path)
|
||||||
|
return CacheManager._local_to_remote_url_lookup.get(hash(conform_local_copy_path), local_copy_path)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _add_remote_url(remote_url, local_copy_path):
|
||||||
|
# so that we can disable the cache lookup altogether
|
||||||
|
if CacheManager._local_to_remote_url_lookup is None:
|
||||||
|
return
|
||||||
|
remote_url = StorageHelper.conform_url(remote_url)
|
||||||
|
if remote_url.startswith('file://'):
|
||||||
|
return
|
||||||
|
local_copy_path = StorageHelper.conform_url(local_copy_path)
|
||||||
|
CacheManager._local_to_remote_url_lookup[hash(local_copy_path)] = remote_url
|
||||||
|
# protect against overuse, so we do not blowup the memory
|
||||||
|
if len(CacheManager._local_to_remote_url_lookup) > CacheManager.__local_to_remote_url_lookup_max_size:
|
||||||
|
# pop the first item (FIFO)
|
||||||
|
CacheManager._local_to_remote_url_lookup.popitem(last=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user