diff --git a/trains/model.py b/trains/model.py index 806aefb7..384254fa 100644 --- a/trains/model.py +++ b/trains/model.py @@ -6,7 +6,7 @@ from tempfile import mkdtemp, mkstemp import pyparsing import six -from typing import List, Dict, Union, Optional, TYPE_CHECKING, Sequence +from typing import List, Dict, Union, Optional, Mapping, TYPE_CHECKING, Sequence from .backend_api import Session from .backend_api.services import models @@ -446,7 +446,7 @@ class InputModel(Model): weights_url, # type: str config_text=None, # type: Optional[str] config_dict=None, # type: Optional[dict] - label_enumeration=None, # type: Optional[Dict[str, int]] + label_enumeration=None, # type: Optional[Mapping[str, int]] name=None, # type: Optional[str] tags=None, # type: Optional[List[str]] comment=None, # type: Optional[str] @@ -598,17 +598,27 @@ class InputModel(Model): def load_model(cls, weights_url, load_archived=False): # type: (str, bool) -> InputModel """ - Load an already registered model based on a pre-existing model file (link must be valid). + Load an already registered model based on a pre-existing model file (link must be valid). If the url to the + weights file already exists, the returned object is a Model representing the loaded Model. If no registered + model with the specified url is found, ``None`` is returned. - If the url to the weights file already exists, the returned object is a Model representing the loaded Model - If there could not be found any registered model Model with the specified url, None is returned. + :param weights_url: The valid url for the weights file (string). - :param weights_url: valid url for the weights file (string). - examples: "https://domain.com/file.bin" or "s3://bucket/file.bin" or "file:///home/user/file.bin". - NOTE: if a model with the exact same URL exists, it will be used, and all other arguments will be ignored. - :param bool load_archived: If True return registered Model with even if they are archived, - otherwise archived models are ignored, - :return Model: InputModel object or None if no model could be found + Examples: + + .. code-block:: py + + "https://domain.com/file.bin" or "s3://bucket/file.bin" or "file:///home/user/file.bin". + + .. note:: + If a model with the exact same URL exists, it will be used, and all other arguments will be ignored. + + :param bool load_archived: Load archived models? + + - ``True`` - Load the registered Model, if it is archived. + - ``False`` - Ignore archive models. + + :return: InputModel object, or ``None`` if no model could be found. """ weights_url = StorageHelper.conform_url(weights_url) if not weights_url: @@ -641,7 +651,7 @@ class InputModel(Model): @classmethod def empty(cls, config_text=None, config_dict=None, label_enumeration=None): - # type: (Optional[str], Optional[dict], Optional[Dict[str, int]]) -> InputModel + # type: (Optional[str], Optional[dict], Optional[Mapping[str, int]]) -> InputModel """ Create an empty model object. Later, you can assign a model to the empty model object. @@ -661,6 +671,8 @@ class InputModel(Model): 'background': 0, 'person': 1 } + + :return: Empty model object. """ design = cls._resolve_config(config_text=config_text, config_dict=config_dict) @@ -822,7 +834,7 @@ class OutputModel(BaseModel): @labels.setter def labels(self, value): - # type: (Dict[str, int]) -> None + # type: (Mapping[str, int]) -> None """ Set the label enumeration. @@ -850,7 +862,7 @@ class OutputModel(BaseModel): task=None, # type: Optional[Task] config_text=None, # type: Optional[str] config_dict=None, # type: Optional[dict] - label_enumeration=None, # type: Optional[Dict[str, int]] + label_enumeration=None, # type: Optional[Mapping[str, int]] name=None, # type: Optional[str] tags=None, # type: Optional[List[str]] comment=None, # type: Optional[str] @@ -914,7 +926,7 @@ class OutputModel(BaseModel): ) if base_model_id: try: - _base_model = InputModel(base_model_id)._get_base_model() + _base_model = self._task._get_output_model(model_id=base_model_id) _base_model.update( labels=self._floating_data.labels, design=self._floating_data.design, @@ -922,8 +934,9 @@ class OutputModel(BaseModel): project_id=self._task.project, name=self._floating_data.name or task.name, comment=('{}\n{}'.format(_base_model.comment, self._floating_data.comment) - if _base_model.comment and self._floating_data.comment else - (_base_model.comment or self._floating_data.comment)), + if (_base_model.comment and self._floating_data.comment and + self._floating_data.comment not in _base_model.comment) + else (_base_model.comment or self._floating_data.comment)), tags=self._floating_data.tags, framework=self._floating_data.framework, upload_storage_uri=self._floating_data.upload_storage_uri @@ -1245,10 +1258,11 @@ class OutputModel(BaseModel): :param dict config_dict: The configuration as a dictionary. Specify ``config_text`` or ``config_dict``, but not both. - :return bool: The status of the update. + :return: The status of the update. - ``True`` - Update successful. - ``False`` - Update not successful. + """ if not self._validate_update(): return False @@ -1269,7 +1283,7 @@ class OutputModel(BaseModel): return result def update_labels(self, labels): - # type: (Dict[str, int]) -> Optional[Waitable] + # type: (Mapping[str, int]) -> Optional[Waitable] """ Update the label enumeration.