diff --git a/trains/__init__.py b/trains/__init__.py index 2b549bd0..053cb431 100644 --- a/trains/__init__.py +++ b/trains/__init__.py @@ -2,6 +2,6 @@ from .version import __version__ from .task import Task -from .model import InputModel, OutputModel +from .model import InputModel, OutputModel, Model from .logger import Logger from .errors import UsageError diff --git a/trains/model.py b/trains/model.py index 293ea154..0a00b46e 100644 --- a/trains/model.py +++ b/trains/model.py @@ -189,7 +189,21 @@ class BaseModel(object): @property def task(self): - return self._task + """ + Return the creating task id (str) + + :return str: Task ID + """ + return self._task or self._get_base_model().task + + @property + def url(self): + """ + Return the url of the model file (or archived files) + + :return str: Model file URL + """ + return self._get_base_model().uri @property def published(self): @@ -323,7 +337,58 @@ class BaseModel(object): return config_text -class InputModel(BaseModel): +class Model(BaseModel): + """ + Represent an existing model in the system, search by model id. + The Model will be read-only and can be used to pre initialize a network + """ + + def __init__(self, model_id): + """ + Load model based on id, returned object is read-only and can be connected to a task + + Notice, we can override the input model when running remotely + + :param model_id: id (string) + """ + super(Model, self).__init__() + self._base_model_id = model_id + self._base_model = None + + def get_local_copy(self, extract_archive=True): + """ + Retrieve a valid link to the model file(s). + If the model URL is a file system link, it will be returned directly. + If the model URL is points to a remote location (http/s3/gs etc.), + it will download the file(s) and return the temporary location of the downloaded model. + + :param bool extract_archive: If True and the model is of type 'packaged' (e.g. TensorFlow compressed folder) + The returned path will be a temporary folder containing the archive content + :return str: a local path to the model (or a downloaded copy of it) + """ + if extract_archive and self._package_tag in self.tags: + return self.get_weights_package(return_path=True) + return self.get_weights() + + def _get_base_model(self): + if self._base_model: + return self._base_model + + if not self._base_model_id: + # this shouldn't actually happen + raise Exception('Missing model ID, cannot create an empty model') + self._base_model = _Model( + upload_storage_uri=None, + cache_dir=get_cache_dir(), + model_id=self._base_model_id, + ) + return self._base_model + + def _get_model_data(self): + return self._get_base_model().data + + +class InputModel(Model): """ Load an existing model in the system, search by model id. The Model will be read-only and can be used to pre initialize a network @@ -446,6 +511,54 @@ class InputModel(BaseModel): return this_model + @classmethod + def load_model( + cls, + weights_url, + load_archived=False + ): + """ + Load an already registered model based on a pre-existing model file (link must be valid). + + If the url to the weights file already exists, the returned object is a Model representing the loaded Model + If there could not be found any registered model Model with the specified url, None is returned. + + :param weights_url: valid url for the weights file (string). + examples: "https://domain.com/file.bin" or "s3://bucket/file.bin" or "file:///home/user/file.bin". + NOTE: if a model with the exact same URL exists, it will be used, and all other arguments will be ignored. + :param bool load_archived: If True return registered Model with even if they are archived, + otherwise archived models are ignored, + :return Model: InputModel object or None if no model could be found + """ + weights_url = StorageHelper.conform_url(weights_url) + if not weights_url: + raise ValueError("Please provide a valid weights_url parameter") + if not load_archived: + extra = {'system_tags': ["-" + ARCHIVED_TAG]} \ + if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]} + else: + extra = {} + + result = _Model._get_default_session().send(models.GetAllRequest( + uri=[weights_url], + only_fields=["id", "name", "created"], + **extra + )) + + if not result or not result.response or not result.response.models: + return None + + logger = get_logger() + model = get_single_result( + entity='model', + query=weights_url, + results=result.response.models, + log=logger, + raise_on_error=False, + ) + + return InputModel(model_id=model.id) + @classmethod def empty( cls, @@ -484,9 +597,7 @@ class InputModel(BaseModel): :param model_id: id (string) """ - super(InputModel, self).__init__() - self._base_model_id = model_id - self._base_model = None + super(InputModel, self).__init__(model_id) @property def id(self): @@ -526,23 +637,6 @@ class InputModel(BaseModel): # the newly connected input model self.task._reconnect_output_model() - def _get_base_model(self): - if self._base_model: - return self._base_model - - if not self._base_model_id: - # this shouldn't actually happen - raise Exception('Missing model ID, cannot create an empty model') - self._base_model = _Model( - upload_storage_uri=None, - cache_dir=get_cache_dir(), - model_id=self._base_model_id, - ) - return self._base_model - - def _get_model_data(self): - return self._get_base_model().data - class OutputModel(BaseModel): """ @@ -624,6 +718,7 @@ class OutputModel(BaseModel): tags=None, comment=None, framework=None, + base_model_id=None, ): """ Create a new model and immediately connect it to a task. @@ -644,6 +739,7 @@ class OutputModel(BaseModel): :param tags: optional, list of strings as tags :param comment: optional, string description for the model :param framework: optional, string name of the framework of the model or Framework + :param base_model_id: optional, model id to be reused """ super(OutputModel, self).__init__(task=task) @@ -656,10 +752,32 @@ class OutputModel(BaseModel): labels=label_enumeration or task.get_labels_enumeration(), name=name, tags=tags, - comment='Created by task id: {}'.format(task.id) + ('\n' + comment if comment else ''), + comment='{} by task id: {}'.format('Created' if not base_model_id else 'Overwritten', task.id) + + ('\n' + comment if comment else ''), framework=framework, upload_storage_uri=task.output_uri, ) + if base_model_id: + try: + _base_model = InputModel(base_model_id)._get_base_model() + _base_model.update( + labels=self._floating_data.labels, + design=self._floating_data.design, + task_id=self._task.id, + project_id=self._task.project, + name=self._floating_data.name or task.name, + comment=('{}\n{}'.format(_base_model.comment, self._floating_data.comment) + if _base_model.comment and self._floating_data.comment else + (_base_model.comment or self._floating_data.comment)), + tags=self._floating_data.tags, + framework=self._floating_data.framework, + upload_storage_uri=self._floating_data.upload_storage_uri + ) + self._base_model = _base_model + self._floating_data = None + self._base_model.update_for_task(task_id=self._task.id, override_model_id=self.id) + except Exception: + pass self.connect(task) def connect(self, task): @@ -679,8 +797,16 @@ class OutputModel(BaseModel): raise ValueError('Can only connect preexisting model to task, but this is a fresh model') if running_remotely() and task.is_main_task(): - self._floating_data.design = _Model._wrap_design(self._task.get_model_config_text()) - self._floating_data.labels = self._task.get_labels_enumeration() + if self._floating_data: + self._floating_data.design = _Model._wrap_design(self._task.get_model_config_text()) or \ + self._floating_data.design + self._floating_data.labels = self._task.get_labels_enumeration() or \ + self._floating_data.labels + elif self._base_model: + self._base_model.update(design=_Model._wrap_design(self._task.get_model_config_text()) or + self._base_model.design) + self._base_model.update(labels=self._task.get_labels_enumeration() or self._base_model.labels) + elif self._floating_data is not None: # we copy configuration / labels if they exist, obviously someone wants them as the output base model if _Model._unwrap_design(self._floating_data.design):