mirror of
https://github.com/clearml/clearml
synced 2025-04-25 00:37:52 +00:00
Support reusing Models. Use trains.Model as general purpose registered Model.
This commit is contained in:
parent
63507c82f7
commit
4e2564cd3a
@ -2,6 +2,6 @@
|
||||
|
||||
from .version import __version__
|
||||
from .task import Task
|
||||
from .model import InputModel, OutputModel
|
||||
from .model import InputModel, OutputModel, Model
|
||||
from .logger import Logger
|
||||
from .errors import UsageError
|
||||
|
176
trains/model.py
176
trains/model.py
@ -189,7 +189,21 @@ class BaseModel(object):
|
||||
|
||||
@property
|
||||
def task(self):
|
||||
return self._task
|
||||
"""
|
||||
Return the creating task id (str)
|
||||
|
||||
:return str: Task ID
|
||||
"""
|
||||
return self._task or self._get_base_model().task
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
"""
|
||||
Return the url of the model file (or archived files)
|
||||
|
||||
:return str: Model file URL
|
||||
"""
|
||||
return self._get_base_model().uri
|
||||
|
||||
@property
|
||||
def published(self):
|
||||
@ -323,7 +337,58 @@ class BaseModel(object):
|
||||
return config_text
|
||||
|
||||
|
||||
class InputModel(BaseModel):
|
||||
class Model(BaseModel):
|
||||
"""
|
||||
Represent an existing model in the system, search by model id.
|
||||
The Model will be read-only and can be used to pre initialize a network
|
||||
"""
|
||||
|
||||
def __init__(self, model_id):
|
||||
"""
|
||||
Load model based on id, returned object is read-only and can be connected to a task
|
||||
|
||||
Notice, we can override the input model when running remotely
|
||||
|
||||
:param model_id: id (string)
|
||||
"""
|
||||
super(Model, self).__init__()
|
||||
self._base_model_id = model_id
|
||||
self._base_model = None
|
||||
|
||||
def get_local_copy(self, extract_archive=True):
|
||||
"""
|
||||
Retrieve a valid link to the model file(s).
|
||||
If the model URL is a file system link, it will be returned directly.
|
||||
If the model URL is points to a remote location (http/s3/gs etc.),
|
||||
it will download the file(s) and return the temporary location of the downloaded model.
|
||||
|
||||
:param bool extract_archive: If True and the model is of type 'packaged' (e.g. TensorFlow compressed folder)
|
||||
The returned path will be a temporary folder containing the archive content
|
||||
:return str: a local path to the model (or a downloaded copy of it)
|
||||
"""
|
||||
if extract_archive and self._package_tag in self.tags:
|
||||
return self.get_weights_package(return_path=True)
|
||||
return self.get_weights()
|
||||
|
||||
def _get_base_model(self):
|
||||
if self._base_model:
|
||||
return self._base_model
|
||||
|
||||
if not self._base_model_id:
|
||||
# this shouldn't actually happen
|
||||
raise Exception('Missing model ID, cannot create an empty model')
|
||||
self._base_model = _Model(
|
||||
upload_storage_uri=None,
|
||||
cache_dir=get_cache_dir(),
|
||||
model_id=self._base_model_id,
|
||||
)
|
||||
return self._base_model
|
||||
|
||||
def _get_model_data(self):
|
||||
return self._get_base_model().data
|
||||
|
||||
|
||||
class InputModel(Model):
|
||||
"""
|
||||
Load an existing model in the system, search by model id.
|
||||
The Model will be read-only and can be used to pre initialize a network
|
||||
@ -446,6 +511,54 @@ class InputModel(BaseModel):
|
||||
|
||||
return this_model
|
||||
|
||||
@classmethod
|
||||
def load_model(
|
||||
cls,
|
||||
weights_url,
|
||||
load_archived=False
|
||||
):
|
||||
"""
|
||||
Load an already registered model based on a pre-existing model file (link must be valid).
|
||||
|
||||
If the url to the weights file already exists, the returned object is a Model representing the loaded Model
|
||||
If there could not be found any registered model Model with the specified url, None is returned.
|
||||
|
||||
:param weights_url: valid url for the weights file (string).
|
||||
examples: "https://domain.com/file.bin" or "s3://bucket/file.bin" or "file:///home/user/file.bin".
|
||||
NOTE: if a model with the exact same URL exists, it will be used, and all other arguments will be ignored.
|
||||
:param bool load_archived: If True return registered Model with even if they are archived,
|
||||
otherwise archived models are ignored,
|
||||
:return Model: InputModel object or None if no model could be found
|
||||
"""
|
||||
weights_url = StorageHelper.conform_url(weights_url)
|
||||
if not weights_url:
|
||||
raise ValueError("Please provide a valid weights_url parameter")
|
||||
if not load_archived:
|
||||
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \
|
||||
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]}
|
||||
else:
|
||||
extra = {}
|
||||
|
||||
result = _Model._get_default_session().send(models.GetAllRequest(
|
||||
uri=[weights_url],
|
||||
only_fields=["id", "name", "created"],
|
||||
**extra
|
||||
))
|
||||
|
||||
if not result or not result.response or not result.response.models:
|
||||
return None
|
||||
|
||||
logger = get_logger()
|
||||
model = get_single_result(
|
||||
entity='model',
|
||||
query=weights_url,
|
||||
results=result.response.models,
|
||||
log=logger,
|
||||
raise_on_error=False,
|
||||
)
|
||||
|
||||
return InputModel(model_id=model.id)
|
||||
|
||||
@classmethod
|
||||
def empty(
|
||||
cls,
|
||||
@ -484,9 +597,7 @@ class InputModel(BaseModel):
|
||||
|
||||
:param model_id: id (string)
|
||||
"""
|
||||
super(InputModel, self).__init__()
|
||||
self._base_model_id = model_id
|
||||
self._base_model = None
|
||||
super(InputModel, self).__init__(model_id)
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
@ -526,23 +637,6 @@ class InputModel(BaseModel):
|
||||
# the newly connected input model
|
||||
self.task._reconnect_output_model()
|
||||
|
||||
def _get_base_model(self):
|
||||
if self._base_model:
|
||||
return self._base_model
|
||||
|
||||
if not self._base_model_id:
|
||||
# this shouldn't actually happen
|
||||
raise Exception('Missing model ID, cannot create an empty model')
|
||||
self._base_model = _Model(
|
||||
upload_storage_uri=None,
|
||||
cache_dir=get_cache_dir(),
|
||||
model_id=self._base_model_id,
|
||||
)
|
||||
return self._base_model
|
||||
|
||||
def _get_model_data(self):
|
||||
return self._get_base_model().data
|
||||
|
||||
|
||||
class OutputModel(BaseModel):
|
||||
"""
|
||||
@ -624,6 +718,7 @@ class OutputModel(BaseModel):
|
||||
tags=None,
|
||||
comment=None,
|
||||
framework=None,
|
||||
base_model_id=None,
|
||||
):
|
||||
"""
|
||||
Create a new model and immediately connect it to a task.
|
||||
@ -644,6 +739,7 @@ class OutputModel(BaseModel):
|
||||
:param tags: optional, list of strings as tags
|
||||
:param comment: optional, string description for the model
|
||||
:param framework: optional, string name of the framework of the model or Framework
|
||||
:param base_model_id: optional, model id to be reused
|
||||
"""
|
||||
super(OutputModel, self).__init__(task=task)
|
||||
|
||||
@ -656,10 +752,32 @@ class OutputModel(BaseModel):
|
||||
labels=label_enumeration or task.get_labels_enumeration(),
|
||||
name=name,
|
||||
tags=tags,
|
||||
comment='Created by task id: {}'.format(task.id) + ('\n' + comment if comment else ''),
|
||||
comment='{} by task id: {}'.format('Created' if not base_model_id else 'Overwritten', task.id) +
|
||||
('\n' + comment if comment else ''),
|
||||
framework=framework,
|
||||
upload_storage_uri=task.output_uri,
|
||||
)
|
||||
if base_model_id:
|
||||
try:
|
||||
_base_model = InputModel(base_model_id)._get_base_model()
|
||||
_base_model.update(
|
||||
labels=self._floating_data.labels,
|
||||
design=self._floating_data.design,
|
||||
task_id=self._task.id,
|
||||
project_id=self._task.project,
|
||||
name=self._floating_data.name or task.name,
|
||||
comment=('{}\n{}'.format(_base_model.comment, self._floating_data.comment)
|
||||
if _base_model.comment and self._floating_data.comment else
|
||||
(_base_model.comment or self._floating_data.comment)),
|
||||
tags=self._floating_data.tags,
|
||||
framework=self._floating_data.framework,
|
||||
upload_storage_uri=self._floating_data.upload_storage_uri
|
||||
)
|
||||
self._base_model = _base_model
|
||||
self._floating_data = None
|
||||
self._base_model.update_for_task(task_id=self._task.id, override_model_id=self.id)
|
||||
except Exception:
|
||||
pass
|
||||
self.connect(task)
|
||||
|
||||
def connect(self, task):
|
||||
@ -679,8 +797,16 @@ class OutputModel(BaseModel):
|
||||
raise ValueError('Can only connect preexisting model to task, but this is a fresh model')
|
||||
|
||||
if running_remotely() and task.is_main_task():
|
||||
self._floating_data.design = _Model._wrap_design(self._task.get_model_config_text())
|
||||
self._floating_data.labels = self._task.get_labels_enumeration()
|
||||
if self._floating_data:
|
||||
self._floating_data.design = _Model._wrap_design(self._task.get_model_config_text()) or \
|
||||
self._floating_data.design
|
||||
self._floating_data.labels = self._task.get_labels_enumeration() or \
|
||||
self._floating_data.labels
|
||||
elif self._base_model:
|
||||
self._base_model.update(design=_Model._wrap_design(self._task.get_model_config_text()) or
|
||||
self._base_model.design)
|
||||
self._base_model.update(labels=self._task.get_labels_enumeration() or self._base_model.labels)
|
||||
|
||||
elif self._floating_data is not None:
|
||||
# we copy configuration / labels if they exist, obviously someone wants them as the output base model
|
||||
if _Model._unwrap_design(self._floating_data.design):
|
||||
|
Loading…
Reference in New Issue
Block a user