diff --git a/trains/model.py b/trains/model.py index 20bcd621..8c49abbe 100644 --- a/trains/model.py +++ b/trains/model.py @@ -11,10 +11,12 @@ from typing import List, Dict, Union, Optional, Mapping, TYPE_CHECKING, Sequence from .backend_api import Session from .backend_api.services import models from pathlib2 import Path + from .utilities.pyhocon import ConfigFactory, HOCONConverter from .backend_interface.util import validate_dict, get_single_result, mutually_exclusive from .debugging.log import get_logger +from .storage.cache import CacheManager from .storage.helper import StorageHelper from .utilities.enum import Options from .backend_interface import Task as _Task @@ -526,6 +528,9 @@ class InputModel(Model): weights_url = StorageHelper.conform_url(weights_url) if not weights_url: raise ValueError("Please provide a valid weights_url parameter") + # convert local to file to remote one + weights_url = CacheManager.get_remote_url(weights_url) + extra = {'system_tags': ["-" + ARCHIVED_TAG]} \ if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]} result = _Model._get_default_session().send(models.GetAllRequest( @@ -561,7 +566,10 @@ class InputModel(Model): if task: comment = 'Imported by task id: {}'.format(task.id) + ('\n' + comment if comment else '') project_id = task.project - task_id = task.id + name = name or task.name + # do not register the Task, because we do not want it listed after as "output model", + # the Task never actually created the Model + task_id = None else: project_id = None task_id = None @@ -624,6 +632,10 @@ class InputModel(Model): weights_url = StorageHelper.conform_url(weights_url) if not weights_url: raise ValueError("Please provide a valid weights_url parameter") + + # convert local to file to remote one + weights_url = CacheManager.get_remote_url(weights_url) + if not load_archived: extra = {'system_tags': ["-" + ARCHIVED_TAG]} \ if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]} @@ -919,7 +931,7 @@ class OutputModel(BaseModel): self._floating_data = create_dummy_model( design=_Model._wrap_design(config_text), labels=label_enumeration or task.get_labels_enumeration(), - name=name, + name=name or task.name, tags=tags, comment='{} by task id: {}'.format('Created' if not base_model_id else 'Overwritten', task.id) + ('\n' + comment if comment else ''), diff --git a/trains/storage/cache.py b/trains/storage/cache.py index 07be8b5c..120a9ddd 100644 --- a/trains/storage/cache.py +++ b/trains/storage/cache.py @@ -1,6 +1,7 @@ import hashlib import shutil +from collections import OrderedDict from pathlib2 import Path from .helper import StorageHelper @@ -14,6 +15,8 @@ class CacheManager(object): _default_cache_file_limit = 100 _storage_manager_folder = "storage_manager" _default_context = "global" + _local_to_remote_url_lookup = OrderedDict() + __local_to_remote_url_lookup_max_size = 1024 class CacheContext(object): def __init__(self, cache_context, default_cache_file_limit=10): @@ -30,6 +33,7 @@ class CacheManager(object): raise ValueError("Storage access failed: {}".format(remote_url)) # check if we need to cache the file try: + # noinspection PyProtectedMember direct_access = helper._driver.get_direct_access(remote_url) except (OSError, ValueError): LoggerRoot.get_base_logger().warning("Failed accessing local file: {}".format(remote_url)) @@ -41,20 +45,24 @@ class CacheManager(object): # check if we already have the file in our cache cached_file, cached_size = self._get_cache_file(remote_url) if cached_size is not None: + CacheManager._add_remote_url(remote_url, cached_file) return cached_file # we need to download the file: downloaded_file = helper.download_to_file(remote_url, cached_file) if downloaded_file != cached_file: # something happened return None + CacheManager._add_remote_url(remote_url, cached_file) return cached_file @staticmethod def upload_file(local_file, remote_url, wait_for_upload=True): helper = StorageHelper.get(remote_url) - return helper.upload( + result = helper.upload( local_file, remote_url, async_enable=not wait_for_upload ) + CacheManager._add_remote_url(remote_url, local_file) + return result @classmethod def _get_hashed_url_file(cls, url): @@ -123,3 +131,25 @@ class CacheManager(object): cls.__cache_managers[cache_context].set_cache_limit(cache_file_limit) return cls.__cache_managers[cache_context] + + @staticmethod + def get_remote_url(local_copy_path): + if not CacheManager._local_to_remote_url_lookup: + return local_copy_path + conform_local_copy_path = StorageHelper.conform_url(local_copy_path) + return CacheManager._local_to_remote_url_lookup.get(hash(conform_local_copy_path), local_copy_path) + + @staticmethod + def _add_remote_url(remote_url, local_copy_path): + # so that we can disable the cache lookup altogether + if CacheManager._local_to_remote_url_lookup is None: + return + remote_url = StorageHelper.conform_url(remote_url) + if remote_url.startswith('file://'): + return + local_copy_path = StorageHelper.conform_url(local_copy_path) + CacheManager._local_to_remote_url_lookup[hash(local_copy_path)] = remote_url + # protect against overuse, so we do not blowup the memory + if len(CacheManager._local_to_remote_url_lookup) > CacheManager.__local_to_remote_url_lookup_max_size: + # pop the first item (FIFO) + CacheManager._local_to_remote_url_lookup.popitem(last=False)