mirror of
https://github.com/clearml/clearml
synced 2025-02-12 07:35:08 +00:00
2828 lines
106 KiB
Python
2828 lines
106 KiB
Python
import abc
|
|
import os
|
|
import zipfile
|
|
import shutil
|
|
from tempfile import mkstemp
|
|
|
|
import six
|
|
import math
|
|
from typing import List, Dict, Union, Optional, Mapping, TYPE_CHECKING, Sequence, Any, Tuple
|
|
import numpy as np
|
|
|
|
try:
|
|
import pandas as pd
|
|
except ImportError:
|
|
pd = None
|
|
|
|
from .backend_api import Session
|
|
from .backend_api.services import models, projects
|
|
from pathlib2 import Path
|
|
|
|
from .utilities.config import config_dict_to_text, text_to_config_dict
|
|
from .utilities.proxy_object import cast_basic_type
|
|
from .utilities.plotly_reporter import SeriesInfo
|
|
|
|
from .backend_interface.util import (
|
|
validate_dict,
|
|
get_single_result,
|
|
mutually_exclusive,
|
|
exact_match_regex,
|
|
get_or_create_project,
|
|
)
|
|
from .debugging.log import get_logger
|
|
from .errors import UsageError
|
|
from .storage.cache import CacheManager
|
|
from .storage.helper import StorageHelper
|
|
from .storage.util import get_common_path
|
|
from .utilities.enum import Options
|
|
from .backend_interface import Task as _Task
|
|
from .backend_interface.model import create_dummy_model, Model as _Model
|
|
from .backend_interface.session import SendError
|
|
from .config import running_remotely, get_cache_dir
|
|
from .backend_interface.metrics import Reporter, Metrics
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from .task import Task
|
|
|
|
|
|
class Framework(Options):
|
|
"""
|
|
Optional frameworks for output model
|
|
"""
|
|
|
|
tensorflow = "TensorFlow"
|
|
tensorflowjs = "TensorFlow_js"
|
|
tensorflowlite = "TensorFlow_Lite"
|
|
pytorch = "PyTorch"
|
|
torchscript = "TorchScript"
|
|
caffe = "Caffe"
|
|
caffe2 = "Caffe2"
|
|
onnx = "ONNX"
|
|
keras = "Keras"
|
|
mknet = "MXNet"
|
|
cntk = "CNTK"
|
|
torch = "Torch"
|
|
darknet = "Darknet"
|
|
paddlepaddle = "PaddlePaddle"
|
|
scikitlearn = "ScikitLearn"
|
|
xgboost = "XGBoost"
|
|
lightgbm = "LightGBM"
|
|
parquet = "Parquet"
|
|
megengine = "MegEngine"
|
|
catboost = "CatBoost"
|
|
tensorrt = "TensorRT"
|
|
openvino = "OpenVINO"
|
|
|
|
__file_extensions_mapping = {
|
|
".pb": (
|
|
tensorflow,
|
|
tensorflowjs,
|
|
onnx,
|
|
),
|
|
".meta": (tensorflow,),
|
|
".pbtxt": (
|
|
tensorflow,
|
|
onnx,
|
|
),
|
|
".zip": (tensorflow,),
|
|
".tgz": (tensorflow,),
|
|
".tar.gz": (tensorflow,),
|
|
"model.json": (tensorflowjs,),
|
|
".tflite": (tensorflowlite,),
|
|
".pth": (pytorch,),
|
|
".pt": (pytorch,),
|
|
".caffemodel": (caffe,),
|
|
".prototxt": (caffe,),
|
|
"predict_net.pb": (caffe2,),
|
|
"predict_net.pbtxt": (caffe2,),
|
|
".onnx": (onnx,),
|
|
".h5": (keras,),
|
|
".hdf5": (keras,),
|
|
".keras": (keras,),
|
|
".model": (mknet, cntk, xgboost),
|
|
"-symbol.json": (mknet,),
|
|
".cntk": (cntk,),
|
|
".t7": (torch,),
|
|
".cfg": (darknet,),
|
|
"__model__": (paddlepaddle,),
|
|
".pkl": (scikitlearn, keras, xgboost, megengine),
|
|
".parquet": (parquet,),
|
|
".cbm": (catboost,),
|
|
".plan": (tensorrt,),
|
|
}
|
|
|
|
__parent_mapping = {
|
|
"tensorflow": (
|
|
tensorflow,
|
|
tensorflowjs,
|
|
tensorflowlite,
|
|
keras,
|
|
),
|
|
"pytorch": (pytorch,),
|
|
"xgboost": (xgboost,),
|
|
"lightgbm": (lightgbm,),
|
|
"catboost": (catboost,),
|
|
"joblib": (scikitlearn, xgboost),
|
|
}
|
|
|
|
@classmethod
|
|
def get_framework_parents(cls, framework):
|
|
if not framework:
|
|
return []
|
|
parents = []
|
|
for k, v in cls.__parent_mapping.items():
|
|
if framework in v:
|
|
parents.append(k)
|
|
return parents
|
|
|
|
@classmethod
|
|
def _get_file_ext(cls, framework, filename):
|
|
mapping = cls.__file_extensions_mapping
|
|
filename = filename.lower()
|
|
|
|
def find_framework_by_ext(framework_selector):
|
|
for ext, frameworks in mapping.items():
|
|
if frameworks and filename.endswith(ext):
|
|
fw = framework_selector(frameworks)
|
|
if fw:
|
|
return fw, ext
|
|
|
|
# If no framework, try finding first framework matching the extension, otherwise (or if no match) try matching
|
|
# the given extension to the given framework. If no match return an empty extension
|
|
return (
|
|
(
|
|
not framework
|
|
and find_framework_by_ext(lambda frameworks_: frameworks_[0])
|
|
)
|
|
or find_framework_by_ext(
|
|
lambda frameworks_: framework if framework in frameworks_ else None
|
|
)
|
|
or (framework, filename.split(".")[-1] if "." in filename else "")
|
|
)
|
|
|
|
|
|
@six.add_metaclass(abc.ABCMeta)
|
|
class BaseModel(object):
|
|
# noinspection PyProtectedMember
|
|
_archived_tag = _Task.archived_tag
|
|
_package_tag = "package"
|
|
|
|
@property
|
|
def id(self):
|
|
# type: () -> str
|
|
"""
|
|
The ID (system UUID) of the model.
|
|
|
|
:return: The model ID.
|
|
"""
|
|
return self._get_model_data().id
|
|
|
|
@property
|
|
def name(self):
|
|
# type: () -> str
|
|
"""
|
|
The name of the model.
|
|
|
|
:return: The model name.
|
|
"""
|
|
return self._get_model_data().name
|
|
|
|
@name.setter
|
|
def name(self, value):
|
|
# type: (str) -> None
|
|
"""
|
|
Set the model name.
|
|
|
|
:param str value: The model name.
|
|
"""
|
|
self._get_base_model().update(name=value)
|
|
|
|
@property
|
|
def project(self):
|
|
# type: () -> str
|
|
"""
|
|
project ID of the model.
|
|
|
|
:return: project ID (str).
|
|
"""
|
|
data = self._get_model_data()
|
|
return data.project
|
|
|
|
@project.setter
|
|
def project(self, value):
|
|
# type: (str) -> None
|
|
"""
|
|
Set the project ID of the model.
|
|
|
|
:param value: project ID (str).
|
|
|
|
:type value: str
|
|
"""
|
|
self._get_base_model().update(project_id=value)
|
|
|
|
@property
|
|
def comment(self):
|
|
# type: () -> str
|
|
"""
|
|
The comment for the model. Also, use for a model description.
|
|
|
|
:return: The model comment / description.
|
|
"""
|
|
return self._get_model_data().comment
|
|
|
|
@comment.setter
|
|
def comment(self, value):
|
|
# type: (str) -> None
|
|
"""
|
|
Set comment for the model. Also, use for a model description.
|
|
|
|
:param str value: The model comment/description.
|
|
"""
|
|
self._get_base_model().update(comment=value)
|
|
|
|
@property
|
|
def tags(self):
|
|
# type: () -> List[str]
|
|
"""
|
|
A list of tags describing the model.
|
|
|
|
:return: The list of tags.
|
|
"""
|
|
return self._get_model_data().tags
|
|
|
|
@tags.setter
|
|
def tags(self, value):
|
|
# type: (List[str]) -> None
|
|
"""
|
|
Set the list of tags describing the model.
|
|
|
|
:param value: The tags.
|
|
|
|
:type value: list(str)
|
|
"""
|
|
self._get_base_model().update(tags=value)
|
|
|
|
@property
|
|
def system_tags(self):
|
|
# type: () -> List[str]
|
|
"""
|
|
A list of system tags describing the model.
|
|
|
|
:return: The list of tags.
|
|
"""
|
|
data = self._get_model_data()
|
|
return data.system_tags if Session.check_min_api_version("2.3") else data.tags
|
|
|
|
@system_tags.setter
|
|
def system_tags(self, value):
|
|
# type: (List[str]) -> None
|
|
"""
|
|
Set the list of system tags describing the model.
|
|
|
|
:param value: The tags.
|
|
|
|
:type value: list(str)
|
|
"""
|
|
self._get_base_model().update(system_tags=value)
|
|
|
|
@property
|
|
def config_text(self):
|
|
# type: () -> str
|
|
"""
|
|
The configuration as a string. For example, prototxt, an ini file, or Python code to evaluate.
|
|
|
|
:return: The configuration.
|
|
"""
|
|
# noinspection PyProtectedMember
|
|
return _Model._unwrap_design(self._get_model_data().design)
|
|
|
|
@property
|
|
def config_dict(self):
|
|
# type: () -> dict
|
|
"""
|
|
The configuration as a dictionary, parsed from the design text. This usually represents the model configuration.
|
|
For example, prototxt, an ini file, or Python code to evaluate.
|
|
|
|
:return: The configuration.
|
|
"""
|
|
return self._text_to_config_dict(self.config_text)
|
|
|
|
@property
|
|
def labels(self):
|
|
# type: () -> Dict[str, int]
|
|
"""
|
|
The label enumeration of string (label) to integer (value) pairs.
|
|
|
|
|
|
:return: A dictionary containing labels enumeration, where the keys are labels and the values as integers.
|
|
"""
|
|
return self._get_model_data().labels
|
|
|
|
@property
|
|
def task(self):
|
|
# type: () -> str
|
|
"""
|
|
Return the creating task ID
|
|
|
|
:return: The Task ID (str)
|
|
"""
|
|
return self._task.id if self._task else self._get_base_model().task
|
|
|
|
@property
|
|
def url(self):
|
|
# type: () -> str
|
|
"""
|
|
Return the url of the model file (or archived files)
|
|
|
|
:return: The model file URL.
|
|
"""
|
|
return self._get_base_model().uri
|
|
|
|
@property
|
|
def published(self):
|
|
# type: () -> bool
|
|
return self._get_base_model().locked
|
|
|
|
@property
|
|
def framework(self):
|
|
# type: () -> str
|
|
return self._get_model_data().framework
|
|
|
|
def __init__(self, task=None):
|
|
# type: (Task) -> None
|
|
super(BaseModel, self).__init__()
|
|
self._log = get_logger()
|
|
self._task = None
|
|
self._reload_required = False
|
|
self._reporter = None
|
|
self._floating_data = None
|
|
self._name = None
|
|
self._task_connect_name = None
|
|
self._set_task(task)
|
|
|
|
def get_weights(self, raise_on_error=False, force_download=False, extract_archive=False):
|
|
# type: (bool, bool, bool) -> str
|
|
"""
|
|
Download the base model and return the locally stored filename.
|
|
|
|
:param bool raise_on_error: If True, and the artifact could not be downloaded,
|
|
raise ValueError, otherwise return None on failure and output log warning.
|
|
|
|
:param bool force_download: If True, the base model will be downloaded,
|
|
even if the base model is already cached.
|
|
|
|
:param bool extract_archive: If True, the downloaded weights file will be extracted if possible
|
|
|
|
:return: The locally stored file.
|
|
"""
|
|
# download model (synchronously) and return local file
|
|
return self._get_base_model().download_model_weights(
|
|
raise_on_error=raise_on_error, force_download=force_download, extract_archive=extract_archive
|
|
)
|
|
|
|
def get_weights_package(
|
|
self, return_path=False, raise_on_error=False, force_download=False, extract_archive=True
|
|
):
|
|
# type: (bool, bool, bool, bool) -> Optional[Union[str, List[Path]]]
|
|
"""
|
|
Download the base model package into a temporary directory (extract the files), or return a list of the
|
|
locally stored filenames.
|
|
|
|
:param bool return_path: Return the model weights or a list of filenames (Optional)
|
|
|
|
- ``True`` - Download the model weights into a temporary directory, and return the temporary directory path.
|
|
- ``False`` - Return a list of the locally stored filenames. (Default)
|
|
|
|
:param bool raise_on_error: If True, and the artifact could not be downloaded,
|
|
raise ValueError, otherwise return None on failure and output log warning.
|
|
|
|
:param bool force_download: If True, the base artifact will be downloaded,
|
|
even if the artifact is already cached.
|
|
|
|
:param bool extract_archive: If True, the downloaded weights file will be extracted if possible
|
|
|
|
:return: The model weights, or a list of the locally stored filenames.
|
|
if raise_on_error=False, returns None on error.
|
|
"""
|
|
# check if model was packaged
|
|
if not self._is_package():
|
|
raise ValueError("Model is not packaged")
|
|
|
|
# download packaged model
|
|
model_path = self.get_weights(
|
|
raise_on_error=raise_on_error, force_download=force_download, extract_archive=extract_archive
|
|
)
|
|
|
|
if not model_path:
|
|
if raise_on_error:
|
|
raise ValueError(
|
|
"Model package '{}' could not be downloaded".format(self.url)
|
|
)
|
|
return None
|
|
|
|
if return_path:
|
|
return model_path
|
|
|
|
target_files = list(Path(model_path).glob("*"))
|
|
return target_files
|
|
|
|
def report_scalar(self, title, series, value, iteration):
|
|
# type: (str, str, float, int) -> None
|
|
"""
|
|
For explicit reporting, plot a scalar series.
|
|
|
|
:param str title: The title (metric) of the plot. Plot more than one scalar series on the same plot by using
|
|
the same ``title`` for each call to this method.
|
|
:param str series: The series name (variant) of the reported scalar.
|
|
:param float value: The value to plot per iteration.
|
|
:param int iteration: The reported iteration / step (x-axis of the reported time series)
|
|
"""
|
|
self._init_reporter()
|
|
return self._reporter.report_scalar(title=title, series=series, value=float(value), iter=iteration)
|
|
|
|
def report_single_value(self, name, value):
|
|
# type: (str, float) -> None
|
|
"""
|
|
Reports a single value metric (for example, total experiment accuracy or mAP)
|
|
|
|
:param name: Metric's name
|
|
:param value: Metric's value
|
|
"""
|
|
self._init_reporter()
|
|
return self._reporter.report_scalar(title="Summary", series=name, value=float(value), iter=-2**31)
|
|
|
|
def report_histogram(
|
|
self,
|
|
title, # type: str
|
|
series, # type: str
|
|
values, # type: Sequence[Union[int, float]]
|
|
iteration=None, # type: Optional[int]
|
|
labels=None, # type: Optional[List[str]]
|
|
xlabels=None, # type: Optional[List[str]]
|
|
xaxis=None, # type: Optional[str]
|
|
yaxis=None, # type: Optional[str]
|
|
mode=None, # type: Optional[str]
|
|
data_args=None, # type: Optional[dict]
|
|
extra_layout=None # type: Optional[dict]
|
|
):
|
|
"""
|
|
For explicit reporting, plot a (default grouped) histogram.
|
|
Notice this function will not calculate the histogram,
|
|
it assumes the histogram was already calculated in `values`
|
|
|
|
For example:
|
|
|
|
.. code-block:: py
|
|
|
|
vector_series = np.random.randint(10, size=10).reshape(2,5)
|
|
model.report_histogram(title='histogram example', series='histogram series',
|
|
values=vector_series, iteration=0, labels=['A','B'], xaxis='X axis label', yaxis='Y axis label')
|
|
|
|
:param title: The title (metric) of the plot.
|
|
:param series: The series name (variant) of the reported histogram.
|
|
:param values: The series values. A list of floats, or an N-dimensional Numpy array containing
|
|
data for each histogram bar.
|
|
:param iteration: The reported iteration / step. Each ``iteration`` creates another plot.
|
|
:param labels: Labels for each bar group, creating a plot legend labeling each series. (Optional)
|
|
:param xlabels: Labels per entry in each bucket in the histogram (vector), creating a set of labels
|
|
for each histogram bar on the x-axis. (Optional)
|
|
:param xaxis: The x-axis title. (Optional)
|
|
:param yaxis: The y-axis title. (Optional)
|
|
:param mode: Multiple histograms mode, stack / group / relative. Default is 'group'.
|
|
:param data_args: optional dictionary for data configuration, passed directly to plotly
|
|
See full details on the supported configuration: https://plotly.com/javascript/reference/bar/
|
|
example: data_args={'orientation': 'h', 'marker': {'color': 'blue'}}
|
|
:param extra_layout: optional dictionary for layout configuration, passed directly to plotly
|
|
See full details on the supported configuration: https://plotly.com/javascript/reference/bar/
|
|
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
|
|
"""
|
|
self._init_reporter()
|
|
|
|
if not isinstance(values, np.ndarray):
|
|
values = np.array(values)
|
|
|
|
return self._reporter.report_histogram(
|
|
title=title,
|
|
series=series,
|
|
histogram=values,
|
|
iter=iteration or 0,
|
|
labels=labels,
|
|
xlabels=xlabels,
|
|
xtitle=xaxis,
|
|
ytitle=yaxis,
|
|
mode=mode or "group",
|
|
data_args=data_args,
|
|
layout_config=extra_layout
|
|
)
|
|
|
|
def report_vector(
|
|
self,
|
|
title, # type: str
|
|
series, # type: str
|
|
values, # type: Sequence[Union[int, float]]
|
|
iteration=None, # type: Optional[int]
|
|
labels=None, # type: Optional[List[str]]
|
|
xlabels=None, # type: Optional[List[str]]
|
|
xaxis=None, # type: Optional[str]
|
|
yaxis=None, # type: Optional[str]
|
|
mode=None, # type: Optional[str]
|
|
extra_layout=None # type: Optional[dict]
|
|
):
|
|
"""
|
|
For explicit reporting, plot a vector as (default stacked) histogram.
|
|
|
|
For example:
|
|
|
|
.. code-block:: py
|
|
|
|
vector_series = np.random.randint(10, size=10).reshape(2,5)
|
|
model.report_vector(title='vector example', series='vector series', values=vector_series, iteration=0,
|
|
labels=['A','B'], xaxis='X axis label', yaxis='Y axis label')
|
|
|
|
:param title: The title (metric) of the plot.
|
|
:param series: The series name (variant) of the reported histogram.
|
|
:param values: The series values. A list of floats, or an N-dimensional Numpy array containing
|
|
data for each histogram bar.
|
|
:param iteration: The reported iteration / step. Each ``iteration`` creates another plot.
|
|
:param labels: Labels for each bar group, creating a plot legend labeling each series. (Optional)
|
|
:param xlabels: Labels per entry in each bucket in the histogram (vector), creating a set of labels
|
|
for each histogram bar on the x-axis. (Optional)
|
|
:param xaxis: The x-axis title. (Optional)
|
|
:param yaxis: The y-axis title. (Optional)
|
|
:param mode: Multiple histograms mode, stack / group / relative. Default is 'group'.
|
|
:param extra_layout: optional dictionary for layout configuration, passed directly to plotly
|
|
See full details on the supported configuration: https://plotly.com/javascript/reference/layout/
|
|
example: extra_layout={'showlegend': False, 'plot_bgcolor': 'yellow'}
|
|
"""
|
|
self._init_reporter()
|
|
return self.report_histogram(
|
|
title,
|
|
series,
|
|
values,
|
|
iteration or 0,
|
|
labels=labels,
|
|
xlabels=xlabels,
|
|
xaxis=xaxis,
|
|
yaxis=yaxis,
|
|
mode=mode,
|
|
extra_layout=extra_layout,
|
|
)
|
|
|
|
def report_table(
|
|
self,
|
|
title, # type: str
|
|
series, # type: str
|
|
iteration=None, # type: Optional[int]
|
|
table_plot=None, # type: Optional[pd.DataFrame, Sequence[Sequence]]
|
|
csv=None, # type: Optional[str]
|
|
url=None, # type: Optional[str]
|
|
extra_layout=None # type: Optional[dict]
|
|
):
|
|
"""
|
|
For explicit report, report a table plot.
|
|
|
|
One and only one of the following parameters must be provided.
|
|
|
|
- ``table_plot`` - Pandas DataFrame or Table as list of rows (list)
|
|
- ``csv`` - CSV file
|
|
- ``url`` - URL to CSV file
|
|
|
|
For example:
|
|
|
|
.. code-block:: py
|
|
|
|
df = pd.DataFrame({'num_legs': [2, 4, 8, 0],
|
|
'num_wings': [2, 0, 0, 0],
|
|
'num_specimen_seen': [10, 2, 1, 8]},
|
|
index=['falcon', 'dog', 'spider', 'fish'])
|
|
|
|
model.report_table(title='table example',series='pandas DataFrame',iteration=0,table_plot=df)
|
|
|
|
:param title: The title (metric) of the table.
|
|
:param series: The series name (variant) of the reported table.
|
|
:param iteration: The reported iteration / step.
|
|
:param table_plot: The output table plot object
|
|
:param csv: path to local csv file
|
|
:param url: A URL to the location of csv file.
|
|
:param extra_layout: optional dictionary for layout configuration, passed directly to plotly
|
|
See full details on the supported configuration: https://plotly.com/javascript/reference/layout/
|
|
example: extra_layout={'height': 600}
|
|
"""
|
|
mutually_exclusive(
|
|
UsageError, _check_none=True,
|
|
table_plot=table_plot, csv=csv, url=url
|
|
)
|
|
table = table_plot
|
|
if url or csv:
|
|
if not pd:
|
|
raise UsageError(
|
|
"pandas is required in order to support reporting tables using CSV or a URL, "
|
|
"please install the pandas python package"
|
|
)
|
|
if url:
|
|
table = pd.read_csv(url, index_col=[0])
|
|
elif csv:
|
|
table = pd.read_csv(csv, index_col=[0])
|
|
|
|
def replace(dst, *srcs):
|
|
for src in srcs:
|
|
reporter_table.replace(src, dst, inplace=True)
|
|
|
|
if isinstance(table, (list, tuple)):
|
|
reporter_table = table
|
|
else:
|
|
reporter_table = table.fillna(str(np.nan))
|
|
replace("NaN", np.nan, math.nan if six.PY3 else float("nan"))
|
|
replace("Inf", np.inf, math.inf if six.PY3 else float("inf"))
|
|
replace("-Inf", -np.inf, np.NINF, -math.inf if six.PY3 else -float("inf"))
|
|
self._init_reporter()
|
|
return self._reporter.report_table(
|
|
title=title,
|
|
series=series,
|
|
table=reporter_table,
|
|
iteration=iteration or 0,
|
|
layout_config=extra_layout
|
|
)
|
|
|
|
def report_line_plot(
|
|
self,
|
|
title, # type: str
|
|
series, # type: Sequence[SeriesInfo]
|
|
xaxis, # type: str
|
|
yaxis, # type: str
|
|
mode="lines", # type: str
|
|
iteration=None, # type: Optional[int]
|
|
reverse_xaxis=False, # type: bool
|
|
comment=None, # type: Optional[str]
|
|
extra_layout=None # type: Optional[dict]
|
|
):
|
|
"""
|
|
For explicit reporting, plot one or more series as lines.
|
|
|
|
:param str title: The title (metric) of the plot.
|
|
:param list series: All the series data, one list element for each line in the plot.
|
|
:param int iteration: The reported iteration / step.
|
|
:param str xaxis: The x-axis title. (Optional)
|
|
:param str yaxis: The y-axis title. (Optional)
|
|
:param str mode: The type of line plot. The values are:
|
|
|
|
- ``lines`` (default)
|
|
- ``markers``
|
|
- ``lines+markers``
|
|
|
|
:param bool reverse_xaxis: Reverse the x-axis. The values are:
|
|
|
|
- ``True`` - The x-axis is high to low (reversed).
|
|
- ``False`` - The x-axis is low to high (not reversed). (default)
|
|
|
|
:param str comment: A comment displayed with the plot, underneath the title.
|
|
:param dict extra_layout: optional dictionary for layout configuration, passed directly to plotly
|
|
See full details on the supported configuration: https://plotly.com/javascript/reference/scatter/
|
|
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
|
|
"""
|
|
self._init_reporter()
|
|
|
|
# noinspection PyArgumentList
|
|
series = [SeriesInfo(**s) if isinstance(s, dict) else s for s in series]
|
|
|
|
return self._reporter.report_line_plot(
|
|
title=title,
|
|
series=series,
|
|
iter=iteration or 0,
|
|
xtitle=xaxis,
|
|
ytitle=yaxis,
|
|
mode=mode,
|
|
reverse_xaxis=reverse_xaxis,
|
|
comment=comment,
|
|
layout_config=extra_layout
|
|
)
|
|
|
|
def report_scatter2d(
|
|
self,
|
|
title, # type: str
|
|
series, # type: str
|
|
scatter, # type: Union[Sequence[Tuple[float, float]], np.ndarray]
|
|
iteration=None, # type: Optional[int]
|
|
xaxis=None, # type: Optional[str]
|
|
yaxis=None, # type: Optional[str]
|
|
labels=None, # type: Optional[List[str]]
|
|
mode="line", # type: str
|
|
comment=None, # type: Optional[str]
|
|
extra_layout=None, # type: Optional[dict]
|
|
):
|
|
"""
|
|
For explicit reporting, report a 2d scatter plot.
|
|
|
|
For example:
|
|
|
|
.. code-block:: py
|
|
|
|
scatter2d = np.hstack((np.atleast_2d(np.arange(0, 10)).T, np.random.randint(10, size=(10, 1))))
|
|
model.report_scatter2d(title="example_scatter", series="series", iteration=0, scatter=scatter2d,
|
|
xaxis="title x", yaxis="title y")
|
|
|
|
Plot multiple 2D scatter series on the same plot by passing the same ``title`` and ``iteration`` values
|
|
to this method:
|
|
|
|
.. code-block:: py
|
|
|
|
scatter2d_1 = np.hstack((np.atleast_2d(np.arange(0, 10)).T, np.random.randint(10, size=(10, 1))))
|
|
model.report_scatter2d(title="example_scatter", series="series_1", iteration=1, scatter=scatter2d_1,
|
|
xaxis="title x", yaxis="title y")
|
|
|
|
scatter2d_2 = np.hstack((np.atleast_2d(np.arange(0, 10)).T, np.random.randint(10, size=(10, 1))))
|
|
model.report_scatter2d("example_scatter", "series_2", iteration=1, scatter=scatter2d_2,
|
|
xaxis="title x", yaxis="title y")
|
|
|
|
:param str title: The title (metric) of the plot.
|
|
:param str series: The series name (variant) of the reported scatter plot.
|
|
:param list scatter: The scatter data. numpy.ndarray or list of (pairs of x,y) scatter:
|
|
:param int iteration: The reported iteration / step.
|
|
:param str xaxis: The x-axis title. (Optional)
|
|
:param str yaxis: The y-axis title. (Optional)
|
|
:param list(str) labels: Labels per point in the data assigned to the ``scatter`` parameter. The labels must be
|
|
in the same order as the data.
|
|
:param str mode: The type of scatter plot. The values are:
|
|
|
|
- ``lines``
|
|
- ``markers``
|
|
- ``lines+markers``
|
|
|
|
:param str comment: A comment displayed with the plot, underneath the title.
|
|
:param dict extra_layout: optional dictionary for layout configuration, passed directly to plotly
|
|
See full details on the supported configuration: https://plotly.com/javascript/reference/scatter/
|
|
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
|
|
"""
|
|
self._init_reporter()
|
|
|
|
if not isinstance(scatter, np.ndarray):
|
|
if not isinstance(scatter, list):
|
|
scatter = list(scatter)
|
|
scatter = np.array(scatter)
|
|
|
|
return self._reporter.report_2d_scatter(
|
|
title=title,
|
|
series=series,
|
|
data=scatter,
|
|
iter=iteration or 0,
|
|
mode=mode,
|
|
xtitle=xaxis,
|
|
ytitle=yaxis,
|
|
labels=labels,
|
|
comment=comment,
|
|
layout_config=extra_layout,
|
|
)
|
|
|
|
def report_scatter3d(
|
|
self,
|
|
title, # type: str
|
|
series, # type: str
|
|
scatter, # type: Union[Sequence[Tuple[float, float, float]], np.ndarray]
|
|
iteration=None, # type: Optional[int]
|
|
xaxis=None, # type: Optional[str]
|
|
yaxis=None, # type: Optional[str]
|
|
zaxis=None, # type: Optional[str]
|
|
labels=None, # type: Optional[List[str]]
|
|
mode="markers", # type: str
|
|
fill=False, # type: bool
|
|
comment=None, # type: Optional[str]
|
|
extra_layout=None # type: Optional[dict]
|
|
):
|
|
"""
|
|
For explicit reporting, plot a 3d scatter graph (with markers).
|
|
|
|
:param str title: The title (metric) of the plot.
|
|
:param str series: The series name (variant) of the reported scatter plot.
|
|
:param Union[numpy.ndarray, list] scatter: The scatter data.
|
|
list of (pairs of x,y,z), list of series [[(x1,y1,z1)...]], or numpy.ndarray
|
|
:param int iteration: The reported iteration / step.
|
|
:param str xaxis: The x-axis title. (Optional)
|
|
:param str yaxis: The y-axis title. (Optional)
|
|
:param str zaxis: The z-axis title. (Optional)
|
|
:param list(str) labels: Labels per point in the data assigned to the ``scatter`` parameter. The labels must be
|
|
in the same order as the data.
|
|
:param str mode: The type of scatter plot. The values are: ``lines``, ``markers``, ``lines+markers``.
|
|
|
|
For example:
|
|
|
|
.. code-block:: py
|
|
|
|
scatter3d = np.random.randint(10, size=(10, 3))
|
|
model.report_scatter3d(title="example_scatter_3d", series="series_xyz", iteration=1, scatter=scatter3d,
|
|
xaxis="title x", yaxis="title y", zaxis="title z")
|
|
|
|
:param bool fill: Fill the area under the curve. The values are:
|
|
|
|
- ``True`` - Fill
|
|
- ``False`` - Do not fill (default)
|
|
|
|
:param str comment: A comment displayed with the plot, underneath the title.
|
|
:param dict extra_layout: optional dictionary for layout configuration, passed directly to plotly
|
|
See full details on the supported configuration: https://plotly.com/javascript/reference/scatter3d/
|
|
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
|
|
"""
|
|
self._init_reporter()
|
|
|
|
# check if multiple series
|
|
multi_series = (
|
|
isinstance(scatter, list)
|
|
and (
|
|
isinstance(scatter[0], np.ndarray)
|
|
or (
|
|
scatter[0]
|
|
and isinstance(scatter[0], list)
|
|
and isinstance(scatter[0][0], list)
|
|
)
|
|
)
|
|
)
|
|
|
|
if not multi_series:
|
|
if not isinstance(scatter, np.ndarray):
|
|
if not isinstance(scatter, list):
|
|
scatter = list(scatter)
|
|
scatter = np.array(scatter)
|
|
try:
|
|
scatter = scatter.astype(np.float32)
|
|
except ValueError:
|
|
pass
|
|
|
|
return self._reporter.report_3d_scatter(
|
|
title=title,
|
|
series=series,
|
|
data=scatter,
|
|
iter=iteration or 0,
|
|
labels=labels,
|
|
mode=mode,
|
|
fill=fill,
|
|
comment=comment,
|
|
xtitle=xaxis,
|
|
ytitle=yaxis,
|
|
ztitle=zaxis,
|
|
layout_config=extra_layout
|
|
)
|
|
|
|
def report_confusion_matrix(
|
|
self,
|
|
title, # type: str
|
|
series, # type: str
|
|
matrix, # type: np.ndarray
|
|
iteration=None, # type: Optional[int]
|
|
xaxis=None, # type: Optional[str]
|
|
yaxis=None, # type: Optional[str]
|
|
xlabels=None, # type: Optional[List[str]]
|
|
ylabels=None, # type: Optional[List[str]]
|
|
yaxis_reversed=False, # type: bool
|
|
comment=None, # type: Optional[str]
|
|
extra_layout=None # type: Optional[dict]
|
|
):
|
|
"""
|
|
For explicit reporting, plot a heat-map matrix.
|
|
|
|
For example:
|
|
|
|
.. code-block:: py
|
|
|
|
confusion = np.random.randint(10, size=(10, 10))
|
|
model.report_confusion_matrix("example confusion matrix", "ignored", iteration=1, matrix=confusion,
|
|
xaxis="title X", yaxis="title Y")
|
|
|
|
:param str title: The title (metric) of the plot.
|
|
:param str series: The series name (variant) of the reported confusion matrix.
|
|
:param numpy.ndarray matrix: A heat-map matrix (example: confusion matrix)
|
|
:param int iteration: The reported iteration / step.
|
|
:param str xaxis: The x-axis title. (Optional)
|
|
:param str yaxis: The y-axis title. (Optional)
|
|
:param list(str) xlabels: Labels for each column of the matrix. (Optional)
|
|
:param list(str) ylabels: Labels for each row of the matrix. (Optional)
|
|
:param bool yaxis_reversed: If False 0,0 is at the bottom left corner. If True, 0,0 is at the top left corner
|
|
:param str comment: A comment displayed with the plot, underneath the title.
|
|
:param dict extra_layout: optional dictionary for layout configuration, passed directly to plotly
|
|
See full details on the supported configuration: https://plotly.com/javascript/reference/heatmap/
|
|
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
|
|
"""
|
|
self._init_reporter()
|
|
|
|
if not isinstance(matrix, np.ndarray):
|
|
matrix = np.array(matrix)
|
|
|
|
return self._reporter.report_value_matrix(
|
|
title=title,
|
|
series=series,
|
|
data=matrix.astype(np.float32),
|
|
iter=iteration or 0,
|
|
xtitle=xaxis,
|
|
ytitle=yaxis,
|
|
xlabels=xlabels,
|
|
ylabels=ylabels,
|
|
yaxis_reversed=yaxis_reversed,
|
|
comment=comment,
|
|
layout_config=extra_layout
|
|
)
|
|
|
|
def report_matrix(
|
|
self,
|
|
title, # type: str
|
|
series, # type: str
|
|
matrix, # type: np.ndarray
|
|
iteration=None, # type: Optional[int]
|
|
xaxis=None, # type: Optional[str]
|
|
yaxis=None, # type: Optional[str]
|
|
xlabels=None, # type: Optional[List[str]]
|
|
ylabels=None, # type: Optional[List[str]]
|
|
yaxis_reversed=False, # type: bool
|
|
extra_layout=None # type: Optional[dict]
|
|
):
|
|
"""
|
|
For explicit reporting, plot a confusion matrix.
|
|
|
|
.. note::
|
|
This method is the same as :meth:`Model.report_confusion_matrix`.
|
|
|
|
:param str title: The title (metric) of the plot.
|
|
:param str series: The series name (variant) of the reported confusion matrix.
|
|
:param numpy.ndarray matrix: A heat-map matrix (example: confusion matrix)
|
|
:param int iteration: The reported iteration / step.
|
|
:param str xaxis: The x-axis title. (Optional)
|
|
:param str yaxis: The y-axis title. (Optional)
|
|
:param list(str) xlabels: Labels for each column of the matrix. (Optional)
|
|
:param list(str) ylabels: Labels for each row of the matrix. (Optional)
|
|
:param bool yaxis_reversed: If False, 0,0 is at the bottom left corner. If True, 0,0 is at the top left corner
|
|
:param dict extra_layout: optional dictionary for layout configuration, passed directly to plotly
|
|
See full details on the supported configuration: https://plotly.com/javascript/reference/heatmap/
|
|
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
|
|
"""
|
|
self._init_reporter()
|
|
return self.report_confusion_matrix(
|
|
title,
|
|
series,
|
|
matrix,
|
|
iteration or 0,
|
|
xaxis=xaxis,
|
|
yaxis=yaxis,
|
|
xlabels=xlabels,
|
|
ylabels=ylabels,
|
|
yaxis_reversed=yaxis_reversed,
|
|
extra_layout=extra_layout
|
|
)
|
|
|
|
def report_surface(
|
|
self,
|
|
title, # type: str
|
|
series, # type: str
|
|
matrix, # type: np.ndarray
|
|
iteration=None, # type: Optional[int]
|
|
xaxis=None, # type: Optional[str]
|
|
yaxis=None, # type: Optional[str]
|
|
zaxis=None, # type: Optional[str]
|
|
xlabels=None, # type: Optional[List[str]]
|
|
ylabels=None, # type: Optional[List[str]]
|
|
camera=None, # type: Optional[Sequence[float]]
|
|
comment=None, # type: Optional[str]
|
|
extra_layout=None # type: Optional[dict]
|
|
):
|
|
"""
|
|
For explicit reporting, report a 3d surface plot.
|
|
|
|
.. note::
|
|
This method plots the same data as :meth:`Model.report_confusion_matrix`, but presents the
|
|
data as a surface diagram not a confusion matrix.
|
|
|
|
.. code-block:: py
|
|
|
|
surface_matrix = np.random.randint(10, size=(10, 10))
|
|
model.report_surface("example surface", "series", iteration=0, matrix=surface_matrix,
|
|
xaxis="title X", yaxis="title Y", zaxis="title Z")
|
|
|
|
:param str title: The title (metric) of the plot.
|
|
:param str series: The series name (variant) of the reported surface.
|
|
:param numpy.ndarray matrix: A heat-map matrix (example: confusion matrix)
|
|
:param int iteration: The reported iteration / step.
|
|
:param str xaxis: The x-axis title. (Optional)
|
|
:param str yaxis: The y-axis title. (Optional)
|
|
:param str zaxis: The z-axis title. (Optional)
|
|
:param list(str) xlabels: Labels for each column of the matrix. (Optional)
|
|
:param list(str) ylabels: Labels for each row of the matrix. (Optional)
|
|
:param list(float) camera: X,Y,Z coordinates indicating the camera position. The default value is ``(1,1,1)``.
|
|
:param str comment: A comment displayed with the plot, underneath the title.
|
|
:param dict extra_layout: optional dictionary for layout configuration, passed directly to plotly
|
|
See full details on the supported configuration: https://plotly.com/javascript/reference/surface/
|
|
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
|
|
"""
|
|
self._init_reporter()
|
|
|
|
if not isinstance(matrix, np.ndarray):
|
|
matrix = np.array(matrix)
|
|
|
|
return self._reporter.report_value_surface(
|
|
title=title,
|
|
series=series,
|
|
data=matrix.astype(np.float32),
|
|
iter=iteration or 0,
|
|
xlabels=xlabels,
|
|
ylabels=ylabels,
|
|
xtitle=xaxis,
|
|
ytitle=yaxis,
|
|
ztitle=zaxis,
|
|
camera=camera,
|
|
comment=comment,
|
|
layout_config=extra_layout
|
|
)
|
|
|
|
def publish(self):
|
|
# type: () -> ()
|
|
"""
|
|
Set the model to the status ``published`` and for public use. If the model's status is already ``published``,
|
|
then this method is a no-op.
|
|
"""
|
|
|
|
if not self.published:
|
|
self._get_base_model().publish()
|
|
|
|
def archive(self):
|
|
# type: () -> ()
|
|
"""
|
|
Archive the model. If the model is already archived, this is a no-op
|
|
"""
|
|
try:
|
|
self._get_base_model().archive()
|
|
except Exception:
|
|
pass
|
|
|
|
def unarchive(self):
|
|
# type: () -> ()
|
|
"""
|
|
Unarchive the model. If the model is not archived, this is a no-op
|
|
"""
|
|
try:
|
|
self._get_base_model().unarchive()
|
|
except Exception:
|
|
pass
|
|
|
|
def _init_reporter(self):
|
|
if self._reporter:
|
|
return
|
|
self._base_model = self._get_force_base_model()
|
|
metrics_manager = Metrics(
|
|
session=_Model._get_default_session(),
|
|
storage_uri=None,
|
|
task=self, # this is fine, the ID of the model will be fetched here
|
|
for_model=True
|
|
)
|
|
self._reporter = Reporter(metrics=metrics_manager, task=self, for_model=True)
|
|
|
|
def _running_remotely(self):
|
|
# type: () -> ()
|
|
return bool(running_remotely() and self._task is not None)
|
|
|
|
def _set_task(self, value):
|
|
# type: (_Task) -> ()
|
|
if value is not None and not isinstance(value, _Task):
|
|
raise ValueError("task argument must be of Task type")
|
|
self._task = value
|
|
|
|
@abc.abstractmethod
|
|
def _get_model_data(self):
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def _get_base_model(self):
|
|
pass
|
|
|
|
def _set_package_tag(self):
|
|
if self._package_tag not in self.system_tags:
|
|
self.system_tags.append(self._package_tag)
|
|
self._get_base_model().edit(system_tags=self.system_tags)
|
|
|
|
def _is_package(self):
|
|
return self._package_tag in (self.system_tags or [])
|
|
|
|
@staticmethod
|
|
def _config_dict_to_text(config):
|
|
if not isinstance(config, six.string_types) and not isinstance(config, dict):
|
|
raise ValueError(
|
|
"Model configuration only supports dictionary or string objects"
|
|
)
|
|
return config_dict_to_text(config)
|
|
|
|
@staticmethod
|
|
def _text_to_config_dict(text):
|
|
if not isinstance(text, six.string_types):
|
|
raise ValueError("Model configuration parsing only supports string")
|
|
return text_to_config_dict(text)
|
|
|
|
@staticmethod
|
|
def _resolve_config(config_text=None, config_dict=None):
|
|
mutually_exclusive(
|
|
config_text=config_text,
|
|
config_dict=config_dict,
|
|
_require_at_least_one=False,
|
|
)
|
|
if config_dict:
|
|
return InputModel._config_dict_to_text(config_dict)
|
|
|
|
return config_text
|
|
|
|
def set_metadata(self, key, value, v_type=None):
|
|
# type: (str, str, Optional[str]) -> bool
|
|
"""
|
|
Set one metadata entry. All parameters must be strings or castable to strings
|
|
|
|
:param key: Key of the metadata entry
|
|
:param value: Value of the metadata entry
|
|
:param v_type: Type of the metadata entry
|
|
|
|
:return: True if the metadata was set and False otherwise
|
|
"""
|
|
if not self._base_model:
|
|
self._base_model = self._get_force_base_model()
|
|
self._reload_required = (
|
|
_Model._get_default_session()
|
|
.send(
|
|
models.AddOrUpdateMetadataRequest(
|
|
metadata=[
|
|
{
|
|
"key": str(key),
|
|
"value": str(value),
|
|
"type": str(v_type)
|
|
if str(v_type)
|
|
in (
|
|
"float",
|
|
"int",
|
|
"bool",
|
|
"str",
|
|
"basestring",
|
|
"list",
|
|
"tuple",
|
|
"dict",
|
|
)
|
|
else str(None),
|
|
}
|
|
],
|
|
model=self.id,
|
|
replace_metadata=False,
|
|
)
|
|
)
|
|
.ok()
|
|
)
|
|
return self._reload_required
|
|
|
|
def get_metadata(self, key):
|
|
# type: (str) -> Optional[str]
|
|
"""
|
|
Get one metadata entry value (as a string) based on its key. See `Model.get_metadata_casted`
|
|
if you wish to cast the value to its type (if possible)
|
|
|
|
:param key: Key of the metadata entry you want to get
|
|
|
|
:return: String representation of the value of the metadata entry or None if the entry was not found
|
|
"""
|
|
if not self._base_model:
|
|
self._base_model = self._get_force_base_model()
|
|
self._reload_if_required()
|
|
return self.get_all_metadata().get(str(key), {}).get("value")
|
|
|
|
def get_metadata_casted(self, key):
|
|
# type: (str) -> Optional[str]
|
|
"""
|
|
Get one metadata entry based on its key, casted to its type if possible
|
|
|
|
:param key: Key of the metadata entry you want to get
|
|
|
|
:return: The value of the metadata entry, casted to its type (if not possible,
|
|
the string representation will be returned) or None if the entry was not found
|
|
"""
|
|
if not self._base_model:
|
|
self._base_model = self._get_force_base_model()
|
|
key = str(key)
|
|
metadata = self.get_all_metadata()
|
|
if key not in metadata:
|
|
return None
|
|
return cast_basic_type(metadata[key].get("value"), metadata[key].get("type"))
|
|
|
|
def get_all_metadata(self):
|
|
# type: () -> Dict[str, Dict[str, str]]
|
|
"""
|
|
See `Model.get_all_metadata_casted` if you wish to cast the value to its type (if possible)
|
|
|
|
:return: Get all metadata as a dictionary of format Dict[key, Dict[value, type]]. The key, value and type
|
|
entries are all strings. Note that each entry might have an additional 'key' entry, repeating the key
|
|
"""
|
|
if not self._base_model:
|
|
self._base_model = self._get_force_base_model()
|
|
self._reload_if_required()
|
|
return self._get_model_data().metadata or {}
|
|
|
|
def get_all_metadata_casted(self):
|
|
# type: () -> Dict[str, Dict[str, Any]]
|
|
"""
|
|
:return: Get all metadata as a dictionary of format Dict[key, Dict[value, type]]. The key and type
|
|
entries are strings. The value is cast to its type if possible. Note that each entry might
|
|
have an additional 'key' entry, repeating the key
|
|
"""
|
|
if not self._base_model:
|
|
self._base_model = self._get_force_base_model()
|
|
self._reload_if_required()
|
|
result = {}
|
|
metadata = self.get_all_metadata()
|
|
for key, metadata_entry in metadata.items():
|
|
result[key] = cast_basic_type(
|
|
metadata_entry.get("value"), metadata_entry.get("type")
|
|
)
|
|
return result
|
|
|
|
def set_all_metadata(self, metadata, replace=True):
|
|
# type: (Dict[str, Dict[str, str]], bool) -> bool
|
|
"""
|
|
Set metadata based on the given parameters. Allows replacing all entries or updating the current entries.
|
|
|
|
:param metadata: A dictionary of format Dict[key, Dict[value, type]] representing the metadata you want to set
|
|
:param replace: If True, replace all metadata with the entries in the `metadata` parameter. If False,
|
|
keep the old metadata and update it with the entries in the `metadata` parameter (add or change it)
|
|
|
|
:return: True if the metadata was set and False otherwise
|
|
"""
|
|
if not self._base_model:
|
|
self._base_model = self._get_force_base_model()
|
|
metadata_array = [
|
|
{
|
|
"key": str(k),
|
|
"value": str(v_t.get("value")),
|
|
"type": str(v_t.get("type")),
|
|
}
|
|
for k, v_t in metadata.items()
|
|
]
|
|
self._reload_required = (
|
|
_Model._get_default_session()
|
|
.send(
|
|
models.AddOrUpdateMetadataRequest(
|
|
metadata=metadata_array, model=self.id, replace_metadata=replace
|
|
)
|
|
)
|
|
.ok()
|
|
)
|
|
return self._reload_required
|
|
|
|
def _reload_if_required(self):
|
|
if not self._reload_required:
|
|
return
|
|
self._get_base_model().reload()
|
|
self._reload_required = False
|
|
|
|
def _update_base_model(self, model_name=None, task_model_entry=None):
|
|
if not self._task:
|
|
return self._base_model
|
|
# update the model from the task inputs
|
|
labels = self._task.get_labels_enumeration()
|
|
# noinspection PyProtectedMember
|
|
config_text = self._task._get_model_config_text()
|
|
model_name = (
|
|
model_name or self._name or (self._floating_data.name if self._floating_data else None) or self._task.name
|
|
)
|
|
# noinspection PyBroadException
|
|
try:
|
|
task_model_entry = (
|
|
task_model_entry
|
|
or self._task_connect_name
|
|
or Path(self._get_model_data().uri).stem
|
|
)
|
|
except Exception:
|
|
pass
|
|
parent = self._task.input_models_id.get(task_model_entry)
|
|
self._base_model.update(
|
|
labels=(self._floating_data.labels if self._floating_data else None) or labels,
|
|
design=(self._floating_data.design if self._floating_data else None) or config_text,
|
|
task_id=self._task.id,
|
|
project_id=self._task.project,
|
|
parent_id=parent,
|
|
name=model_name,
|
|
comment=self._floating_data.comment if self._floating_data else None,
|
|
tags=self._floating_data.tags if self._floating_data else None,
|
|
framework=self._floating_data.framework if self._floating_data else None,
|
|
upload_storage_uri=self._floating_data.upload_storage_uri if self._floating_data else None,
|
|
)
|
|
|
|
# remove model floating change set, by now they should have matched the task.
|
|
self._floating_data = None
|
|
|
|
# now we have to update the creator task so it points to us
|
|
if str(self._task.status) not in (
|
|
str(self._task.TaskStatusEnum.created),
|
|
str(self._task.TaskStatusEnum.in_progress),
|
|
):
|
|
self._log.warning(
|
|
"Could not update last created model in Task {}, "
|
|
"Task status '{}' cannot be updated".format(
|
|
self._task.id, self._task.status
|
|
)
|
|
)
|
|
elif task_model_entry:
|
|
self._base_model.update_for_task(
|
|
task_id=self._task.id,
|
|
model_id=self.id,
|
|
type_="output",
|
|
name=task_model_entry,
|
|
)
|
|
|
|
return self._base_model
|
|
|
|
def _get_force_base_model(self, model_name=None, task_model_entry=None):
|
|
if self._base_model:
|
|
return self._base_model
|
|
if not self._task:
|
|
return None
|
|
|
|
# create a new model from the task
|
|
# noinspection PyProtectedMember
|
|
self._base_model = self._task._get_output_model(model_id=None)
|
|
return self._update_base_model(model_name=model_name, task_model_entry=task_model_entry)
|
|
|
|
|
|
class Model(BaseModel):
|
|
"""
|
|
Represent an existing model in the system, search by model id.
|
|
The Model will be read-only and can be used to pre initialize a network
|
|
"""
|
|
|
|
def __init__(self, model_id):
|
|
# type: (str) ->None
|
|
"""
|
|
Load model based on id, returned object is read-only and can be connected to a task
|
|
|
|
Notice, we can override the input model when running remotely
|
|
|
|
:param model_id: ID (string)
|
|
"""
|
|
super(Model, self).__init__()
|
|
self._base_model_id = model_id
|
|
self._base_model = None
|
|
|
|
def get_local_copy(
|
|
self, extract_archive=None, raise_on_error=False, force_download=False
|
|
):
|
|
# type: (Optional[bool], bool, bool) -> str
|
|
"""
|
|
Retrieve a valid link to the model file(s).
|
|
If the model URL is a file system link, it will be returned directly.
|
|
If the model URL points to a remote location (http/s3/gs etc.),
|
|
it will download the file(s) and return the temporary location of the downloaded model.
|
|
|
|
:param bool extract_archive: If True, the local copy will be extracted if possible. If False,
|
|
the local copy will not be extracted. If None (default), the downloaded file will be extracted
|
|
if the model is a package.
|
|
:param bool raise_on_error: If True, and the artifact could not be downloaded,
|
|
raise ValueError, otherwise return None on failure and output log warning.
|
|
:param bool force_download: If True, the artifact will be downloaded,
|
|
even if the model artifact is already cached.
|
|
|
|
:return: A local path to the model (or a downloaded copy of it).
|
|
"""
|
|
if self._is_package():
|
|
return self.get_weights_package(
|
|
return_path=True,
|
|
raise_on_error=raise_on_error,
|
|
force_download=force_download,
|
|
extract_archive=True if extract_archive is None else extract_archive
|
|
)
|
|
return self.get_weights(
|
|
raise_on_error=raise_on_error,
|
|
force_download=force_download,
|
|
extract_archive=False if extract_archive is None else extract_archive
|
|
)
|
|
|
|
def _get_base_model(self):
|
|
if self._base_model:
|
|
return self._base_model
|
|
|
|
if not self._base_model_id:
|
|
# this shouldn't actually happen
|
|
raise Exception("Missing model ID, cannot create an empty model")
|
|
self._base_model = _Model(
|
|
upload_storage_uri=None,
|
|
cache_dir=get_cache_dir(),
|
|
model_id=self._base_model_id,
|
|
)
|
|
return self._base_model
|
|
|
|
def _get_model_data(self):
|
|
return self._get_base_model().data
|
|
|
|
@classmethod
|
|
def query_models(
|
|
cls,
|
|
project_name=None, # type: Optional[str]
|
|
model_name=None, # type: Optional[str]
|
|
tags=None, # type: Optional[Sequence[str]]
|
|
only_published=False, # type: bool
|
|
include_archived=False, # type: bool
|
|
max_results=None, # type: Optional[int]
|
|
metadata=None, # type: Optional[Dict[str, str]]
|
|
):
|
|
# type: (...) -> List[Model]
|
|
"""
|
|
Return Model objects from the project artifactory.
|
|
Filter based on project-name / model-name / tags.
|
|
List is always returned sorted by descending last update time (i.e. latest model is the first in the list)
|
|
|
|
:param project_name: Optional, filter based project name string, if not given query models from all projects
|
|
:param model_name: Optional Model name as shown in the model artifactory
|
|
:param tags: Filter based on the requested list of tags (strings).
|
|
To exclude a tag add "-" prefix to the tag. Example: ``["production", "verified", "-qa"]``.
|
|
The default behaviour is to join all tags with a logical "OR" operator.
|
|
To join all tags with a logical "AND" operator instead, use "__$all" as the first string, for example:
|
|
|
|
.. code-block:: py
|
|
|
|
["__$all", "best", "model", "ever"]
|
|
|
|
To join all tags with AND, but exclude a tag use "__$not" before the excluded tag, for example:
|
|
|
|
.. code-block:: py
|
|
|
|
["__$all", "best", "model", "ever", "__$not", "internal", "__$not", "test"]
|
|
|
|
The "OR" and "AND" operators apply to all tags that follow them until another operator is specified.
|
|
The NOT operator applies only to the immediately following tag.
|
|
For example:
|
|
|
|
.. code-block:: py
|
|
|
|
["__$all", "a", "b", "c", "__$or", "d", "__$not", "e", "__$and", "__$or" "f", "g"]
|
|
|
|
This example means ("a" AND "b" AND "c" AND ("d" OR NOT "e") AND ("f" OR "g")).
|
|
See https://clear.ml/docs/latest/docs/clearml_sdk/model_sdk#tag-filters for details.
|
|
:param only_published: If True, only return published models.
|
|
:param include_archived: If True, return archived models.
|
|
:param max_results: Optional return the last X models,
|
|
sorted by last update time (from the most recent to the least).
|
|
:param metadata: Filter based on metadata. This parameter is a dictionary. Notice that the type of the
|
|
metadata field is not required.
|
|
|
|
:return: ModeList of Models objects
|
|
"""
|
|
if project_name:
|
|
# noinspection PyProtectedMember
|
|
res = _Model._get_default_session().send(
|
|
projects.GetAllRequest(
|
|
name=exact_match_regex(project_name),
|
|
only_fields=["id", "name", "last_update"],
|
|
)
|
|
)
|
|
project = get_single_result(
|
|
entity="project", query=project_name, results=res.response.projects
|
|
)
|
|
else:
|
|
project = None
|
|
|
|
only_fields = ["id", "created", "system_tags"]
|
|
|
|
extra_fields = {
|
|
"metadata.{}.value".format(k): v for k, v in (metadata or {}).items()
|
|
}
|
|
|
|
models_fetched = []
|
|
|
|
page = 0
|
|
page_size = 500
|
|
results_left = max_results if max_results is not None else float("inf")
|
|
while True:
|
|
# noinspection PyProtectedMember
|
|
res = _Model._get_default_session().send(
|
|
models.GetAllRequest(
|
|
project=[project.id] if project else None,
|
|
name=exact_match_regex(model_name)
|
|
if model_name is not None
|
|
else None,
|
|
only_fields=only_fields,
|
|
tags=tags or None,
|
|
system_tags=["-" + cls._archived_tag]
|
|
if not include_archived
|
|
else None,
|
|
ready=True if only_published else None,
|
|
order_by=["-created"],
|
|
page=page,
|
|
page_size=page_size if results_left > page_size else results_left,
|
|
_allow_extra_fields_=True,
|
|
**extra_fields
|
|
)
|
|
)
|
|
if not res.response.models:
|
|
break
|
|
models_fetched.extend(res.response.models)
|
|
results_left -= len(res.response.models)
|
|
if results_left <= 0 or len(res.response.models) < page_size:
|
|
break
|
|
|
|
page += 1
|
|
|
|
return [Model(model_id=m.id) for m in models_fetched]
|
|
|
|
@property
|
|
def id(self):
|
|
# type: () -> str
|
|
return self._base_model_id if self._base_model_id else super(Model, self).id
|
|
|
|
@classmethod
|
|
def remove(
|
|
cls, model, delete_weights_file=True, force=False, raise_on_errors=False
|
|
):
|
|
# type: (Union[str, Model], bool, bool, bool) -> bool
|
|
"""
|
|
Remove a model from the model repository.
|
|
Optional, delete the model weights file from the remote storage.
|
|
|
|
:param model: Model ID or Model object to remove
|
|
:param delete_weights_file: If True (default), delete the weights file from the remote storage
|
|
:param force: If True, remove model even if other Tasks are using this model. default False.
|
|
:param raise_on_errors: If True, throw ValueError if something went wrong, default False.
|
|
:return: True if Model was removed successfully
|
|
partial removal returns False, i.e. Model was deleted but weights file deletion failed
|
|
"""
|
|
if isinstance(model, str):
|
|
model = Model(model_id=model)
|
|
|
|
# noinspection PyBroadException
|
|
try:
|
|
weights_url = model.url
|
|
except Exception:
|
|
if raise_on_errors:
|
|
raise ValueError("Could not find model id={}".format(model.id))
|
|
return False
|
|
|
|
try:
|
|
# noinspection PyProtectedMember
|
|
res = _Model._get_default_session().send(
|
|
models.DeleteRequest(model.id, force=force),
|
|
)
|
|
response = res.wait()
|
|
if not response.ok():
|
|
if raise_on_errors:
|
|
raise ValueError(
|
|
"Could not remove model id={}: {}".format(
|
|
model.id, response.meta
|
|
)
|
|
)
|
|
return False
|
|
except SendError as ex:
|
|
if raise_on_errors:
|
|
raise ValueError(
|
|
"Could not remove model id={}: {}".format(model.id, ex)
|
|
)
|
|
return False
|
|
except ValueError:
|
|
if raise_on_errors:
|
|
raise
|
|
return False
|
|
except Exception as ex:
|
|
if raise_on_errors:
|
|
raise ValueError(
|
|
"Could not remove model id={}: {}".format(model.id, ex)
|
|
)
|
|
return False
|
|
|
|
if not delete_weights_file:
|
|
return True
|
|
|
|
helper = StorageHelper.get(url=weights_url)
|
|
try:
|
|
if not helper.delete(weights_url):
|
|
if raise_on_errors:
|
|
raise ValueError(
|
|
"Could not remove model id={} weights file: {}".format(
|
|
model.id, weights_url
|
|
)
|
|
)
|
|
return False
|
|
except Exception as ex:
|
|
if raise_on_errors:
|
|
raise ValueError(
|
|
"Could not remove model id={} weights file '{}': {}".format(
|
|
model.id, weights_url, ex
|
|
)
|
|
)
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
class InputModel(Model):
|
|
"""
|
|
Load an existing model in the system, search by model ID.
|
|
The Model will be read-only and can be used to pre initialize a network.
|
|
We can connect the model to a task as input model, then when running remotely override it with the UI.
|
|
"""
|
|
|
|
# noinspection PyProtectedMember
|
|
_EMPTY_MODEL_ID = _Model._EMPTY_MODEL_ID
|
|
_WARNING_CONNECTED_NAMES = {}
|
|
|
|
@classmethod
|
|
def import_model(
|
|
cls,
|
|
weights_url, # type: str
|
|
config_text=None, # type: Optional[str]
|
|
config_dict=None, # type: Optional[dict]
|
|
label_enumeration=None, # type: Optional[Mapping[str, int]]
|
|
name=None, # type: Optional[str]
|
|
project=None, # type: Optional[str]
|
|
tags=None, # type: Optional[List[str]]
|
|
comment=None, # type: Optional[str]
|
|
is_package=False, # type: bool
|
|
create_as_published=False, # type: bool
|
|
framework=None, # type: Optional[str]
|
|
):
|
|
# type: (...) -> InputModel
|
|
"""
|
|
Create an InputModel object from a pre-trained model by specifying the URL of an initial weight file.
|
|
Optionally, input a configuration, label enumeration, name for the model, tags describing the model,
|
|
comment as a description of the model, indicate whether the model is a package, specify the model's
|
|
framework, and indicate whether to immediately set the model's status to ``Published``.
|
|
The model is read-only.
|
|
|
|
The **ClearML Server** (backend) may already store the model's URL. If the input model's URL is not
|
|
stored, meaning the model is new, then it is imported and ClearML stores its metadata.
|
|
If the URL is already stored, the import process stops, ClearML issues a warning message, and ClearML
|
|
reuses the model.
|
|
|
|
In your Python experiment script, after importing the model, you can connect it to the main execution
|
|
Task as an input model using :meth:`InputModel.connect` or :meth:`.Task.connect`. That initializes the
|
|
network.
|
|
|
|
.. note::
|
|
Using the **ClearML Web-App** (user interface), you can reuse imported models and switch models in
|
|
experiments.
|
|
|
|
:param str weights_url: A valid URL for the initial weights file. If the **ClearML Web-App** (backend)
|
|
already stores the metadata of a model with the same URL, that existing model is returned
|
|
and ClearML ignores all other parameters. For example:
|
|
|
|
- ``https://domain.com/file.bin``
|
|
- ``s3://bucket/file.bin``
|
|
- ``file:///home/user/file.bin``
|
|
|
|
:param str config_text: The configuration as a string. This is usually the content of a configuration
|
|
dictionary file. Specify ``config_text`` or ``config_dict``, but not both.
|
|
:type config_text: unconstrained text string
|
|
:param dict config_dict: The configuration as a dictionary. Specify ``config_text`` or ``config_dict``,
|
|
but not both.
|
|
:param dict label_enumeration: Optional label enumeration dictionary of string (label) to integer (value) pairs.
|
|
|
|
For example:
|
|
|
|
.. code-block:: javascript
|
|
|
|
{
|
|
"background": 0,
|
|
"person": 1
|
|
}
|
|
:param str name: The name of the newly imported model. (Optional)
|
|
:param str project: The project name to add the model into. (Optional)
|
|
:param tags: The list of tags which describe the model. (Optional)
|
|
:type tags: list(str)
|
|
:param str comment: A comment / description for the model. (Optional)
|
|
:type comment: str
|
|
:param is_package: Is the imported weights file is a package (Optional)
|
|
|
|
- ``True`` - Is a package. Add a package tag to the model.
|
|
- ``False`` - Is not a package. Do not add a package tag. (Default)
|
|
|
|
:type is_package: bool
|
|
:param bool create_as_published: Set the model's status to Published (Optional)
|
|
|
|
- ``True`` - Set the status to Published.
|
|
- ``False`` - Do not set the status to Published. The status will be Draft. (Default)
|
|
|
|
:param str framework: The framework of the model. (Optional)
|
|
:type framework: str or Framework object
|
|
|
|
:return: The imported model or existing model (see above).
|
|
"""
|
|
config_text = cls._resolve_config(
|
|
config_text=config_text, config_dict=config_dict
|
|
)
|
|
weights_url = StorageHelper.conform_url(weights_url)
|
|
if not weights_url:
|
|
raise ValueError("Please provide a valid weights_url parameter")
|
|
# convert local to file to remote one
|
|
weights_url = CacheManager.get_remote_url(weights_url)
|
|
|
|
extra = (
|
|
{"system_tags": ["-" + cls._archived_tag]}
|
|
if Session.check_min_api_version("2.3")
|
|
else {"tags": ["-" + cls._archived_tag]}
|
|
)
|
|
# noinspection PyProtectedMember
|
|
result = _Model._get_default_session().send(
|
|
models.GetAllRequest(
|
|
uri=[weights_url], only_fields=["id", "name", "created"], **extra
|
|
)
|
|
)
|
|
|
|
if result.response.models:
|
|
logger = get_logger()
|
|
|
|
logger.debug(
|
|
'A model with uri "{}" already exists. Selecting it'.format(weights_url)
|
|
)
|
|
|
|
model = get_single_result(
|
|
entity="model",
|
|
query=weights_url,
|
|
results=result.response.models,
|
|
log=logger,
|
|
raise_on_error=False,
|
|
)
|
|
|
|
logger.info("Selected model id: {}".format(model.id))
|
|
|
|
return InputModel(model_id=model.id)
|
|
|
|
base_model = _Model(
|
|
upload_storage_uri=None,
|
|
cache_dir=get_cache_dir(),
|
|
)
|
|
|
|
from .task import Task
|
|
|
|
task = Task.current_task()
|
|
if task:
|
|
comment = "Imported by task id: {}".format(task.id) + (
|
|
"\n" + comment if comment else ""
|
|
)
|
|
project_id = task.project
|
|
name = name or "Imported by {}".format(task.name or "")
|
|
# do not register the Task, because we do not want it listed after as "output model",
|
|
# the Task never actually created the Model
|
|
task_id = None
|
|
else:
|
|
project_id = None
|
|
task_id = None
|
|
|
|
if project:
|
|
project_id = get_or_create_project(
|
|
session=task.session if task else Task._get_default_session(),
|
|
project_name=project,
|
|
)
|
|
|
|
if not framework:
|
|
# noinspection PyProtectedMember
|
|
framework, file_ext = Framework._get_file_ext(
|
|
framework=framework, filename=weights_url
|
|
)
|
|
|
|
base_model.update(
|
|
design=config_text,
|
|
labels=label_enumeration,
|
|
name=name,
|
|
comment=comment,
|
|
tags=tags,
|
|
uri=weights_url,
|
|
framework=framework,
|
|
project_id=project_id,
|
|
task_id=task_id,
|
|
)
|
|
|
|
this_model = InputModel(model_id=base_model.id)
|
|
this_model._base_model = base_model
|
|
|
|
if is_package:
|
|
this_model._set_package_tag()
|
|
|
|
if create_as_published:
|
|
this_model.publish()
|
|
|
|
return this_model
|
|
|
|
@classmethod
|
|
def load_model(cls, weights_url, load_archived=False):
|
|
# type: (str, bool) -> InputModel
|
|
"""
|
|
Load an already registered model based on a pre-existing model file (link must be valid). If the url to the
|
|
weights file already exists, the returned object is a Model representing the loaded Model. If no registered
|
|
model with the specified url is found, ``None`` is returned.
|
|
|
|
:param weights_url: The valid url for the weights file (string).
|
|
|
|
Examples:
|
|
|
|
.. code-block:: py
|
|
|
|
"https://domain.com/file.bin" or "s3://bucket/file.bin" or "file:///home/user/file.bin".
|
|
|
|
.. note::
|
|
If a model with the exact same URL exists, it will be used, and all other arguments will be ignored.
|
|
|
|
:param bool load_archived: Load archived models
|
|
|
|
- ``True`` - Load the registered Model, if it is archived.
|
|
- ``False`` - Ignore archive models.
|
|
|
|
:return: The InputModel object, or None if no model could be found.
|
|
"""
|
|
weights_url = StorageHelper.conform_url(weights_url)
|
|
if not weights_url:
|
|
raise ValueError("Please provide a valid weights_url parameter")
|
|
|
|
# convert local to file to remote one
|
|
weights_url = CacheManager.get_remote_url(weights_url)
|
|
|
|
if not load_archived:
|
|
# noinspection PyTypeChecker
|
|
extra = (
|
|
{"system_tags": ["-" + _Task.archived_tag]}
|
|
if Session.check_min_api_version("2.3")
|
|
else {"tags": ["-" + cls._archived_tag]}
|
|
)
|
|
else:
|
|
extra = {}
|
|
|
|
# noinspection PyProtectedMember
|
|
result = _Model._get_default_session().send(
|
|
models.GetAllRequest(
|
|
uri=[weights_url], only_fields=["id", "name", "created"], **extra
|
|
)
|
|
)
|
|
|
|
if not result or not result.response or not result.response.models:
|
|
return None
|
|
|
|
logger = get_logger()
|
|
model = get_single_result(
|
|
entity="model",
|
|
query=weights_url,
|
|
results=result.response.models,
|
|
log=logger,
|
|
raise_on_error=False,
|
|
)
|
|
|
|
return InputModel(model_id=model.id)
|
|
|
|
@classmethod
|
|
def empty(cls, config_text=None, config_dict=None, label_enumeration=None):
|
|
# type: (Optional[str], Optional[dict], Optional[Mapping[str, int]]) -> InputModel
|
|
"""
|
|
Create an empty model object. Later, you can assign a model to the empty model object.
|
|
|
|
:param config_text: The model configuration as a string. This is usually the content of a configuration
|
|
dictionary file. Specify ``config_text`` or ``config_dict``, but not both.
|
|
:type config_text: unconstrained text string
|
|
:param dict config_dict: The model configuration as a dictionary. Specify ``config_text`` or ``config_dict``,
|
|
but not both.
|
|
:param dict label_enumeration: The label enumeration dictionary of string (label) to integer (value) pairs.
|
|
(Optional)
|
|
|
|
For example:
|
|
|
|
.. code-block:: javascript
|
|
|
|
{
|
|
"background": 0,
|
|
"person": 1
|
|
}
|
|
|
|
:return: An empty model object.
|
|
"""
|
|
design = cls._resolve_config(config_text=config_text, config_dict=config_dict)
|
|
|
|
this_model = InputModel(model_id=cls._EMPTY_MODEL_ID)
|
|
this_model._base_model = m = _Model(
|
|
cache_dir=None,
|
|
upload_storage_uri=None,
|
|
model_id=cls._EMPTY_MODEL_ID,
|
|
)
|
|
# noinspection PyProtectedMember
|
|
m._data.design = _Model._wrap_design(design)
|
|
# noinspection PyProtectedMember
|
|
m._data.labels = label_enumeration
|
|
return this_model
|
|
|
|
def __init__(
|
|
self, model_id=None, name=None, project=None, tags=None, only_published=False
|
|
):
|
|
# type: (Optional[str], Optional[str], Optional[str], Optional[Sequence[str]], bool) -> None
|
|
"""
|
|
Load a model from the Model artifactory,
|
|
based on model_id (uuid) or a model name/projects/tags combination.
|
|
|
|
:param model_id: The ClearML ID (system UUID) of the input model whose metadata the **ClearML Server**
|
|
(backend) stores. If provided all other arguments are ignored
|
|
:param name: Model name to search and load
|
|
:param project: Model project name to search model in
|
|
:param tags: Model tags list to filter by
|
|
:param only_published: If True, filter out non-published (draft) models
|
|
"""
|
|
if not model_id:
|
|
found_models = self.query_models(
|
|
project_name=project,
|
|
model_name=name,
|
|
tags=tags,
|
|
only_published=only_published,
|
|
)
|
|
if not found_models:
|
|
raise ValueError(
|
|
"Could not locate model with project={} name={} tags={} published={}".format(
|
|
project, name, tags, only_published
|
|
)
|
|
)
|
|
model_id = found_models[0].id
|
|
super(InputModel, self).__init__(model_id)
|
|
|
|
@property
|
|
def id(self):
|
|
# type: () -> str
|
|
return self._base_model_id
|
|
|
|
def connect(self, task, name=None, ignore_remote_overrides=False):
|
|
# type: (Task, Optional[str], bool) -> None
|
|
"""
|
|
Connect the current model to a Task object, if the model is preexisting. Preexisting models include:
|
|
|
|
- Imported models (InputModel objects created using the :meth:`Logger.import_model` method).
|
|
- Models whose metadata is already in the ClearML platform, meaning the InputModel object is instantiated
|
|
from the ``InputModel`` class specifying the model's ClearML ID as an argument.
|
|
- Models whose origin is not ClearML that are used to create an InputModel object. For example,
|
|
models created using TensorFlow models.
|
|
|
|
When the experiment is executed remotely in a worker, the input model specified in the experiment UI/backend
|
|
is used, unless `ignore_remote_overrides` is set to True.
|
|
|
|
.. note::
|
|
The **ClearML Web-App** allows you to switch one input model for another and then enqueue the experiment
|
|
to execute in a worker.
|
|
|
|
:param object task: A Task object.
|
|
:param ignore_remote_overrides: If True, changing the model in the UI/backend will have no
|
|
effect when running remotely.
|
|
Default is False, meaning that any changes made in the UI/backend will be applied in remote execution.
|
|
:param str name: The model name to be stored on the Task
|
|
(default to filename of the model weights, without the file extension, or to `Input Model`
|
|
if that is not found)
|
|
"""
|
|
self._set_task(task)
|
|
name = name or InputModel._get_connect_name(self)
|
|
InputModel._warn_on_same_name_connect(name)
|
|
ignore_remote_overrides = task._handle_ignore_remote_overrides(
|
|
name + "/_ignore_remote_overrides_input_model_", ignore_remote_overrides
|
|
)
|
|
|
|
model_id = None
|
|
# noinspection PyProtectedMember
|
|
if running_remotely() and (task.is_main_task() or task._is_remote_main_task()) and not ignore_remote_overrides:
|
|
input_models = task.input_models_id
|
|
# noinspection PyBroadException
|
|
try:
|
|
# TODO: (temp fix) At the moment, the UI changes the key of the model hparam
|
|
# when modifying its value... There is no way to tell which model was changed
|
|
# so just take the first one in case `name` is not in `input_models`
|
|
model_id = input_models.get(name, next(iter(input_models.values())))
|
|
self._base_model_id = model_id
|
|
self._base_model = InputModel(model_id=model_id)._get_base_model()
|
|
except Exception:
|
|
model_id = None
|
|
|
|
if not model_id:
|
|
# we should set the task input model to point to us
|
|
model = self._get_base_model()
|
|
# try to store the input model id, if it is not empty
|
|
# (Empty Should not happen)
|
|
if model.id != self._EMPTY_MODEL_ID:
|
|
task.set_input_model(model_id=model.id, name=name)
|
|
# only copy the model design if the task has no design to begin with
|
|
# noinspection PyProtectedMember
|
|
if not self._task._get_model_config_text() and model.model_design:
|
|
# noinspection PyProtectedMember
|
|
task._set_model_config(config_text=model.model_design)
|
|
if not self._task.get_labels_enumeration() and model.data.labels:
|
|
task.set_model_label_enumeration(model.data.labels)
|
|
|
|
@classmethod
|
|
def _warn_on_same_name_connect(cls, name):
|
|
if name not in cls._WARNING_CONNECTED_NAMES:
|
|
cls._WARNING_CONNECTED_NAMES[name] = False
|
|
return
|
|
if cls._WARNING_CONNECTED_NAMES[name]:
|
|
return
|
|
get_logger().warning("Connecting multiple input models with the same name: `{}`. This might result in the wrong model being used when executing remotely".format(name))
|
|
cls._WARNING_CONNECTED_NAMES[name] = True
|
|
|
|
@staticmethod
|
|
def _get_connect_name(model):
|
|
default_name = "Input Model"
|
|
if model is None:
|
|
return default_name
|
|
# noinspection PyBroadException
|
|
try:
|
|
model_uri = getattr(model, "url", getattr(model, "uri", None))
|
|
return Path(model_uri).stem
|
|
except Exception:
|
|
return default_name
|
|
|
|
|
|
class OutputModel(BaseModel):
|
|
"""
|
|
Create an output model for a Task (experiment) to store the training results.
|
|
|
|
The OutputModel object is always connected to a Task object, because it is instantiated with a Task object
|
|
as an argument. It is, therefore, automatically registered as the Task's (experiment's) output model.
|
|
|
|
The OutputModel object is read-write.
|
|
|
|
A common use case is to reuse the OutputModel object, and override the weights after storing a model snapshot.
|
|
Another use case is to create multiple OutputModel objects for a Task (experiment), and after a new high score
|
|
is found, store a model snapshot.
|
|
|
|
If the model configuration and / or the model's label enumeration
|
|
are ``None``, then the output model is initialized with the values from the Task object's input model.
|
|
|
|
.. note::
|
|
When executing a Task (experiment) remotely in a worker, you can modify the model configuration and / or model's
|
|
label enumeration using the **ClearML Web-App**.
|
|
"""
|
|
_default_output_uri = None
|
|
_offline_folder = "models"
|
|
|
|
@property
|
|
def published(self):
|
|
# type: () -> bool
|
|
"""
|
|
Get the published state of this model.
|
|
|
|
:return:
|
|
|
|
"""
|
|
if not self.id:
|
|
return False
|
|
return self._get_base_model().locked
|
|
|
|
@property
|
|
def config_text(self):
|
|
# type: () -> str
|
|
"""
|
|
Get the configuration as a string. For example, prototxt, an ini file, or Python code to evaluate.
|
|
|
|
:return: The configuration.
|
|
"""
|
|
# noinspection PyProtectedMember
|
|
return _Model._unwrap_design(self._get_model_data().design)
|
|
|
|
@config_text.setter
|
|
def config_text(self, value):
|
|
# type: (str) -> None
|
|
"""
|
|
Set the configuration. Store a blob of text for custom usage.
|
|
"""
|
|
self.update_design(config_text=value)
|
|
|
|
@property
|
|
def config_dict(self):
|
|
# type: () -> dict
|
|
"""
|
|
Get the configuration as a dictionary parsed from the ``config_text`` text. This usually represents the model
|
|
configuration. For example, from prototxt to ini file or python code to evaluate.
|
|
|
|
:return: The configuration.
|
|
"""
|
|
return self._text_to_config_dict(self.config_text)
|
|
|
|
@config_dict.setter
|
|
def config_dict(self, value):
|
|
# type: (dict) -> None
|
|
"""
|
|
Set the configuration. Saved in the model object.
|
|
|
|
:param dict value: The configuration parameters.
|
|
"""
|
|
self.update_design(config_dict=value)
|
|
|
|
@property
|
|
def labels(self):
|
|
# type: () -> Dict[str, int]
|
|
"""
|
|
Get the label enumeration as a dictionary of string (label) to integer (value) pairs.
|
|
|
|
For example:
|
|
|
|
.. code-block:: javascript
|
|
|
|
{
|
|
"background": 0,
|
|
"person": 1
|
|
}
|
|
|
|
:return: The label enumeration.
|
|
"""
|
|
return self._get_model_data().labels
|
|
|
|
@labels.setter
|
|
def labels(self, value):
|
|
# type: (Mapping[str, int]) -> None
|
|
"""
|
|
Set the label enumeration.
|
|
|
|
:param dict value: The label enumeration dictionary of string (label) to integer (value) pairs.
|
|
|
|
For example:
|
|
|
|
.. code-block:: javascript
|
|
|
|
{
|
|
"background": 0,
|
|
"person": 1
|
|
}
|
|
|
|
"""
|
|
self.update_labels(labels=value)
|
|
|
|
@property
|
|
def upload_storage_uri(self):
|
|
# type: () -> str
|
|
return self._get_base_model().upload_storage_uri
|
|
|
|
def __init__(
|
|
self,
|
|
task=None, # type: Optional[Task]
|
|
config_text=None, # type: Optional[str]
|
|
config_dict=None, # type: Optional[dict]
|
|
label_enumeration=None, # type: Optional[Mapping[str, int]]
|
|
name=None, # type: Optional[str]
|
|
tags=None, # type: Optional[List[str]]
|
|
comment=None, # type: Optional[str]
|
|
framework=None, # type: Optional[Union[str, Framework]]
|
|
base_model_id=None, # type: Optional[str]
|
|
):
|
|
"""
|
|
Create a new model and immediately connect it to a task.
|
|
|
|
We do not allow for Model creation without a task, so we always keep track on how we created the models
|
|
In remote execution, Model parameters can be overridden by the Task
|
|
(such as model configuration & label enumerator)
|
|
|
|
:param task: The Task object with which the OutputModel object is associated.
|
|
:type task: Task
|
|
:param config_text: The configuration as a string. This is usually the content of a configuration
|
|
dictionary file. Specify ``config_text`` or ``config_dict``, but not both.
|
|
:type config_text: unconstrained text string
|
|
:param dict config_dict: The configuration as a dictionary.
|
|
Specify ``config_dict`` or ``config_text``, but not both.
|
|
:param dict label_enumeration: The label enumeration dictionary of string (label) to integer (value) pairs.
|
|
(Optional)
|
|
|
|
For example:
|
|
|
|
.. code-block:: javascript
|
|
|
|
{
|
|
"background": 0,
|
|
"person": 1
|
|
}
|
|
|
|
:param str name: The name for the newly created model. (Optional)
|
|
:param list(str) tags: A list of strings which are tags for the model. (Optional)
|
|
:param str comment: A comment / description for the model. (Optional)
|
|
:param framework: The framework of the model or a Framework object. (Optional)
|
|
:type framework: str or Framework object
|
|
:param base_model_id: optional, model ID to be reused
|
|
"""
|
|
if not task:
|
|
from .task import Task
|
|
|
|
task = Task.current_task()
|
|
if not task:
|
|
raise ValueError(
|
|
"task object was not provided, and no current task was found"
|
|
)
|
|
|
|
super(OutputModel, self).__init__(task=task)
|
|
|
|
config_text = self._resolve_config(
|
|
config_text=config_text, config_dict=config_dict
|
|
)
|
|
|
|
self._model_local_filename = None
|
|
self._last_uploaded_url = None
|
|
self._base_model = None
|
|
self._base_model_id = None
|
|
self._task_connect_name = None
|
|
self._name = name
|
|
self._label_enumeration = label_enumeration
|
|
# noinspection PyProtectedMember
|
|
self._floating_data = create_dummy_model(
|
|
design=_Model._wrap_design(config_text),
|
|
labels=label_enumeration or task.get_labels_enumeration(),
|
|
name=name or self._task.name,
|
|
tags=tags,
|
|
comment="{} by task id: {}".format(
|
|
"Created" if not base_model_id else "Overwritten", task.id
|
|
)
|
|
+ ("\n" + comment if comment else ""),
|
|
framework=framework,
|
|
upload_storage_uri=task.output_uri,
|
|
)
|
|
# If we have no real model ID, we are done
|
|
if not base_model_id:
|
|
return
|
|
|
|
# noinspection PyBroadException
|
|
try:
|
|
# noinspection PyProtectedMember
|
|
_base_model = self._task._get_output_model(model_id=base_model_id)
|
|
_base_model.update(
|
|
labels=self._floating_data.labels,
|
|
design=self._floating_data.design,
|
|
task_id=self._task.id,
|
|
project_id=self._task.project,
|
|
name=self._floating_data.name or self._task.name,
|
|
comment=(
|
|
"{}\n{}".format(_base_model.comment, self._floating_data.comment)
|
|
if (
|
|
_base_model.comment
|
|
and self._floating_data.comment
|
|
and self._floating_data.comment not in _base_model.comment
|
|
)
|
|
else (_base_model.comment or self._floating_data.comment)
|
|
),
|
|
tags=self._floating_data.tags,
|
|
framework=self._floating_data.framework,
|
|
upload_storage_uri=self._floating_data.upload_storage_uri,
|
|
)
|
|
self._base_model = _base_model
|
|
self._floating_data = None
|
|
name = self._task_connect_name or Path(_base_model.uri).stem
|
|
except Exception:
|
|
pass
|
|
self.connect(task, name=name)
|
|
|
|
def connect(self, task, name=None, **kwargs):
|
|
# type: (Task, Optional[str]) -> None
|
|
"""
|
|
Connect the current model to a Task object, if the model is a preexisting model. Preexisting models include:
|
|
|
|
- Imported models.
|
|
- Models whose metadata the **ClearML Server** (backend) is already storing.
|
|
- Models from another source, such as frameworks like TensorFlow.
|
|
|
|
:param object task: A Task object.
|
|
:param str name: The model name as it would appear on the Task object.
|
|
The model object itself can have a different name,
|
|
this is designed to support multiple models used/created by a single Task.
|
|
Use examples would be GANs or model ensemble
|
|
"""
|
|
if self._task != task:
|
|
raise ValueError(
|
|
"Can only connect preexisting model to task, but this is a fresh model"
|
|
)
|
|
|
|
if name:
|
|
self._task_connect_name = name
|
|
|
|
# we should set the task input model to point to us
|
|
model = self._get_base_model()
|
|
|
|
# only copy the model design if the task has no design to begin with
|
|
# noinspection PyProtectedMember
|
|
if not self._task._get_model_config_text():
|
|
# noinspection PyProtectedMember
|
|
task._set_model_config(
|
|
config_text=model.model_design
|
|
if hasattr(model, "model_design")
|
|
else model.design.get("design", "")
|
|
)
|
|
if not self._task.get_labels_enumeration():
|
|
task.set_model_label_enumeration(
|
|
model.data.labels if hasattr(model, "data") else model.labels
|
|
)
|
|
|
|
if self._base_model:
|
|
self._base_model.update_for_task(
|
|
task_id=self._task.id,
|
|
model_id=self.id,
|
|
type_="output",
|
|
name=self._task_connect_name,
|
|
)
|
|
|
|
def set_upload_destination(self, uri):
|
|
# type: (str) -> None
|
|
"""
|
|
Set the URI of the storage destination for uploaded model weight files.
|
|
Supported storage destinations include S3, Google Cloud Storage, and file locations.
|
|
|
|
Using this method, file uploads are separate and then a link to each is stored in the model object.
|
|
|
|
.. note::
|
|
For storage requiring credentials, the credentials are stored in the ClearML configuration file,
|
|
``~/clearml.conf``.
|
|
|
|
:param str uri: The URI of the upload storage destination.
|
|
|
|
For example:
|
|
|
|
- ``s3://bucket/directory/``
|
|
- ``file:///tmp/debug/``
|
|
|
|
:return bool: The status of whether the storage destination schema is supported.
|
|
|
|
- ``True`` - The storage destination scheme is supported.
|
|
- ``False`` - The storage destination scheme is not supported.
|
|
"""
|
|
if not uri:
|
|
return
|
|
|
|
# Test if we can update the model.
|
|
self._validate_update()
|
|
|
|
# Create the storage helper
|
|
storage = StorageHelper.get(uri)
|
|
|
|
# Verify that we can upload to this destination
|
|
try:
|
|
uri = storage.verify_upload(folder_uri=uri)
|
|
except Exception:
|
|
raise ValueError(
|
|
"Could not set destination uri to: %s [Check write permissions]" % uri
|
|
)
|
|
|
|
# store default uri
|
|
self._get_base_model().upload_storage_uri = uri
|
|
|
|
def update_weights(
|
|
self,
|
|
weights_filename=None, # type: Optional[str]
|
|
upload_uri=None, # type: Optional[str]
|
|
target_filename=None, # type: Optional[str]
|
|
auto_delete_file=True, # type: bool
|
|
register_uri=None, # type: Optional[str]
|
|
iteration=None, # type: Optional[int]
|
|
update_comment=True, # type: bool
|
|
is_package=False, # type: bool
|
|
async_enable=True, # type: bool
|
|
):
|
|
# type: (...) -> str
|
|
"""
|
|
Update the model weights from a locally stored model filename.
|
|
|
|
.. note::
|
|
Uploading the model is a background process. A call to this method returns immediately.
|
|
|
|
:param str weights_filename: The name of the locally stored weights file to upload.
|
|
Specify ``weights_filename`` or ``register_uri``, but not both.
|
|
:param str upload_uri: The URI of the storage destination for model weights upload. The default value
|
|
is the previously used URI. (Optional)
|
|
:param str target_filename: The newly created filename in the storage destination location. The default value
|
|
is the ``weights_filename`` value. (Optional)
|
|
:param bool auto_delete_file: Delete the temporary file after uploading (Optional)
|
|
|
|
- ``True`` - Delete (Default)
|
|
- ``False`` - Do not delete
|
|
|
|
:param str register_uri: The URI of an already uploaded weights file. The URI must be valid. Specify
|
|
``register_uri`` or ``weights_filename``, but not both.
|
|
:param int iteration: The iteration number.
|
|
:param bool update_comment: Update the model comment with the local weights file name (to maintain provenance) (Optional)
|
|
|
|
- ``True`` - Update model comment (Default)
|
|
- ``False`` - Do not update
|
|
:param bool is_package: Mark the weights file as compressed package, usually a zip file.
|
|
:param bool async_enable: Whether to upload model in background or to block.
|
|
Will raise an error in the main thread if the weights failed to be uploaded or not.
|
|
|
|
:return: The uploaded URI.
|
|
"""
|
|
|
|
def delete_previous_weights_file(filename=weights_filename):
|
|
try:
|
|
if filename:
|
|
os.remove(filename)
|
|
except OSError:
|
|
self._log.debug("Failed removing temporary file %s" % filename)
|
|
|
|
# test if we can update the model
|
|
if self.id and self.published:
|
|
raise ValueError("Model is published and cannot be changed")
|
|
|
|
if (not weights_filename and not register_uri) or (
|
|
weights_filename and register_uri
|
|
):
|
|
raise ValueError(
|
|
"Model update must have either local weights file to upload, "
|
|
"or pre-uploaded register_uri, never both"
|
|
)
|
|
|
|
# only upload if we are connected to a task
|
|
if not self._task:
|
|
raise Exception("Missing a task for this model")
|
|
|
|
if self._task.is_offline() and (weights_filename is None or not Path(weights_filename).is_dir()):
|
|
return self._update_weights_offline(
|
|
weights_filename=weights_filename,
|
|
upload_uri=upload_uri,
|
|
target_filename=target_filename,
|
|
register_uri=register_uri,
|
|
iteration=iteration,
|
|
update_comment=update_comment,
|
|
is_package=is_package,
|
|
)
|
|
|
|
if weights_filename is not None:
|
|
# Check if weights_filename is a folder, is package upload
|
|
if Path(weights_filename).is_dir():
|
|
return self.update_weights_package(
|
|
weights_path=weights_filename,
|
|
upload_uri=upload_uri,
|
|
target_filename=target_filename or Path(weights_filename).name,
|
|
auto_delete_file=auto_delete_file,
|
|
iteration=iteration,
|
|
async_enable=async_enable
|
|
)
|
|
|
|
# make sure we delete the previous file, if it exists
|
|
if self._model_local_filename != weights_filename:
|
|
delete_previous_weights_file(self._model_local_filename)
|
|
# store temp filename for deletion next time, if needed
|
|
if auto_delete_file:
|
|
self._model_local_filename = weights_filename
|
|
|
|
# make sure the created model is updated:
|
|
out_model_file_name = target_filename or weights_filename or register_uri
|
|
|
|
# prefer self._task_connect_name if exists
|
|
if self._task_connect_name:
|
|
name = self._task_connect_name
|
|
elif out_model_file_name:
|
|
name = Path(out_model_file_name).stem
|
|
else:
|
|
name = "Output Model"
|
|
|
|
if not self._base_model:
|
|
model = self._get_force_base_model(task_model_entry=name)
|
|
else:
|
|
self._update_base_model(task_model_entry=name)
|
|
model = self._base_model
|
|
if not model:
|
|
raise ValueError("Failed creating internal output model")
|
|
|
|
# select the correct file extension based on the framework,
|
|
# or update the framework based on the file extension
|
|
# noinspection PyProtectedMember
|
|
framework, file_ext = Framework._get_file_ext(
|
|
framework=self._get_model_data().framework,
|
|
filename=target_filename or weights_filename or register_uri,
|
|
)
|
|
|
|
if weights_filename:
|
|
target_filename = target_filename or Path(weights_filename).name
|
|
if not target_filename.lower().endswith(file_ext):
|
|
target_filename += file_ext
|
|
|
|
# set target uri for upload (if specified)
|
|
if upload_uri:
|
|
self.set_upload_destination(upload_uri)
|
|
|
|
# let us know the iteration number, we put it in the comment section for now.
|
|
if update_comment:
|
|
comment = self.comment or ""
|
|
iteration_msg = "snapshot {} stored".format(
|
|
weights_filename or register_uri
|
|
)
|
|
if not comment.startswith("\n"):
|
|
comment = "\n" + comment
|
|
comment = iteration_msg + comment
|
|
else:
|
|
comment = None
|
|
|
|
# if we have no output destination, just register the local model file
|
|
if (
|
|
weights_filename
|
|
and not self.upload_storage_uri
|
|
and not self._task.storage_uri
|
|
):
|
|
register_uri = weights_filename
|
|
weights_filename = None
|
|
auto_delete_file = False
|
|
self._log.info(
|
|
"No output storage destination defined, registering local model %s"
|
|
% register_uri
|
|
)
|
|
|
|
# start the upload
|
|
if weights_filename:
|
|
if not model.upload_storage_uri:
|
|
self.set_upload_destination(
|
|
self.upload_storage_uri or self._task.storage_uri
|
|
)
|
|
|
|
output_uri = model.update_and_upload(
|
|
model_file=weights_filename,
|
|
task_id=self._task.id,
|
|
async_enable=async_enable,
|
|
target_filename=target_filename,
|
|
framework=self.framework or framework,
|
|
comment=comment,
|
|
cb=delete_previous_weights_file if auto_delete_file else None,
|
|
iteration=iteration or self._task.get_last_iteration(),
|
|
)
|
|
elif register_uri:
|
|
register_uri = StorageHelper.conform_url(register_uri)
|
|
output_uri = model.update(
|
|
uri=register_uri,
|
|
task_id=self._task.id,
|
|
framework=framework,
|
|
comment=comment,
|
|
)
|
|
else:
|
|
output_uri = None
|
|
|
|
self._last_uploaded_url = output_uri
|
|
|
|
if is_package:
|
|
self._set_package_tag()
|
|
|
|
return output_uri
|
|
|
|
def update_weights_package(
|
|
self,
|
|
weights_filenames=None, # type: Optional[Sequence[str]]
|
|
weights_path=None, # type: Optional[str]
|
|
upload_uri=None, # type: Optional[str]
|
|
target_filename=None, # type: Optional[str]
|
|
auto_delete_file=True, # type: bool
|
|
iteration=None, # type: Optional[int]
|
|
async_enable=True, # type: bool
|
|
):
|
|
# type: (...) -> str
|
|
"""
|
|
Update the model weights from locally stored model files, or from directory containing multiple files.
|
|
|
|
.. note::
|
|
Uploading the model weights is a background process. A call to this method returns immediately.
|
|
|
|
:param weights_filenames: The file names of the locally stored model files. Specify ``weights_filenames``,
|
|
or ``weights_path``, but not both.
|
|
:type weights_filenames: list(str)
|
|
:param weights_path: The directory path to a package. All the files in the directory will be uploaded.
|
|
Specify ``weights_path`` or ``weights_filenames``, but not both.
|
|
:type weights_path: str
|
|
:param str upload_uri: The URI of the storage destination for the model weights upload. The default
|
|
is the previously used URI. (Optional)
|
|
:param str target_filename: The newly created filename in the storage destination URI location. The default
|
|
is the value specified in the ``weights_filename`` parameter. (Optional)
|
|
:param bool auto_delete_file: Delete temporary file after uploading (Optional)
|
|
|
|
- ``True`` - Delete (Default)
|
|
- ``False`` - Do not delete
|
|
|
|
:param int iteration: The iteration number.
|
|
:param bool async_enable: Whether to upload model in background or to block.
|
|
Will raise an error in the main thread if the weights failed to be uploaded or not.
|
|
|
|
:return: The uploaded URI for the weights package.
|
|
"""
|
|
# create list of files
|
|
if (not weights_filenames and not weights_path) or (
|
|
weights_filenames and weights_path
|
|
):
|
|
raise ValueError(
|
|
"Model update weights package should get either "
|
|
"directory path to pack or a list of files"
|
|
)
|
|
|
|
if not weights_filenames:
|
|
weights_filenames = list(map(six.text_type, Path(weights_path).rglob("*")))
|
|
elif weights_filenames and len(weights_filenames) > 1:
|
|
weights_path = get_common_path(weights_filenames)
|
|
|
|
# create packed model from all the files
|
|
fd, zip_file = mkstemp(prefix="model_package.", suffix=".zip")
|
|
try:
|
|
with zipfile.ZipFile(
|
|
zip_file, "w", allowZip64=True, compression=zipfile.ZIP_STORED
|
|
) as zf:
|
|
for filename in weights_filenames:
|
|
relative_file_name = (
|
|
Path(filename).name
|
|
if not weights_path
|
|
else Path(filename)
|
|
.absolute()
|
|
.relative_to(Path(weights_path).absolute())
|
|
.as_posix()
|
|
)
|
|
zf.write(filename, arcname=relative_file_name)
|
|
finally:
|
|
os.close(fd)
|
|
|
|
# now we can delete the files (or path if provided)
|
|
if auto_delete_file:
|
|
|
|
def safe_remove(path, is_dir=False):
|
|
try:
|
|
(os.rmdir if is_dir else os.remove)(path)
|
|
except OSError:
|
|
self._log.info("Failed removing temporary {}".format(path))
|
|
|
|
for filename in weights_filenames:
|
|
safe_remove(filename)
|
|
if weights_path:
|
|
safe_remove(weights_path, is_dir=True)
|
|
|
|
if target_filename and not target_filename.lower().endswith(".zip"):
|
|
target_filename += ".zip"
|
|
|
|
# and now we should upload the file, always delete the temporary zip file
|
|
comment = self.comment or ""
|
|
iteration_msg = "snapshot {} stored".format(str(weights_filenames))
|
|
if not comment.startswith("\n"):
|
|
comment = "\n" + comment
|
|
comment = iteration_msg + comment
|
|
self.comment = comment
|
|
uploaded_uri = self.update_weights(
|
|
weights_filename=zip_file,
|
|
auto_delete_file=True,
|
|
upload_uri=upload_uri,
|
|
target_filename=target_filename or "model_package.zip",
|
|
iteration=iteration,
|
|
update_comment=False,
|
|
async_enable=async_enable
|
|
)
|
|
# set the model tag (by now we should have a model object) so we know we have packaged file
|
|
self._set_package_tag()
|
|
return uploaded_uri
|
|
|
|
def update_design(self, config_text=None, config_dict=None):
|
|
# type: (Optional[str], Optional[dict]) -> bool
|
|
"""
|
|
Update the model configuration. Store a blob of text for custom usage.
|
|
|
|
.. note::
|
|
This method's behavior is lazy. The design update is only forced when the weights
|
|
are updated.
|
|
|
|
:param config_text: The configuration as a string. This is usually the content of a configuration
|
|
dictionary file. Specify ``config_text`` or ``config_dict``, but not both.
|
|
:type config_text: unconstrained text string
|
|
:param dict config_dict: The configuration as a dictionary. Specify ``config_text`` or ``config_dict``,
|
|
but not both.
|
|
|
|
:return: True, update successful. False, update not successful.
|
|
"""
|
|
if not self._validate_update():
|
|
return False
|
|
|
|
config_text = self._resolve_config(
|
|
config_text=config_text, config_dict=config_dict
|
|
)
|
|
|
|
if self._task and not self._task.get_model_config_text():
|
|
self._task.set_model_config(config_text=config_text)
|
|
|
|
if self.id:
|
|
# update the model object (this will happen if we resumed a training task)
|
|
result = self._get_force_base_model().edit(design=config_text)
|
|
else:
|
|
# noinspection PyProtectedMember
|
|
self._floating_data.design = _Model._wrap_design(config_text)
|
|
result = Waitable()
|
|
|
|
# you can wait on this object
|
|
return result
|
|
|
|
def update_labels(self, labels):
|
|
# type: (Mapping[str, int]) -> Optional[Waitable]
|
|
"""
|
|
Update the label enumeration.
|
|
|
|
:param dict labels: The label enumeration dictionary of string (label) to integer (value) pairs.
|
|
|
|
For example:
|
|
|
|
.. code-block:: javascript
|
|
|
|
{
|
|
"background": 0,
|
|
"person": 1
|
|
}
|
|
|
|
:return:
|
|
"""
|
|
validate_dict(
|
|
labels,
|
|
key_types=six.string_types,
|
|
value_types=six.integer_types,
|
|
desc="label enumeration",
|
|
)
|
|
|
|
if not self._validate_update():
|
|
return
|
|
|
|
if self._task:
|
|
self._task.set_model_label_enumeration(labels)
|
|
|
|
if self.id:
|
|
# update the model object (this will happen if we resumed a training task)
|
|
result = self._get_force_base_model().edit(labels=labels)
|
|
else:
|
|
self._floating_data.labels = labels
|
|
result = Waitable()
|
|
|
|
# you can wait on this object
|
|
return result
|
|
|
|
@classmethod
|
|
def wait_for_uploads(cls, timeout=None, max_num_uploads=None):
|
|
# type: (Optional[float], Optional[int]) -> None
|
|
"""
|
|
Wait for any pending or in-progress model uploads to complete. If no uploads are pending or in-progress,
|
|
then the ``wait_for_uploads`` returns immediately.
|
|
|
|
:param float timeout: The timeout interval to wait for uploads (seconds). (Optional).
|
|
:param int max_num_uploads: The maximum number of uploads to wait for. (Optional).
|
|
"""
|
|
_Model.wait_for_results(timeout=timeout, max_num_uploads=max_num_uploads)
|
|
|
|
@classmethod
|
|
def set_default_upload_uri(cls, output_uri):
|
|
# type: (Optional[str]) -> None
|
|
"""
|
|
Set the default upload uri for all OutputModels
|
|
|
|
:param output_uri: URL for uploading models. examples:
|
|
https://demofiles.demo.clear.ml, s3://bucket/, gs://bucket/, azure://bucket/, file:///mnt/shared/nfs
|
|
"""
|
|
cls._default_output_uri = str(output_uri) if output_uri else None
|
|
|
|
def _update_weights_offline(
|
|
self,
|
|
weights_filename=None, # type: Optional[str]
|
|
upload_uri=None, # type: Optional[str]
|
|
target_filename=None, # type: Optional[str]
|
|
register_uri=None, # type: Optional[str]
|
|
iteration=None, # type: Optional[int]
|
|
update_comment=True, # type: bool
|
|
is_package=False, # type: bool
|
|
):
|
|
# type: (...) -> str
|
|
if (not weights_filename and not register_uri) or (weights_filename and register_uri):
|
|
raise ValueError(
|
|
"Model update must have either local weights file to upload, "
|
|
"or pre-uploaded register_uri, never both"
|
|
)
|
|
if not self._task:
|
|
raise Exception("Missing a task for this model")
|
|
weights_filename_offline = None
|
|
if weights_filename:
|
|
weights_filename_offline = (
|
|
self._task.get_offline_mode_folder() / self._offline_folder / Path(weights_filename).name
|
|
).as_posix()
|
|
os.makedirs(os.path.dirname(weights_filename_offline), exist_ok=True)
|
|
shutil.copyfile(weights_filename, weights_filename_offline)
|
|
# noinspection PyProtectedMember
|
|
self._task._offline_output_models.append(
|
|
dict(
|
|
init=dict(
|
|
config_text=self.config_text,
|
|
config_dict=self.config_dict,
|
|
label_enumeration=self._label_enumeration,
|
|
name=self.name,
|
|
tags=self.tags,
|
|
comment=self.comment,
|
|
framework=self.framework
|
|
),
|
|
weights=dict(
|
|
weights_filename=weights_filename_offline,
|
|
upload_uri=upload_uri,
|
|
target_filename=target_filename,
|
|
register_uri=register_uri,
|
|
iteration=iteration,
|
|
update_comment=update_comment,
|
|
is_package=is_package
|
|
),
|
|
output_uri=self._get_base_model().upload_storage_uri or self._default_output_uri
|
|
)
|
|
)
|
|
return weights_filename_offline or register_uri
|
|
|
|
def _get_base_model(self):
|
|
if self._floating_data:
|
|
return self._floating_data
|
|
return self._get_force_base_model()
|
|
|
|
def _get_model_data(self):
|
|
if self._base_model:
|
|
return self._base_model.data
|
|
return self._floating_data
|
|
|
|
def _validate_update(self):
|
|
# test if we can update the model
|
|
if self.id and self.published:
|
|
raise ValueError("Model is published and cannot be changed")
|
|
|
|
return True
|
|
|
|
def _get_last_uploaded_filename(self):
|
|
if not self._last_uploaded_url and not self.url:
|
|
return None
|
|
return Path(self._last_uploaded_url or self.url).name
|
|
|
|
|
|
class Waitable(object):
|
|
def wait(self, *_, **__):
|
|
return True
|