Support reusing Models. Use trains.Model as general purpose registered Model.

This commit is contained in:
allegroai 2020-03-22 18:13:56 +02:00
parent 63507c82f7
commit 4e2564cd3a
2 changed files with 152 additions and 26 deletions

View File

@ -2,6 +2,6 @@
from .version import __version__
from .task import Task
from .model import InputModel, OutputModel
from .model import InputModel, OutputModel, Model
from .logger import Logger
from .errors import UsageError

View File

@ -189,7 +189,21 @@ class BaseModel(object):
@property
def task(self):
return self._task
"""
Return the creating task id (str)
:return str: Task ID
"""
return self._task or self._get_base_model().task
@property
def url(self):
"""
Return the url of the model file (or archived files)
:return str: Model file URL
"""
return self._get_base_model().uri
@property
def published(self):
@ -323,7 +337,58 @@ class BaseModel(object):
return config_text
class InputModel(BaseModel):
class Model(BaseModel):
"""
Represent an existing model in the system, search by model id.
The Model will be read-only and can be used to pre initialize a network
"""
def __init__(self, model_id):
"""
Load model based on id, returned object is read-only and can be connected to a task
Notice, we can override the input model when running remotely
:param model_id: id (string)
"""
super(Model, self).__init__()
self._base_model_id = model_id
self._base_model = None
def get_local_copy(self, extract_archive=True):
"""
Retrieve a valid link to the model file(s).
If the model URL is a file system link, it will be returned directly.
If the model URL is points to a remote location (http/s3/gs etc.),
it will download the file(s) and return the temporary location of the downloaded model.
:param bool extract_archive: If True and the model is of type 'packaged' (e.g. TensorFlow compressed folder)
The returned path will be a temporary folder containing the archive content
:return str: a local path to the model (or a downloaded copy of it)
"""
if extract_archive and self._package_tag in self.tags:
return self.get_weights_package(return_path=True)
return self.get_weights()
def _get_base_model(self):
if self._base_model:
return self._base_model
if not self._base_model_id:
# this shouldn't actually happen
raise Exception('Missing model ID, cannot create an empty model')
self._base_model = _Model(
upload_storage_uri=None,
cache_dir=get_cache_dir(),
model_id=self._base_model_id,
)
return self._base_model
def _get_model_data(self):
return self._get_base_model().data
class InputModel(Model):
"""
Load an existing model in the system, search by model id.
The Model will be read-only and can be used to pre initialize a network
@ -446,6 +511,54 @@ class InputModel(BaseModel):
return this_model
@classmethod
def load_model(
cls,
weights_url,
load_archived=False
):
"""
Load an already registered model based on a pre-existing model file (link must be valid).
If the url to the weights file already exists, the returned object is a Model representing the loaded Model
If there could not be found any registered model Model with the specified url, None is returned.
:param weights_url: valid url for the weights file (string).
examples: "https://domain.com/file.bin" or "s3://bucket/file.bin" or "file:///home/user/file.bin".
NOTE: if a model with the exact same URL exists, it will be used, and all other arguments will be ignored.
:param bool load_archived: If True return registered Model with even if they are archived,
otherwise archived models are ignored,
:return Model: InputModel object or None if no model could be found
"""
weights_url = StorageHelper.conform_url(weights_url)
if not weights_url:
raise ValueError("Please provide a valid weights_url parameter")
if not load_archived:
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]}
else:
extra = {}
result = _Model._get_default_session().send(models.GetAllRequest(
uri=[weights_url],
only_fields=["id", "name", "created"],
**extra
))
if not result or not result.response or not result.response.models:
return None
logger = get_logger()
model = get_single_result(
entity='model',
query=weights_url,
results=result.response.models,
log=logger,
raise_on_error=False,
)
return InputModel(model_id=model.id)
@classmethod
def empty(
cls,
@ -484,9 +597,7 @@ class InputModel(BaseModel):
:param model_id: id (string)
"""
super(InputModel, self).__init__()
self._base_model_id = model_id
self._base_model = None
super(InputModel, self).__init__(model_id)
@property
def id(self):
@ -526,23 +637,6 @@ class InputModel(BaseModel):
# the newly connected input model
self.task._reconnect_output_model()
def _get_base_model(self):
if self._base_model:
return self._base_model
if not self._base_model_id:
# this shouldn't actually happen
raise Exception('Missing model ID, cannot create an empty model')
self._base_model = _Model(
upload_storage_uri=None,
cache_dir=get_cache_dir(),
model_id=self._base_model_id,
)
return self._base_model
def _get_model_data(self):
return self._get_base_model().data
class OutputModel(BaseModel):
"""
@ -624,6 +718,7 @@ class OutputModel(BaseModel):
tags=None,
comment=None,
framework=None,
base_model_id=None,
):
"""
Create a new model and immediately connect it to a task.
@ -644,6 +739,7 @@ class OutputModel(BaseModel):
:param tags: optional, list of strings as tags
:param comment: optional, string description for the model
:param framework: optional, string name of the framework of the model or Framework
:param base_model_id: optional, model id to be reused
"""
super(OutputModel, self).__init__(task=task)
@ -656,10 +752,32 @@ class OutputModel(BaseModel):
labels=label_enumeration or task.get_labels_enumeration(),
name=name,
tags=tags,
comment='Created by task id: {}'.format(task.id) + ('\n' + comment if comment else ''),
comment='{} by task id: {}'.format('Created' if not base_model_id else 'Overwritten', task.id) +
('\n' + comment if comment else ''),
framework=framework,
upload_storage_uri=task.output_uri,
)
if base_model_id:
try:
_base_model = InputModel(base_model_id)._get_base_model()
_base_model.update(
labels=self._floating_data.labels,
design=self._floating_data.design,
task_id=self._task.id,
project_id=self._task.project,
name=self._floating_data.name or task.name,
comment=('{}\n{}'.format(_base_model.comment, self._floating_data.comment)
if _base_model.comment and self._floating_data.comment else
(_base_model.comment or self._floating_data.comment)),
tags=self._floating_data.tags,
framework=self._floating_data.framework,
upload_storage_uri=self._floating_data.upload_storage_uri
)
self._base_model = _base_model
self._floating_data = None
self._base_model.update_for_task(task_id=self._task.id, override_model_id=self.id)
except Exception:
pass
self.connect(task)
def connect(self, task):
@ -679,8 +797,16 @@ class OutputModel(BaseModel):
raise ValueError('Can only connect preexisting model to task, but this is a fresh model')
if running_remotely() and task.is_main_task():
self._floating_data.design = _Model._wrap_design(self._task.get_model_config_text())
self._floating_data.labels = self._task.get_labels_enumeration()
if self._floating_data:
self._floating_data.design = _Model._wrap_design(self._task.get_model_config_text()) or \
self._floating_data.design
self._floating_data.labels = self._task.get_labels_enumeration() or \
self._floating_data.labels
elif self._base_model:
self._base_model.update(design=_Model._wrap_design(self._task.get_model_config_text()) or
self._base_model.design)
self._base_model.update(labels=self._task.get_labels_enumeration() or self._base_model.labels)
elif self._floating_data is not None:
# we copy configuration / labels if they exist, obviously someone wants them as the output base model
if _Model._unwrap_design(self._floating_data.design):