Support reusing Models. Use trains.Model as general purpose registered Model.

This commit is contained in:
allegroai 2020-03-22 18:13:56 +02:00
parent 63507c82f7
commit 4e2564cd3a
2 changed files with 152 additions and 26 deletions

View File

@ -2,6 +2,6 @@
from .version import __version__ from .version import __version__
from .task import Task from .task import Task
from .model import InputModel, OutputModel from .model import InputModel, OutputModel, Model
from .logger import Logger from .logger import Logger
from .errors import UsageError from .errors import UsageError

View File

@ -189,7 +189,21 @@ class BaseModel(object):
@property @property
def task(self): def task(self):
return self._task """
Return the creating task id (str)
:return str: Task ID
"""
return self._task or self._get_base_model().task
@property
def url(self):
"""
Return the url of the model file (or archived files)
:return str: Model file URL
"""
return self._get_base_model().uri
@property @property
def published(self): def published(self):
@ -323,7 +337,58 @@ class BaseModel(object):
return config_text return config_text
class InputModel(BaseModel): class Model(BaseModel):
"""
Represent an existing model in the system, search by model id.
The Model will be read-only and can be used to pre initialize a network
"""
def __init__(self, model_id):
"""
Load model based on id, returned object is read-only and can be connected to a task
Notice, we can override the input model when running remotely
:param model_id: id (string)
"""
super(Model, self).__init__()
self._base_model_id = model_id
self._base_model = None
def get_local_copy(self, extract_archive=True):
"""
Retrieve a valid link to the model file(s).
If the model URL is a file system link, it will be returned directly.
If the model URL is points to a remote location (http/s3/gs etc.),
it will download the file(s) and return the temporary location of the downloaded model.
:param bool extract_archive: If True and the model is of type 'packaged' (e.g. TensorFlow compressed folder)
The returned path will be a temporary folder containing the archive content
:return str: a local path to the model (or a downloaded copy of it)
"""
if extract_archive and self._package_tag in self.tags:
return self.get_weights_package(return_path=True)
return self.get_weights()
def _get_base_model(self):
if self._base_model:
return self._base_model
if not self._base_model_id:
# this shouldn't actually happen
raise Exception('Missing model ID, cannot create an empty model')
self._base_model = _Model(
upload_storage_uri=None,
cache_dir=get_cache_dir(),
model_id=self._base_model_id,
)
return self._base_model
def _get_model_data(self):
return self._get_base_model().data
class InputModel(Model):
""" """
Load an existing model in the system, search by model id. Load an existing model in the system, search by model id.
The Model will be read-only and can be used to pre initialize a network The Model will be read-only and can be used to pre initialize a network
@ -446,6 +511,54 @@ class InputModel(BaseModel):
return this_model return this_model
@classmethod
def load_model(
cls,
weights_url,
load_archived=False
):
"""
Load an already registered model based on a pre-existing model file (link must be valid).
If the url to the weights file already exists, the returned object is a Model representing the loaded Model
If there could not be found any registered model Model with the specified url, None is returned.
:param weights_url: valid url for the weights file (string).
examples: "https://domain.com/file.bin" or "s3://bucket/file.bin" or "file:///home/user/file.bin".
NOTE: if a model with the exact same URL exists, it will be used, and all other arguments will be ignored.
:param bool load_archived: If True return registered Model with even if they are archived,
otherwise archived models are ignored,
:return Model: InputModel object or None if no model could be found
"""
weights_url = StorageHelper.conform_url(weights_url)
if not weights_url:
raise ValueError("Please provide a valid weights_url parameter")
if not load_archived:
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]}
else:
extra = {}
result = _Model._get_default_session().send(models.GetAllRequest(
uri=[weights_url],
only_fields=["id", "name", "created"],
**extra
))
if not result or not result.response or not result.response.models:
return None
logger = get_logger()
model = get_single_result(
entity='model',
query=weights_url,
results=result.response.models,
log=logger,
raise_on_error=False,
)
return InputModel(model_id=model.id)
@classmethod @classmethod
def empty( def empty(
cls, cls,
@ -484,9 +597,7 @@ class InputModel(BaseModel):
:param model_id: id (string) :param model_id: id (string)
""" """
super(InputModel, self).__init__() super(InputModel, self).__init__(model_id)
self._base_model_id = model_id
self._base_model = None
@property @property
def id(self): def id(self):
@ -526,23 +637,6 @@ class InputModel(BaseModel):
# the newly connected input model # the newly connected input model
self.task._reconnect_output_model() self.task._reconnect_output_model()
def _get_base_model(self):
if self._base_model:
return self._base_model
if not self._base_model_id:
# this shouldn't actually happen
raise Exception('Missing model ID, cannot create an empty model')
self._base_model = _Model(
upload_storage_uri=None,
cache_dir=get_cache_dir(),
model_id=self._base_model_id,
)
return self._base_model
def _get_model_data(self):
return self._get_base_model().data
class OutputModel(BaseModel): class OutputModel(BaseModel):
""" """
@ -624,6 +718,7 @@ class OutputModel(BaseModel):
tags=None, tags=None,
comment=None, comment=None,
framework=None, framework=None,
base_model_id=None,
): ):
""" """
Create a new model and immediately connect it to a task. Create a new model and immediately connect it to a task.
@ -644,6 +739,7 @@ class OutputModel(BaseModel):
:param tags: optional, list of strings as tags :param tags: optional, list of strings as tags
:param comment: optional, string description for the model :param comment: optional, string description for the model
:param framework: optional, string name of the framework of the model or Framework :param framework: optional, string name of the framework of the model or Framework
:param base_model_id: optional, model id to be reused
""" """
super(OutputModel, self).__init__(task=task) super(OutputModel, self).__init__(task=task)
@ -656,10 +752,32 @@ class OutputModel(BaseModel):
labels=label_enumeration or task.get_labels_enumeration(), labels=label_enumeration or task.get_labels_enumeration(),
name=name, name=name,
tags=tags, tags=tags,
comment='Created by task id: {}'.format(task.id) + ('\n' + comment if comment else ''), comment='{} by task id: {}'.format('Created' if not base_model_id else 'Overwritten', task.id) +
('\n' + comment if comment else ''),
framework=framework, framework=framework,
upload_storage_uri=task.output_uri, upload_storage_uri=task.output_uri,
) )
if base_model_id:
try:
_base_model = InputModel(base_model_id)._get_base_model()
_base_model.update(
labels=self._floating_data.labels,
design=self._floating_data.design,
task_id=self._task.id,
project_id=self._task.project,
name=self._floating_data.name or task.name,
comment=('{}\n{}'.format(_base_model.comment, self._floating_data.comment)
if _base_model.comment and self._floating_data.comment else
(_base_model.comment or self._floating_data.comment)),
tags=self._floating_data.tags,
framework=self._floating_data.framework,
upload_storage_uri=self._floating_data.upload_storage_uri
)
self._base_model = _base_model
self._floating_data = None
self._base_model.update_for_task(task_id=self._task.id, override_model_id=self.id)
except Exception:
pass
self.connect(task) self.connect(task)
def connect(self, task): def connect(self, task):
@ -679,8 +797,16 @@ class OutputModel(BaseModel):
raise ValueError('Can only connect preexisting model to task, but this is a fresh model') raise ValueError('Can only connect preexisting model to task, but this is a fresh model')
if running_remotely() and task.is_main_task(): if running_remotely() and task.is_main_task():
self._floating_data.design = _Model._wrap_design(self._task.get_model_config_text()) if self._floating_data:
self._floating_data.labels = self._task.get_labels_enumeration() self._floating_data.design = _Model._wrap_design(self._task.get_model_config_text()) or \
self._floating_data.design
self._floating_data.labels = self._task.get_labels_enumeration() or \
self._floating_data.labels
elif self._base_model:
self._base_model.update(design=_Model._wrap_design(self._task.get_model_config_text()) or
self._base_model.design)
self._base_model.update(labels=self._task.get_labels_enumeration() or self._base_model.labels)
elif self._floating_data is not None: elif self._floating_data is not None:
# we copy configuration / labels if they exist, obviously someone wants them as the output base model # we copy configuration / labels if they exist, obviously someone wants them as the output base model
if _Model._unwrap_design(self._floating_data.design): if _Model._unwrap_design(self._floating_data.design):