mirror of
https://github.com/clearml/clearml
synced 2025-04-25 16:59:46 +00:00
Support reusing Models. Use trains.Model as general purpose registered Model.
This commit is contained in:
parent
63507c82f7
commit
4e2564cd3a
@ -2,6 +2,6 @@
|
|||||||
|
|
||||||
from .version import __version__
|
from .version import __version__
|
||||||
from .task import Task
|
from .task import Task
|
||||||
from .model import InputModel, OutputModel
|
from .model import InputModel, OutputModel, Model
|
||||||
from .logger import Logger
|
from .logger import Logger
|
||||||
from .errors import UsageError
|
from .errors import UsageError
|
||||||
|
176
trains/model.py
176
trains/model.py
@ -189,7 +189,21 @@ class BaseModel(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def task(self):
|
def task(self):
|
||||||
return self._task
|
"""
|
||||||
|
Return the creating task id (str)
|
||||||
|
|
||||||
|
:return str: Task ID
|
||||||
|
"""
|
||||||
|
return self._task or self._get_base_model().task
|
||||||
|
|
||||||
|
@property
|
||||||
|
def url(self):
|
||||||
|
"""
|
||||||
|
Return the url of the model file (or archived files)
|
||||||
|
|
||||||
|
:return str: Model file URL
|
||||||
|
"""
|
||||||
|
return self._get_base_model().uri
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def published(self):
|
def published(self):
|
||||||
@ -323,7 +337,58 @@ class BaseModel(object):
|
|||||||
return config_text
|
return config_text
|
||||||
|
|
||||||
|
|
||||||
class InputModel(BaseModel):
|
class Model(BaseModel):
|
||||||
|
"""
|
||||||
|
Represent an existing model in the system, search by model id.
|
||||||
|
The Model will be read-only and can be used to pre initialize a network
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_id):
|
||||||
|
"""
|
||||||
|
Load model based on id, returned object is read-only and can be connected to a task
|
||||||
|
|
||||||
|
Notice, we can override the input model when running remotely
|
||||||
|
|
||||||
|
:param model_id: id (string)
|
||||||
|
"""
|
||||||
|
super(Model, self).__init__()
|
||||||
|
self._base_model_id = model_id
|
||||||
|
self._base_model = None
|
||||||
|
|
||||||
|
def get_local_copy(self, extract_archive=True):
|
||||||
|
"""
|
||||||
|
Retrieve a valid link to the model file(s).
|
||||||
|
If the model URL is a file system link, it will be returned directly.
|
||||||
|
If the model URL is points to a remote location (http/s3/gs etc.),
|
||||||
|
it will download the file(s) and return the temporary location of the downloaded model.
|
||||||
|
|
||||||
|
:param bool extract_archive: If True and the model is of type 'packaged' (e.g. TensorFlow compressed folder)
|
||||||
|
The returned path will be a temporary folder containing the archive content
|
||||||
|
:return str: a local path to the model (or a downloaded copy of it)
|
||||||
|
"""
|
||||||
|
if extract_archive and self._package_tag in self.tags:
|
||||||
|
return self.get_weights_package(return_path=True)
|
||||||
|
return self.get_weights()
|
||||||
|
|
||||||
|
def _get_base_model(self):
|
||||||
|
if self._base_model:
|
||||||
|
return self._base_model
|
||||||
|
|
||||||
|
if not self._base_model_id:
|
||||||
|
# this shouldn't actually happen
|
||||||
|
raise Exception('Missing model ID, cannot create an empty model')
|
||||||
|
self._base_model = _Model(
|
||||||
|
upload_storage_uri=None,
|
||||||
|
cache_dir=get_cache_dir(),
|
||||||
|
model_id=self._base_model_id,
|
||||||
|
)
|
||||||
|
return self._base_model
|
||||||
|
|
||||||
|
def _get_model_data(self):
|
||||||
|
return self._get_base_model().data
|
||||||
|
|
||||||
|
|
||||||
|
class InputModel(Model):
|
||||||
"""
|
"""
|
||||||
Load an existing model in the system, search by model id.
|
Load an existing model in the system, search by model id.
|
||||||
The Model will be read-only and can be used to pre initialize a network
|
The Model will be read-only and can be used to pre initialize a network
|
||||||
@ -446,6 +511,54 @@ class InputModel(BaseModel):
|
|||||||
|
|
||||||
return this_model
|
return this_model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_model(
|
||||||
|
cls,
|
||||||
|
weights_url,
|
||||||
|
load_archived=False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load an already registered model based on a pre-existing model file (link must be valid).
|
||||||
|
|
||||||
|
If the url to the weights file already exists, the returned object is a Model representing the loaded Model
|
||||||
|
If there could not be found any registered model Model with the specified url, None is returned.
|
||||||
|
|
||||||
|
:param weights_url: valid url for the weights file (string).
|
||||||
|
examples: "https://domain.com/file.bin" or "s3://bucket/file.bin" or "file:///home/user/file.bin".
|
||||||
|
NOTE: if a model with the exact same URL exists, it will be used, and all other arguments will be ignored.
|
||||||
|
:param bool load_archived: If True return registered Model with even if they are archived,
|
||||||
|
otherwise archived models are ignored,
|
||||||
|
:return Model: InputModel object or None if no model could be found
|
||||||
|
"""
|
||||||
|
weights_url = StorageHelper.conform_url(weights_url)
|
||||||
|
if not weights_url:
|
||||||
|
raise ValueError("Please provide a valid weights_url parameter")
|
||||||
|
if not load_archived:
|
||||||
|
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \
|
||||||
|
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]}
|
||||||
|
else:
|
||||||
|
extra = {}
|
||||||
|
|
||||||
|
result = _Model._get_default_session().send(models.GetAllRequest(
|
||||||
|
uri=[weights_url],
|
||||||
|
only_fields=["id", "name", "created"],
|
||||||
|
**extra
|
||||||
|
))
|
||||||
|
|
||||||
|
if not result or not result.response or not result.response.models:
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
model = get_single_result(
|
||||||
|
entity='model',
|
||||||
|
query=weights_url,
|
||||||
|
results=result.response.models,
|
||||||
|
log=logger,
|
||||||
|
raise_on_error=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return InputModel(model_id=model.id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(
|
def empty(
|
||||||
cls,
|
cls,
|
||||||
@ -484,9 +597,7 @@ class InputModel(BaseModel):
|
|||||||
|
|
||||||
:param model_id: id (string)
|
:param model_id: id (string)
|
||||||
"""
|
"""
|
||||||
super(InputModel, self).__init__()
|
super(InputModel, self).__init__(model_id)
|
||||||
self._base_model_id = model_id
|
|
||||||
self._base_model = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def id(self):
|
def id(self):
|
||||||
@ -526,23 +637,6 @@ class InputModel(BaseModel):
|
|||||||
# the newly connected input model
|
# the newly connected input model
|
||||||
self.task._reconnect_output_model()
|
self.task._reconnect_output_model()
|
||||||
|
|
||||||
def _get_base_model(self):
|
|
||||||
if self._base_model:
|
|
||||||
return self._base_model
|
|
||||||
|
|
||||||
if not self._base_model_id:
|
|
||||||
# this shouldn't actually happen
|
|
||||||
raise Exception('Missing model ID, cannot create an empty model')
|
|
||||||
self._base_model = _Model(
|
|
||||||
upload_storage_uri=None,
|
|
||||||
cache_dir=get_cache_dir(),
|
|
||||||
model_id=self._base_model_id,
|
|
||||||
)
|
|
||||||
return self._base_model
|
|
||||||
|
|
||||||
def _get_model_data(self):
|
|
||||||
return self._get_base_model().data
|
|
||||||
|
|
||||||
|
|
||||||
class OutputModel(BaseModel):
|
class OutputModel(BaseModel):
|
||||||
"""
|
"""
|
||||||
@ -624,6 +718,7 @@ class OutputModel(BaseModel):
|
|||||||
tags=None,
|
tags=None,
|
||||||
comment=None,
|
comment=None,
|
||||||
framework=None,
|
framework=None,
|
||||||
|
base_model_id=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a new model and immediately connect it to a task.
|
Create a new model and immediately connect it to a task.
|
||||||
@ -644,6 +739,7 @@ class OutputModel(BaseModel):
|
|||||||
:param tags: optional, list of strings as tags
|
:param tags: optional, list of strings as tags
|
||||||
:param comment: optional, string description for the model
|
:param comment: optional, string description for the model
|
||||||
:param framework: optional, string name of the framework of the model or Framework
|
:param framework: optional, string name of the framework of the model or Framework
|
||||||
|
:param base_model_id: optional, model id to be reused
|
||||||
"""
|
"""
|
||||||
super(OutputModel, self).__init__(task=task)
|
super(OutputModel, self).__init__(task=task)
|
||||||
|
|
||||||
@ -656,10 +752,32 @@ class OutputModel(BaseModel):
|
|||||||
labels=label_enumeration or task.get_labels_enumeration(),
|
labels=label_enumeration or task.get_labels_enumeration(),
|
||||||
name=name,
|
name=name,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
comment='Created by task id: {}'.format(task.id) + ('\n' + comment if comment else ''),
|
comment='{} by task id: {}'.format('Created' if not base_model_id else 'Overwritten', task.id) +
|
||||||
|
('\n' + comment if comment else ''),
|
||||||
framework=framework,
|
framework=framework,
|
||||||
upload_storage_uri=task.output_uri,
|
upload_storage_uri=task.output_uri,
|
||||||
)
|
)
|
||||||
|
if base_model_id:
|
||||||
|
try:
|
||||||
|
_base_model = InputModel(base_model_id)._get_base_model()
|
||||||
|
_base_model.update(
|
||||||
|
labels=self._floating_data.labels,
|
||||||
|
design=self._floating_data.design,
|
||||||
|
task_id=self._task.id,
|
||||||
|
project_id=self._task.project,
|
||||||
|
name=self._floating_data.name or task.name,
|
||||||
|
comment=('{}\n{}'.format(_base_model.comment, self._floating_data.comment)
|
||||||
|
if _base_model.comment and self._floating_data.comment else
|
||||||
|
(_base_model.comment or self._floating_data.comment)),
|
||||||
|
tags=self._floating_data.tags,
|
||||||
|
framework=self._floating_data.framework,
|
||||||
|
upload_storage_uri=self._floating_data.upload_storage_uri
|
||||||
|
)
|
||||||
|
self._base_model = _base_model
|
||||||
|
self._floating_data = None
|
||||||
|
self._base_model.update_for_task(task_id=self._task.id, override_model_id=self.id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
self.connect(task)
|
self.connect(task)
|
||||||
|
|
||||||
def connect(self, task):
|
def connect(self, task):
|
||||||
@ -679,8 +797,16 @@ class OutputModel(BaseModel):
|
|||||||
raise ValueError('Can only connect preexisting model to task, but this is a fresh model')
|
raise ValueError('Can only connect preexisting model to task, but this is a fresh model')
|
||||||
|
|
||||||
if running_remotely() and task.is_main_task():
|
if running_remotely() and task.is_main_task():
|
||||||
self._floating_data.design = _Model._wrap_design(self._task.get_model_config_text())
|
if self._floating_data:
|
||||||
self._floating_data.labels = self._task.get_labels_enumeration()
|
self._floating_data.design = _Model._wrap_design(self._task.get_model_config_text()) or \
|
||||||
|
self._floating_data.design
|
||||||
|
self._floating_data.labels = self._task.get_labels_enumeration() or \
|
||||||
|
self._floating_data.labels
|
||||||
|
elif self._base_model:
|
||||||
|
self._base_model.update(design=_Model._wrap_design(self._task.get_model_config_text()) or
|
||||||
|
self._base_model.design)
|
||||||
|
self._base_model.update(labels=self._task.get_labels_enumeration() or self._base_model.labels)
|
||||||
|
|
||||||
elif self._floating_data is not None:
|
elif self._floating_data is not None:
|
||||||
# we copy configuration / labels if they exist, obviously someone wants them as the output base model
|
# we copy configuration / labels if they exist, obviously someone wants them as the output base model
|
||||||
if _Model._unwrap_design(self._floating_data.design):
|
if _Model._unwrap_design(self._floating_data.design):
|
||||||
|
Loading…
Reference in New Issue
Block a user