mirror of
https://github.com/clearml/clearml
synced 2025-01-31 17:17:00 +00:00
1323 lines
48 KiB
Python
1323 lines
48 KiB
Python
import abc
|
|
import os
|
|
import tarfile
|
|
import zipfile
|
|
from tempfile import mkdtemp, mkstemp
|
|
|
|
import pyparsing
|
|
import six
|
|
|
|
from .backend_api import Session
|
|
from .backend_api.services import models
|
|
from pathlib2 import Path
|
|
from .utilities.pyhocon import ConfigFactory, HOCONConverter
|
|
|
|
from .backend_interface.util import validate_dict, get_single_result, mutually_exclusive
|
|
from .debugging.log import get_logger
|
|
from .storage.helper import StorageHelper
|
|
from .utilities.enum import Options
|
|
from .backend_interface import Task as _Task
|
|
from .backend_interface.model import create_dummy_model, Model as _Model
|
|
from .config import running_remotely, get_cache_dir
|
|
|
|
ARCHIVED_TAG = "archived"
|
|
|
|
|
|
class Framework(Options):
|
|
"""
|
|
Optional frameworks for output model
|
|
"""
|
|
tensorflow = 'TensorFlow'
|
|
tensorflowjs = 'TensorFlow_js'
|
|
tensorflowlite = 'TensorFlow_Lite'
|
|
pytorch = 'PyTorch'
|
|
caffe = 'Caffe'
|
|
caffe2 = 'Caffe2'
|
|
onnx = 'ONNX'
|
|
keras = 'Keras'
|
|
mknet = 'MXNet'
|
|
cntk = 'CNTK'
|
|
torch = 'Torch'
|
|
darknet = 'Darknet'
|
|
paddlepaddle = 'PaddlePaddle'
|
|
scikitlearn = 'ScikitLearn'
|
|
xgboost = 'XGBoost'
|
|
|
|
__file_extensions_mapping = {
|
|
'.pb': (tensorflow, tensorflowjs, onnx, ),
|
|
'.meta': (tensorflow, ),
|
|
'.pbtxt': (tensorflow, onnx, ),
|
|
'.zip': (tensorflow, ),
|
|
'.tgz': (tensorflow, ),
|
|
'.tar.gz': (tensorflow, ),
|
|
'model.json': (tensorflowjs, ),
|
|
'.tflite': (tensorflowlite, ),
|
|
'.pth': (pytorch, ),
|
|
'.pt': (pytorch, ),
|
|
'.caffemodel': (caffe, ),
|
|
'.prototxt': (caffe, ),
|
|
'predict_net.pb': (caffe2, ),
|
|
'predict_net.pbtxt': (caffe2, ),
|
|
'.onnx': (onnx, ),
|
|
'.h5': (keras, ),
|
|
'.hdf5': (keras, ),
|
|
'.keras': (keras, ),
|
|
'.model': (mknet, cntk, xgboost),
|
|
'-symbol.json': (mknet, ),
|
|
'.cntk': (cntk, ),
|
|
'.t7': (torch, ),
|
|
'.cfg': (darknet, ),
|
|
'__model__': (paddlepaddle, ),
|
|
'.pkl': (scikitlearn, keras, xgboost),
|
|
}
|
|
|
|
@classmethod
|
|
def _get_file_ext(cls, framework, filename):
|
|
mapping = cls.__file_extensions_mapping
|
|
filename = filename.lower()
|
|
|
|
def find_framework_by_ext(framework_selector):
|
|
for ext, frameworks in mapping.items():
|
|
if frameworks and filename.endswith(ext):
|
|
fw = framework_selector(frameworks)
|
|
if fw:
|
|
return (fw, ext)
|
|
|
|
# If no framework, try finding first framework matching the extension, otherwise (or if no match) try matching
|
|
# the given extension to the given framework. If no match return an empty extension
|
|
return (
|
|
(not framework and find_framework_by_ext(lambda frameworks_: frameworks_[0]))
|
|
or find_framework_by_ext(lambda frameworks_: framework if framework in frameworks_ else None)
|
|
or (framework, filename.split('.')[-1] if '.' in filename else '')
|
|
)
|
|
|
|
|
|
@six.add_metaclass(abc.ABCMeta)
|
|
class BaseModel(object):
|
|
_package_tag = "package"
|
|
|
|
@property
|
|
def id(self):
|
|
"""
|
|
The Id (system UUID) of the model.
|
|
|
|
:return: The model id.
|
|
|
|
:rtype: str
|
|
"""
|
|
return self._get_model_data().id
|
|
|
|
@property
|
|
def name(self):
|
|
"""
|
|
The name of the model.
|
|
|
|
:return: The model name.
|
|
|
|
:rtype: str
|
|
"""
|
|
return self._get_model_data().name
|
|
|
|
@name.setter
|
|
def name(self, value):
|
|
"""
|
|
Set the model name.
|
|
|
|
:param str value: The model name.
|
|
"""
|
|
self._get_base_model().update(name=value)
|
|
|
|
@property
|
|
def comment(self):
|
|
"""
|
|
The comment for the model. Also, use for a model description.
|
|
|
|
:return: The model comment / description.
|
|
|
|
:rtype: str
|
|
"""
|
|
return self._get_model_data().comment
|
|
|
|
@comment.setter
|
|
def comment(self, value):
|
|
"""
|
|
Set comment for the model. Also, use for a model description.
|
|
|
|
:param str value: The model comment/description.
|
|
"""
|
|
self._get_base_model().update(comment=value)
|
|
|
|
@property
|
|
def tags(self):
|
|
"""
|
|
A list of tags describing the model.
|
|
|
|
:return: The list of tags.
|
|
|
|
:rtype: list(str)
|
|
"""
|
|
return self._get_model_data().tags
|
|
|
|
@tags.setter
|
|
def tags(self, value):
|
|
"""
|
|
Set the list of tags describing the model.
|
|
|
|
:param value: The tags.
|
|
|
|
:type value: list(str)
|
|
"""
|
|
self._get_base_model().update(tags=value)
|
|
|
|
@property
|
|
def config_text(self):
|
|
"""
|
|
The configuration as a string. For example, prototxt, an ini file, or Python code to evaluate.
|
|
|
|
:return: The configuration.
|
|
|
|
:rtype: str
|
|
"""
|
|
return _Model._unwrap_design(self._get_model_data().design)
|
|
|
|
@property
|
|
def config_dict(self):
|
|
"""
|
|
The configuration as a dictionary, parsed from the design text. This usually represents the model configuration.
|
|
For example, prototxt, an ini file, or Python code to evaluate.
|
|
|
|
:return: The configuration.
|
|
|
|
:rtype: dict
|
|
"""
|
|
return self._text_to_config_dict(self.config_text)
|
|
|
|
@property
|
|
def labels(self):
|
|
"""
|
|
The label enumeration of string (label) to integer (value) pairs.
|
|
|
|
|
|
:return: A dictionary containing labels enumeration, where the keys are labels and the values as integers.
|
|
|
|
:rtype: dict
|
|
"""
|
|
return self._get_model_data().labels
|
|
|
|
@property
|
|
def task(self):
|
|
"""
|
|
Return the creating task id (str)
|
|
|
|
:return str: Task ID
|
|
"""
|
|
return self._task or self._get_base_model().task
|
|
|
|
@property
|
|
def url(self):
|
|
"""
|
|
Return the url of the model file (or archived files)
|
|
|
|
:return str: Model file URL
|
|
"""
|
|
return self._get_base_model().uri
|
|
|
|
@property
|
|
def published(self):
|
|
return self._get_base_model().locked
|
|
|
|
@property
|
|
def framework(self):
|
|
return self._get_model_data().framework
|
|
|
|
def __init__(self, task=None):
|
|
super(BaseModel, self).__init__()
|
|
self._log = get_logger()
|
|
self._task = None
|
|
self._set_task(task)
|
|
|
|
def get_weights(self):
|
|
"""
|
|
Download the base model and return the locally stored filename.
|
|
|
|
:return: The locally stored file.
|
|
|
|
:rtype: str
|
|
"""
|
|
# download model (synchronously) and return local file
|
|
return self._get_base_model().download_model_weights()
|
|
|
|
def get_weights_package(self, return_path=False):
|
|
"""
|
|
Download the base model package into a temporary directory (extract the files), or return a list of the
|
|
locally stored filenames.
|
|
|
|
:param bool return_path: Return the model weights or a list of filenames? (Optional)
|
|
|
|
- ``True`` - Download the model weights into a temporary directory, and return the temporary directory path.
|
|
- ``False`` - Return a list of the locally stored filenames. (Default)
|
|
|
|
:return: The model weights, or a list of the locally stored filenames.
|
|
|
|
:rtype: package or path
|
|
"""
|
|
# check if model was packaged
|
|
if self._package_tag not in self._get_model_data().tags:
|
|
raise ValueError('Model is not packaged')
|
|
|
|
# download packaged model
|
|
packed_file = self.get_weights()
|
|
|
|
# unpack
|
|
target_folder = mkdtemp(prefix='model_package_')
|
|
if not target_folder:
|
|
raise ValueError('cannot create temporary directory for packed weight files')
|
|
|
|
for func in (zipfile.ZipFile, tarfile.open):
|
|
try:
|
|
obj = func(packed_file)
|
|
obj.extractall(path=target_folder)
|
|
break
|
|
except (zipfile.BadZipfile, tarfile.ReadError):
|
|
pass
|
|
else:
|
|
raise ValueError('cannot extract files from packaged model at %s', packed_file)
|
|
|
|
if return_path:
|
|
return target_folder
|
|
|
|
target_files = list(Path(target_folder).glob('*'))
|
|
return target_files
|
|
|
|
def publish(self):
|
|
"""
|
|
Set the model to the status ``published`` and for public use. If the model's status is already ``published``,
|
|
then this method is a no-op.
|
|
"""
|
|
|
|
if not self.published:
|
|
self._get_base_model().publish()
|
|
|
|
def _running_remotely(self):
|
|
return bool(running_remotely() and self._task is not None)
|
|
|
|
def _set_task(self, value):
|
|
if value is not None and not isinstance(value, _Task):
|
|
raise ValueError('task argument must be of Task type')
|
|
self._task = value
|
|
|
|
@abc.abstractmethod
|
|
def _get_model_data(self):
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def _get_base_model(self):
|
|
pass
|
|
|
|
def _set_package_tag(self):
|
|
if self._package_tag not in self.tags:
|
|
self.tags.append(self._package_tag)
|
|
self._get_base_model().edit(tags=self.tags)
|
|
|
|
@staticmethod
|
|
def _config_dict_to_text(config):
|
|
# if already string return as is
|
|
if isinstance(config, six.string_types):
|
|
return config
|
|
if not isinstance(config, dict):
|
|
raise ValueError("Model configuration only supports dictionary objects")
|
|
try:
|
|
try:
|
|
text = HOCONConverter.to_hocon(ConfigFactory.from_dict(config))
|
|
except Exception:
|
|
# fallback json+pyhocon
|
|
# hack, pyhocon is not very good with dict conversion so we pass through json
|
|
import json
|
|
text = json.dumps(config)
|
|
text = HOCONConverter.to_hocon(ConfigFactory.parse_string(text))
|
|
|
|
except Exception:
|
|
raise ValueError("Could not serialize configuration dictionary:\n", config)
|
|
return text
|
|
|
|
@staticmethod
|
|
def _text_to_config_dict(text):
|
|
if not isinstance(text, six.string_types):
|
|
raise ValueError("Model configuration parsing only supports string")
|
|
try:
|
|
return ConfigFactory.parse_string(text).as_plain_ordered_dict()
|
|
except pyparsing.ParseBaseException as ex:
|
|
pos = "at char {}, line:{}, col:{}".format(ex.loc, ex.lineno, ex.column)
|
|
six.raise_from(ValueError("Could not parse configuration text ({}):\n{}".format(pos, text)), None)
|
|
except Exception:
|
|
six.raise_from(ValueError("Could not parse configuration text:\n{}".format(text)), None)
|
|
|
|
@staticmethod
|
|
def _resolve_config(config_text=None, config_dict=None):
|
|
mutually_exclusive(config_text=config_text, config_dict=config_dict, _require_at_least_one=False)
|
|
if config_dict:
|
|
return InputModel._config_dict_to_text(config_dict)
|
|
|
|
return config_text
|
|
|
|
|
|
class Model(BaseModel):
|
|
"""
|
|
Represent an existing model in the system, search by model id.
|
|
The Model will be read-only and can be used to pre initialize a network
|
|
"""
|
|
|
|
def __init__(self, model_id):
|
|
"""
|
|
Load model based on id, returned object is read-only and can be connected to a task
|
|
|
|
Notice, we can override the input model when running remotely
|
|
|
|
:param model_id: id (string)
|
|
"""
|
|
super(Model, self).__init__()
|
|
self._base_model_id = model_id
|
|
self._base_model = None
|
|
|
|
def get_local_copy(self, extract_archive=True):
|
|
"""
|
|
Retrieve a valid link to the model file(s).
|
|
If the model URL is a file system link, it will be returned directly.
|
|
If the model URL is points to a remote location (http/s3/gs etc.),
|
|
it will download the file(s) and return the temporary location of the downloaded model.
|
|
|
|
:param bool extract_archive: If True and the model is of type 'packaged' (e.g. TensorFlow compressed folder)
|
|
The returned path will be a temporary folder containing the archive content
|
|
:return str: a local path to the model (or a downloaded copy of it)
|
|
"""
|
|
if extract_archive and self._package_tag in self.tags:
|
|
return self.get_weights_package(return_path=True)
|
|
return self.get_weights()
|
|
|
|
def _get_base_model(self):
|
|
if self._base_model:
|
|
return self._base_model
|
|
|
|
if not self._base_model_id:
|
|
# this shouldn't actually happen
|
|
raise Exception('Missing model ID, cannot create an empty model')
|
|
self._base_model = _Model(
|
|
upload_storage_uri=None,
|
|
cache_dir=get_cache_dir(),
|
|
model_id=self._base_model_id,
|
|
)
|
|
return self._base_model
|
|
|
|
def _get_model_data(self):
|
|
return self._get_base_model().data
|
|
|
|
|
|
class InputModel(Model):
|
|
"""
|
|
Load an existing model in the system, search by model id.
|
|
The Model will be read-only and can be used to pre initialize a network
|
|
We can connect the model to a task as input model, then when running remotely override it with the UI.
|
|
"""
|
|
|
|
_EMPTY_MODEL_ID = _Model._EMPTY_MODEL_ID
|
|
|
|
@classmethod
|
|
def import_model(
|
|
cls,
|
|
weights_url,
|
|
config_text=None,
|
|
config_dict=None,
|
|
label_enumeration=None,
|
|
name=None,
|
|
tags=None,
|
|
comment=None,
|
|
is_package=False,
|
|
create_as_published=False,
|
|
framework=None,
|
|
):
|
|
"""
|
|
Create an InputModel object from a pre-trained model by specifying the URL of an initial weight files.
|
|
Optionally, input a configuration, label enumeration, name for the model, tags describing the model,
|
|
comment as a description of the model, indicate whether the model is a package, specify the model's
|
|
framework, and indicate whether to immediately set the model's status to ``Published``.
|
|
The model is read-only.
|
|
|
|
The **Trains Server** (backend) may already store the model's URL. If the input model's URL is not
|
|
stored, meaning the model is new, then it is imported and Trains stores its metadata.
|
|
If the URL is already stored, the import process stops, Trains issues a warning message, and Trains
|
|
reuses the model.
|
|
|
|
In your Python experiment script, after importing the model, you can connect it to the main execution
|
|
Task as an input model using :meth:`InputModel.connect` or :meth:`.Task.connect`. That initializes the
|
|
network.
|
|
|
|
.. note::
|
|
Using the **Trains Web-App** (user interface), you can reuse imported models and switch models in
|
|
experiments.
|
|
|
|
:param str weights_url: A valid URL for the initial weights file. If the **Trains Web-App** (backend)
|
|
already stores the metadata of a model with the same URL, that existing model is returned
|
|
and Trains ignores all other parameters.
|
|
|
|
For example:
|
|
|
|
- ``https://domain.com/file.bin``
|
|
- ``s3://bucket/file.bin``
|
|
- ``file:///home/user/file.bin``
|
|
|
|
:param str config_text: The configuration as a string. This is usually the content of a configuration
|
|
dictionary file. Specify ``config_text`` or ``config_dict``, but not both.
|
|
:type config_text: unconstrained text string
|
|
:param dict config_dict: The configuration as a dictionary. Specify ``config_text`` or ``config_dict``,
|
|
but not both.
|
|
:param dict label_enumeration: The label enumeration dictionary of string (label) to integer (value) pairs. (Optional)
|
|
|
|
For example:
|
|
|
|
.. code-block:: javascript
|
|
|
|
{
|
|
'background': 0,
|
|
'person': 1
|
|
}
|
|
:param str name: The name of the newly imported model. (Optional)
|
|
:param tags: The list of tags which describe the model. (Optional)
|
|
:type tags: list(str)
|
|
:param str comment: A comment / description for the model. (Optional)
|
|
:type comment str:
|
|
:param is_package: Is the imported weights file is a package? (Optional)
|
|
|
|
- ``True`` - Is a package. Add a package tag to the model.
|
|
- ``False`` - Is not a package. Do not add a package tag. (Default)
|
|
|
|
:type is_package: bool
|
|
:param bool create_as_published: Set the model's status to Published? (Optional)
|
|
|
|
- ``True`` - Set the status to Published.
|
|
- ``False`` - Do not set the status to Published. The status will be Draft. (Default)
|
|
|
|
:param str framework: The framework of the model. (Optional)
|
|
:type framework: str or Framework object
|
|
|
|
:return: The imported model or existing model (see above).
|
|
|
|
:rtype: A model object.
|
|
"""
|
|
config_text = cls._resolve_config(config_text=config_text, config_dict=config_dict)
|
|
weights_url = StorageHelper.conform_url(weights_url)
|
|
if not weights_url:
|
|
raise ValueError("Please provide a valid weights_url parameter")
|
|
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \
|
|
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]}
|
|
result = _Model._get_default_session().send(models.GetAllRequest(
|
|
uri=[weights_url],
|
|
only_fields=["id", "name", "created"],
|
|
**extra
|
|
))
|
|
|
|
if result.response.models:
|
|
logger = get_logger()
|
|
|
|
logger.debug('A model with uri "{}" already exists. Selecting it'.format(weights_url))
|
|
|
|
model = get_single_result(
|
|
entity='model',
|
|
query=weights_url,
|
|
results=result.response.models,
|
|
log=logger,
|
|
raise_on_error=False,
|
|
)
|
|
|
|
logger.info("Selected model id: {}".format(model.id))
|
|
|
|
return InputModel(model_id=model.id)
|
|
|
|
base_model = _Model(
|
|
upload_storage_uri=None,
|
|
cache_dir=get_cache_dir(),
|
|
)
|
|
|
|
from .task import Task
|
|
task = Task.current_task()
|
|
if task:
|
|
comment = 'Imported by task id: {}'.format(task.id) + ('\n'+comment if comment else '')
|
|
project_id = task.project
|
|
task_id = task.id
|
|
else:
|
|
project_id = None
|
|
task_id = None
|
|
|
|
if not framework:
|
|
framework, file_ext = Framework._get_file_ext(
|
|
framework=framework,
|
|
filename=weights_url
|
|
)
|
|
|
|
base_model.update(
|
|
design=config_text,
|
|
labels=label_enumeration,
|
|
name=name,
|
|
comment=comment,
|
|
tags=tags,
|
|
uri=weights_url,
|
|
framework=framework,
|
|
project_id=project_id,
|
|
task_id=task_id,
|
|
)
|
|
|
|
this_model = InputModel(model_id=base_model.id)
|
|
this_model._base_model = base_model
|
|
|
|
if is_package:
|
|
this_model._set_package_tag()
|
|
|
|
if create_as_published:
|
|
this_model.publish()
|
|
|
|
return this_model
|
|
|
|
@classmethod
|
|
def load_model(
|
|
cls,
|
|
weights_url,
|
|
load_archived=False
|
|
):
|
|
"""
|
|
Load an already registered model based on a pre-existing model file (link must be valid).
|
|
|
|
If the url to the weights file already exists, the returned object is a Model representing the loaded Model
|
|
If there could not be found any registered model Model with the specified url, None is returned.
|
|
|
|
:param weights_url: valid url for the weights file (string).
|
|
examples: "https://domain.com/file.bin" or "s3://bucket/file.bin" or "file:///home/user/file.bin".
|
|
NOTE: if a model with the exact same URL exists, it will be used, and all other arguments will be ignored.
|
|
:param bool load_archived: If True return registered Model with even if they are archived,
|
|
otherwise archived models are ignored,
|
|
:return Model: InputModel object or None if no model could be found
|
|
"""
|
|
weights_url = StorageHelper.conform_url(weights_url)
|
|
if not weights_url:
|
|
raise ValueError("Please provide a valid weights_url parameter")
|
|
if not load_archived:
|
|
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \
|
|
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]}
|
|
else:
|
|
extra = {}
|
|
|
|
result = _Model._get_default_session().send(models.GetAllRequest(
|
|
uri=[weights_url],
|
|
only_fields=["id", "name", "created"],
|
|
**extra
|
|
))
|
|
|
|
if not result or not result.response or not result.response.models:
|
|
return None
|
|
|
|
logger = get_logger()
|
|
model = get_single_result(
|
|
entity='model',
|
|
query=weights_url,
|
|
results=result.response.models,
|
|
log=logger,
|
|
raise_on_error=False,
|
|
)
|
|
|
|
return InputModel(model_id=model.id)
|
|
|
|
@classmethod
|
|
def empty(
|
|
cls,
|
|
config_text=None,
|
|
config_dict=None,
|
|
label_enumeration=None,
|
|
):
|
|
"""
|
|
Create an empty model object. Later, you can assign a model to the empty model object.
|
|
|
|
:param config_text: The model configuration as a string. This is usually the content of a configuration
|
|
dictionary file. Specify ``config_text`` or ``config_dict``, but not both.
|
|
:type config_text: unconstrained text string
|
|
:param dict config_dict: The model configuration as a dictionary. Specify ``config_text`` or ``config_dict``,
|
|
but not both.
|
|
:param dict label_enumeration: The label enumeration dictionary of string (label) to integer (value) pairs.
|
|
(Optional)
|
|
|
|
For example:
|
|
|
|
.. code-block:: javascript
|
|
|
|
{
|
|
'background': 0,
|
|
'person': 1
|
|
}
|
|
"""
|
|
design = cls._resolve_config(config_text=config_text, config_dict=config_dict)
|
|
|
|
this_model = InputModel(model_id=cls._EMPTY_MODEL_ID)
|
|
this_model._base_model = m = _Model(
|
|
cache_dir=None,
|
|
upload_storage_uri=None,
|
|
model_id=cls._EMPTY_MODEL_ID,
|
|
)
|
|
m._data.design = _Model._wrap_design(design)
|
|
m._data.labels = label_enumeration
|
|
return this_model
|
|
|
|
def __init__(self, model_id):
|
|
"""
|
|
:param str model_id: The Trains Id (system UUID) of the input model whose metadata the **Trains Server**
|
|
(backend) stores.
|
|
"""
|
|
super(InputModel, self).__init__(model_id)
|
|
|
|
@property
|
|
def id(self):
|
|
return self._base_model_id
|
|
|
|
def connect(self, task):
|
|
"""
|
|
Connect the current model to a Task object, if the model is preexisting. Preexisting models include:
|
|
|
|
- Imported models (InputModel objects created using the :meth:`Logger.import_model` method).
|
|
- Models whose metadata is already in the Trains platform, meaning the InputModel object is instantiated
|
|
from the ``InputModel`` class specifying the the model's Trains Id as an argument.
|
|
- Models whose origin is not Trains that are used to create an InputModel object. For example,
|
|
models created using TensorFlow models.
|
|
|
|
When the experiment is executed remotely in a worker, the input model already specified in the experiment is
|
|
used.
|
|
|
|
.. note::
|
|
The **Trains Web-App** allows you to switch one input model for another and then enqueue the experiment
|
|
to execute in a worker.
|
|
|
|
:param object task: A Task object.
|
|
"""
|
|
self._set_task(task)
|
|
|
|
if running_remotely() and task.input_model and task.is_main_task():
|
|
self._base_model = task.input_model
|
|
self._base_model_id = task.input_model.id
|
|
else:
|
|
# we should set the task input model to point to us
|
|
model = self._get_base_model()
|
|
# try to store the input model id, if it is not empty
|
|
if model.id != self._EMPTY_MODEL_ID:
|
|
task.set_input_model(model_id=model.id)
|
|
# only copy the model design if the task has no design to begin with
|
|
if not self._task._get_model_config_text():
|
|
task._set_model_config(config_text=model.model_design)
|
|
if not self._task.get_labels_enumeration():
|
|
task.set_model_label_enumeration(model.data.labels)
|
|
|
|
# If there was an output model connected, it may need to be updated by
|
|
# the newly connected input model
|
|
self.task._reconnect_output_model()
|
|
|
|
|
|
class OutputModel(BaseModel):
|
|
"""
|
|
Create an output model for a Task (experiment) to store the training results.
|
|
|
|
The OutputModel object is always connected to a Task object, because it is instantiated with a Task object
|
|
as an argument. It is, therefore, automatically registered as the Task's (experiment's) output model.
|
|
|
|
The OutputModel object is read-write.
|
|
|
|
A common use case is to reuse the OutputModel object, and override the weights after storing a model snapshot.
|
|
Another use case is to create multiple OutputModel objects for a Task (experiment), and after a new high score
|
|
is found, store a model snapshot.
|
|
|
|
If the model configuration and / or the model's label enumeration
|
|
are ``None``, then the output model is initialized with the values from the Task object's input model.
|
|
|
|
.. note::
|
|
When executing a Task (experiment) remotely in a worker, you can modify the model configuration and / or model's
|
|
label enumeration using the **Trains Web-App**.
|
|
"""
|
|
|
|
@property
|
|
def published(self):
|
|
if not self.id:
|
|
return False
|
|
return self._get_base_model().locked
|
|
|
|
@property
|
|
def config_text(self):
|
|
"""
|
|
Get the configuration as a string. For example, prototxt, an ini file, or Python code to evaluate.
|
|
|
|
:return: The configuration.
|
|
|
|
:rtype: str
|
|
"""
|
|
return _Model._unwrap_design(self._get_model_data().design)
|
|
|
|
@config_text.setter
|
|
def config_text(self, value):
|
|
"""
|
|
Set the configuration. Store a blob of text for custom usage.
|
|
"""
|
|
self.update_design(config_text=value)
|
|
|
|
@property
|
|
def config_dict(self):
|
|
"""
|
|
Get the configuration as a dictionary parsed from the ``config_text`` text. This usually represents the model
|
|
configuration. For example, from prototxt to ini file or python code to evaluate.
|
|
|
|
:return: The configuration.
|
|
|
|
:rtype: dict
|
|
"""
|
|
return self._text_to_config_dict(self.config_text)
|
|
|
|
@config_dict.setter
|
|
def config_dict(self, value):
|
|
"""
|
|
Set the configuration. Saved in the model object.
|
|
|
|
:param dict value: The configuration parameters.
|
|
"""
|
|
self.update_design(config_dict=value)
|
|
|
|
@property
|
|
def labels(self):
|
|
"""
|
|
Get the label enumeration as a dictionary of string (label) to integer (value) pairs.
|
|
|
|
For example:
|
|
|
|
.. code-block:: javascript
|
|
|
|
{
|
|
'background': 0,
|
|
'person': 1
|
|
}
|
|
|
|
:return: The label enumeration.
|
|
|
|
:rtype: dict
|
|
"""
|
|
return self._get_model_data().labels
|
|
|
|
@labels.setter
|
|
def labels(self, value):
|
|
"""
|
|
Set the label enumeration.
|
|
|
|
:param dict value: The label enumeration dictionary of string (label) to integer (value) pairs.
|
|
|
|
For example:
|
|
|
|
.. code-block:: javascript
|
|
|
|
{
|
|
'background': 0,
|
|
'person': 1
|
|
}
|
|
|
|
"""
|
|
self.update_labels(labels=value)
|
|
|
|
@property
|
|
def upload_storage_uri(self):
|
|
return self._get_base_model().upload_storage_uri
|
|
|
|
def __init__(
|
|
self,
|
|
task,
|
|
config_text=None,
|
|
config_dict=None,
|
|
label_enumeration=None,
|
|
name=None,
|
|
tags=None,
|
|
comment=None,
|
|
framework=None,
|
|
base_model_id=None,
|
|
):
|
|
"""
|
|
Create a new model and immediately connect it to a task.
|
|
|
|
We do not allow for Model creation without a task, so we always keep track on how we created the models
|
|
In remote execution, Model parameters can be overridden by the Task (such as model configuration & label enumerator)
|
|
|
|
:param task: The Task object with which the OutputModel object is associated.
|
|
:type task: Task
|
|
:param config_text: The configuration as a string. This is usually the content of a configuration
|
|
dictionary file. Specify ``config_text`` or ``config_dict``, but not both.
|
|
:type config_text: unconstrained text string
|
|
:param dict config_dict: The configuration as a dictionary.
|
|
Specify ``config_dict`` or ``config_text``, but not both.
|
|
:param dict label_enumeration: The label enumeration dictionary of string (label) to integer (value) pairs.
|
|
(Optional)
|
|
|
|
For example:
|
|
|
|
.. code-block:: javascript
|
|
|
|
{
|
|
'background': 0,
|
|
'person': 1
|
|
}
|
|
|
|
:param str name: The name for the newly created model. (Optional)
|
|
:param list(str) tags: A list of strings which are tags for the model. (Optional)
|
|
:param str comment: A comment / description for the model. (Optional)
|
|
:param framework: The framework of the model or a Framework object. (Optional)
|
|
:type framework: str or Framework object
|
|
:param base_model_id: optional, model id to be reused
|
|
"""
|
|
super(OutputModel, self).__init__(task=task)
|
|
|
|
config_text = self._resolve_config(config_text=config_text, config_dict=config_dict)
|
|
|
|
self._model_local_filename = None
|
|
self._base_model = None
|
|
self._floating_data = create_dummy_model(
|
|
design=_Model._wrap_design(config_text),
|
|
labels=label_enumeration or task.get_labels_enumeration(),
|
|
name=name,
|
|
tags=tags,
|
|
comment='{} by task id: {}'.format('Created' if not base_model_id else 'Overwritten', task.id) +
|
|
('\n' + comment if comment else ''),
|
|
framework=framework,
|
|
upload_storage_uri=task.output_uri,
|
|
)
|
|
if base_model_id:
|
|
try:
|
|
_base_model = InputModel(base_model_id)._get_base_model()
|
|
_base_model.update(
|
|
labels=self._floating_data.labels,
|
|
design=self._floating_data.design,
|
|
task_id=self._task.id,
|
|
project_id=self._task.project,
|
|
name=self._floating_data.name or task.name,
|
|
comment=('{}\n{}'.format(_base_model.comment, self._floating_data.comment)
|
|
if _base_model.comment and self._floating_data.comment else
|
|
(_base_model.comment or self._floating_data.comment)),
|
|
tags=self._floating_data.tags,
|
|
framework=self._floating_data.framework,
|
|
upload_storage_uri=self._floating_data.upload_storage_uri
|
|
)
|
|
self._base_model = _base_model
|
|
self._floating_data = None
|
|
self._base_model.update_for_task(task_id=self._task.id, override_model_id=self.id)
|
|
except Exception:
|
|
pass
|
|
self.connect(task)
|
|
|
|
def connect(self, task):
|
|
"""
|
|
Connect the current model to a Task object, if the model is a preexisting model. Preexisting models include:
|
|
|
|
- Imported models.
|
|
- Models whose metadata the **Trains Server** (backend) is already storing.
|
|
- Models from another source, such as frameworks like TensorFlow.
|
|
|
|
:param object task: A Task object.
|
|
"""
|
|
if self._task != task:
|
|
raise ValueError('Can only connect preexisting model to task, but this is a fresh model')
|
|
|
|
if running_remotely() and task.is_main_task():
|
|
if self._floating_data:
|
|
self._floating_data.design = _Model._wrap_design(self._task._get_model_config_text()) or \
|
|
self._floating_data.design
|
|
self._floating_data.labels = self._task.get_labels_enumeration() or \
|
|
self._floating_data.labels
|
|
elif self._base_model:
|
|
self._base_model.update(design=_Model._wrap_design(self._task._get_model_config_text()) or
|
|
self._base_model.design)
|
|
self._base_model.update(labels=self._task.get_labels_enumeration() or self._base_model.labels)
|
|
|
|
elif self._floating_data is not None:
|
|
# we copy configuration / labels if they exist, obviously someone wants them as the output base model
|
|
if _Model._unwrap_design(self._floating_data.design):
|
|
if not task._get_model_config_text():
|
|
task._set_model_config(config_text=self._floating_data.design)
|
|
else:
|
|
self._floating_data.design = _Model._wrap_design(self._task._get_model_config_text())
|
|
|
|
if self._floating_data.labels:
|
|
task.set_model_label_enumeration(self._floating_data.labels)
|
|
else:
|
|
self._floating_data.labels = self._task.get_labels_enumeration()
|
|
|
|
self.task._save_output_model(self)
|
|
|
|
def set_upload_destination(self, uri):
|
|
"""
|
|
Set the URI of the storage destination for uploaded model weight files. Supported storage destinations include
|
|
S3, Google Cloud Storage), and file locations.
|
|
|
|
Using this method, files uploads are separate and then a link to each is stored in the model object.
|
|
|
|
.. note::
|
|
For storage requiring credentials, the credentials are stored in the Trains configuration file,
|
|
``~/trains.conf``.
|
|
|
|
:param str uri: The URI of the upload storage destination.
|
|
|
|
For example:
|
|
|
|
- ``s3://bucket/directory/``
|
|
- ``file:///tmp/debug/``
|
|
|
|
:return: The status of whether the storage destination schema is supported.
|
|
|
|
- ``True`` - The storage destination scheme is supported.
|
|
- ``False`` - The storage destination scheme is not supported.
|
|
|
|
:rtype: bool
|
|
"""
|
|
if not uri:
|
|
return
|
|
|
|
# Test if we can update the model.
|
|
self._validate_update()
|
|
|
|
# Create the storage helper
|
|
storage = StorageHelper.get(uri)
|
|
|
|
# Verify that we can upload to this destination
|
|
try:
|
|
uri = storage.verify_upload(folder_uri=uri)
|
|
except Exception:
|
|
raise ValueError("Could not set destination uri to: %s [Check write permissions]" % uri)
|
|
|
|
# store default uri
|
|
self._get_base_model().upload_storage_uri = uri
|
|
|
|
def update_weights(self, weights_filename=None, upload_uri=None, target_filename=None,
|
|
auto_delete_file=True, register_uri=None, iteration=None, update_comment=True):
|
|
"""
|
|
Update the model weights from a locally stored model filename.
|
|
|
|
.. note::
|
|
Uploading the model is a background process. A call to this method returns immediately.
|
|
|
|
:param str weights_filename: The name of the locally stored weights file to upload. Specify ``weights_filename``
|
|
or ``register_uri``, but not both.
|
|
:param str upload_uri: The URI of the storage destination for model weights upload. The default value
|
|
is the previously used URI. (Optional)
|
|
:param str target_filename: The newly created filename in the storage destination location. The default value
|
|
is the ``weights_filename`` value. (Optional)
|
|
:param bool auto_delete_file: Delete the temporary file after uploading? (Optional)
|
|
|
|
- ``True`` - Delete (Default)
|
|
- ``False`` - Do not delete
|
|
|
|
:param str register_uri: The URI of an already uploaded weights file. The URI must be valid. Specify
|
|
``register_uri`` or ``weights_filename``, but not both.
|
|
:param bool update_comment: Update the model comment with the local weights file name (to maintain
|
|
provenance)? (Optional)
|
|
|
|
- ``True`` - Update model comment (Default)
|
|
- ``False`` - Do not update
|
|
|
|
:return: The uploaded URI.
|
|
|
|
:rtype: str
|
|
"""
|
|
|
|
def delete_previous_weights_file(filename=weights_filename):
|
|
try:
|
|
if filename:
|
|
os.remove(filename)
|
|
except OSError:
|
|
self._log.debug('Failed removing temporary file %s' % filename)
|
|
|
|
# test if we can update the model
|
|
if self.id and self.published:
|
|
raise ValueError('Model is published and cannot be changed')
|
|
|
|
if (not weights_filename and not register_uri) or (weights_filename and register_uri):
|
|
raise ValueError('Model update must have either local weights file to upload, '
|
|
'or pre-uploaded register_uri, never both')
|
|
|
|
# only upload if we are connected to a task
|
|
if not self._task:
|
|
raise Exception('Missing a task for this model')
|
|
|
|
if weights_filename is not None:
|
|
# make sure we delete the previous file, if it exists
|
|
if self._model_local_filename != weights_filename:
|
|
delete_previous_weights_file(self._model_local_filename)
|
|
# store temp filename for deletion next time, if needed
|
|
if auto_delete_file:
|
|
self._model_local_filename = weights_filename
|
|
|
|
# make sure the created model is updated:
|
|
model = self._get_force_base_model()
|
|
if not model:
|
|
raise ValueError('Failed creating internal output model')
|
|
|
|
# select the correct file extension based on the framework, or update the framework based on the file extension
|
|
framework, file_ext = Framework._get_file_ext(
|
|
framework=self._get_model_data().framework,
|
|
filename=target_filename or weights_filename or register_uri
|
|
)
|
|
|
|
if weights_filename:
|
|
target_filename = target_filename or Path(weights_filename).name
|
|
if not target_filename.lower().endswith(file_ext):
|
|
target_filename += file_ext
|
|
|
|
# set target uri for upload (if specified)
|
|
if upload_uri:
|
|
self.set_upload_destination(upload_uri)
|
|
|
|
# let us know the iteration number, we put it in the comment section for now.
|
|
if update_comment:
|
|
comment = self.comment or ''
|
|
iteration_msg = 'snapshot {} stored'.format(weights_filename or register_uri)
|
|
if not comment.startswith('\n'):
|
|
comment = '\n' + comment
|
|
comment = iteration_msg + comment
|
|
else:
|
|
comment = None
|
|
|
|
# if we have no output destination, just register the local model file
|
|
if weights_filename and not self.upload_storage_uri and not self._task.storage_uri:
|
|
register_uri = weights_filename
|
|
weights_filename = None
|
|
auto_delete_file = False
|
|
self._log.info('No output storage destination defined, registering local model %s' % register_uri)
|
|
|
|
# start the upload
|
|
if weights_filename:
|
|
if not model.upload_storage_uri:
|
|
self.set_upload_destination(self.upload_storage_uri or self._task.storage_uri)
|
|
|
|
output_uri = model.update_and_upload(
|
|
model_file=weights_filename,
|
|
task_id=self._task.id,
|
|
async_enable=True,
|
|
target_filename=target_filename,
|
|
framework=self.framework or framework,
|
|
comment=comment,
|
|
cb=delete_previous_weights_file if auto_delete_file else None,
|
|
iteration=iteration or self._task.get_last_iteration(),
|
|
)
|
|
elif register_uri:
|
|
register_uri = StorageHelper.conform_url(register_uri)
|
|
output_uri = model.update(uri=register_uri, task_id=self._task.id, framework=framework, comment=comment)
|
|
else:
|
|
output_uri = None
|
|
|
|
# make sure that if we are in dev move we report that we are training (not debugging)
|
|
self._task._output_model_updated()
|
|
|
|
return output_uri
|
|
|
|
def update_weights_package(self, weights_filenames=None, weights_path=None, upload_uri=None,
|
|
target_filename=None, auto_delete_file=True, iteration=None):
|
|
"""
|
|
Update the model weights from locally stored model files, or from directory containing multiple files.
|
|
|
|
.. note::
|
|
Uploading the model weights is a background process. A call to this method returns immediately.
|
|
|
|
:param weights_filenames: The file names of the locally stored model files. Specify ``weights_filenames``
|
|
or ``weights_path``, but not both.
|
|
:type weights_filenames: list(str)
|
|
:param weights_path: The directory path to a package. All the files in the directory will be uploaded.
|
|
Specify ``weights_path`` or ``weights_filenames``, but not both.
|
|
:type weights_path: str
|
|
:param str upload_uri: The URI of the storage destination for the model weights upload. The default
|
|
is the previously used URI. (Optional)
|
|
:param str target_filename: The newly created filename in the storage destination URI location. The default
|
|
is the value specified in the ``weights_filename`` parameter. (Optional)
|
|
:param bool auto_delete_file: Delete temporary file after uploading? (Optional)
|
|
|
|
- ``True`` - Delete (Default)
|
|
- ``False`` - Do not delete
|
|
|
|
:return: The uploaded URI for the weights package.
|
|
|
|
:rtype: str
|
|
"""
|
|
# create list of files
|
|
if (not weights_filenames and not weights_path) or (weights_filenames and weights_path):
|
|
raise ValueError('Model update weights package should get either directory path to pack or a list of files')
|
|
|
|
if not weights_filenames:
|
|
weights_filenames = list(map(six.text_type, Path(weights_path).glob('*')))
|
|
|
|
# create packed model from all the files
|
|
fd, zip_file = mkstemp(prefix='model_package.', suffix='.zip')
|
|
try:
|
|
with zipfile.ZipFile(zip_file, 'w', allowZip64=True, compression=zipfile.ZIP_STORED) as zf:
|
|
for filename in weights_filenames:
|
|
zf.write(filename, arcname=Path(filename).name)
|
|
finally:
|
|
os.close(fd)
|
|
|
|
# now we can delete the files (or path if provided)
|
|
if auto_delete_file:
|
|
def safe_remove(path, is_dir=False):
|
|
try:
|
|
(os.rmdir if is_dir else os.remove)(path)
|
|
except OSError:
|
|
self._log.info('Failed removing temporary {}'.format(path))
|
|
|
|
for filename in weights_filenames:
|
|
safe_remove(filename)
|
|
if weights_path:
|
|
safe_remove(weights_path, is_dir=True)
|
|
|
|
if target_filename and not target_filename.lower().endswith('.zip'):
|
|
target_filename += '.zip'
|
|
|
|
# and now we should upload the file, always delete the temporary zip file
|
|
comment = self.comment or ''
|
|
iteration_msg = 'snapshot {} stored'.format(str(weights_filenames))
|
|
if not comment.startswith('\n'):
|
|
comment = '\n' + comment
|
|
comment = iteration_msg + comment
|
|
self.comment = comment
|
|
uploaded_uri = self.update_weights(weights_filename=zip_file, auto_delete_file=True, upload_uri=upload_uri,
|
|
target_filename=target_filename or 'model_package.zip',
|
|
iteration=iteration, update_comment=False)
|
|
# set the model tag (by now we should have a model object) so we know we have packaged file
|
|
self._set_package_tag()
|
|
return uploaded_uri
|
|
|
|
def update_design(self, config_text=None, config_dict=None):
|
|
"""
|
|
Update the model configuration. Store a blob of text for custom usage.
|
|
|
|
.. note::
|
|
This method's behavior is lazy. The design update is only forced when the weights
|
|
are updated.
|
|
|
|
:param config_text: The configuration as a string. This is usually the content of a configuration
|
|
dictionary file. Specify ``config_text`` or ``config_dict``, but not both.
|
|
:type config_text: unconstrained text string
|
|
:param dict config_dict: The configuration as a dictionary. Specify ``config_text`` or ``config_dict``,
|
|
but not both.
|
|
|
|
:return: The status of the update.
|
|
|
|
- ``True`` - Update successful.
|
|
- ``False`` - Update not successful.
|
|
:rtype: bool
|
|
"""
|
|
if not self._validate_update():
|
|
return
|
|
|
|
config_text = self._resolve_config(config_text=config_text, config_dict=config_dict)
|
|
|
|
if self._task and not self._task.get_model_config_text():
|
|
self._task.set_model_config(config_text=config_text)
|
|
|
|
if self.id:
|
|
# update the model object (this will happen if we resumed a training task)
|
|
result = self._get_force_base_model().edit(design=config_text)
|
|
else:
|
|
self._floating_data.design = _Model._wrap_design(config_text)
|
|
result = Waitable()
|
|
|
|
# you can wait on this object
|
|
return result
|
|
|
|
def update_labels(self, labels):
|
|
"""
|
|
Update the label enumeration.
|
|
|
|
:param dict labels: The label enumeration dictionary of string (label) to integer (value) pairs.
|
|
|
|
For example:
|
|
|
|
.. code-block:: javascript
|
|
|
|
{
|
|
'background': 0,
|
|
'person': 1
|
|
}
|
|
|
|
:return:
|
|
"""
|
|
validate_dict(labels, key_types=six.string_types, value_types=six.integer_types, desc='label enumeration')
|
|
|
|
if not self._validate_update():
|
|
return
|
|
|
|
if self._task:
|
|
self._task.set_model_label_enumeration(labels)
|
|
|
|
if self.id:
|
|
# update the model object (this will happen if we resumed a training task)
|
|
result = self._get_force_base_model().edit(labels=labels)
|
|
else:
|
|
self._floating_data.labels = labels
|
|
result = Waitable()
|
|
|
|
# you can wait on this object
|
|
return result
|
|
|
|
@classmethod
|
|
def wait_for_uploads(cls, timeout=None, max_num_uploads=None):
|
|
"""
|
|
Wait for any pending or in-progress model uploads to complete. If no uploads are pending or in-progress,
|
|
then the ``wait_for_uploads`` returns immediately.
|
|
|
|
:param float timeout: The timeout interval to wait for uploads (seconds). (Optional).
|
|
:param int max_num_uploads: The maximum number of uploads to wait for. (Optional).
|
|
"""
|
|
_Model.wait_for_results(timeout=timeout, max_num_uploads=max_num_uploads)
|
|
|
|
def _get_force_base_model(self):
|
|
if self._base_model:
|
|
return self._base_model
|
|
|
|
# create a new model from the task
|
|
self._base_model = self._task.create_output_model()
|
|
# update the model from the task inputs
|
|
labels = self._task.get_labels_enumeration()
|
|
config_text = self._task._get_model_config_text()
|
|
parent = self._task.output_model_id or self._task.input_model_id
|
|
self._base_model.update(
|
|
labels=self._floating_data.labels or labels,
|
|
design=self._floating_data.design or config_text,
|
|
task_id=self._task.id,
|
|
project_id=self._task.project,
|
|
parent_id=parent,
|
|
name=self._floating_data.name or self._task.name,
|
|
comment=self._floating_data.comment,
|
|
tags=self._floating_data.tags,
|
|
framework=self._floating_data.framework,
|
|
upload_storage_uri=self._floating_data.upload_storage_uri
|
|
)
|
|
|
|
# remove model floating change set, by now they should have matched the task.
|
|
self._floating_data = None
|
|
|
|
# now we have to update the creator task so it points to us
|
|
self._base_model.update_for_task(task_id=self._task.id, override_model_id=self.id)
|
|
|
|
return self._base_model
|
|
|
|
def _get_base_model(self):
|
|
if self._floating_data:
|
|
return self._floating_data
|
|
return self._get_force_base_model()
|
|
|
|
def _get_model_data(self):
|
|
if self._base_model:
|
|
return self._base_model.data
|
|
return self._floating_data
|
|
|
|
def _validate_update(self):
|
|
# test if we can update the model
|
|
if self.id and self.published:
|
|
raise ValueError('Model is published and cannot be changed')
|
|
|
|
return True
|
|
|
|
|
|
class Waitable(object):
|
|
def wait(self, *_, **__):
|
|
return True
|