diff --git a/trains/model.py b/trains/model.py index b3e5615e..52c1abdf 100644 --- a/trains/model.py +++ b/trains/model.py @@ -6,6 +6,7 @@ from tempfile import mkdtemp, mkstemp import pyparsing import six +from typing import List, Dict, Union, Optional, TYPE_CHECKING from .backend_api import Session from .backend_api.services import models @@ -20,6 +21,10 @@ from .backend_interface import Task as _Task from .backend_interface.model import create_dummy_model, Model as _Model from .config import running_remotely, get_cache_dir + +if TYPE_CHECKING: + from .task import Task + ARCHIVED_TAG = "archived" @@ -98,6 +103,7 @@ class BaseModel(object): @property def id(self): + # type: () -> str """ The Id (system UUID) of the model. @@ -109,6 +115,7 @@ class BaseModel(object): @property def name(self): + # type: () -> str """ The name of the model. @@ -120,6 +127,7 @@ class BaseModel(object): @name.setter def name(self, value): + # type: (str) -> None """ Set the model name. @@ -129,6 +137,7 @@ class BaseModel(object): @property def comment(self): + # type: () -> str """ The comment for the model. Also, use for a model description. @@ -140,6 +149,7 @@ class BaseModel(object): @comment.setter def comment(self, value): + # type: (str) -> None """ Set comment for the model. Also, use for a model description. @@ -149,6 +159,7 @@ class BaseModel(object): @property def tags(self): + # type: () -> List[str] """ A list of tags describing the model. @@ -160,6 +171,7 @@ class BaseModel(object): @tags.setter def tags(self, value): + # type: (List[str]) -> None """ Set the list of tags describing the model. @@ -171,6 +183,7 @@ class BaseModel(object): @property def config_text(self): + # type: () -> str """ The configuration as a string. For example, prototxt, an ini file, or Python code to evaluate. @@ -182,6 +195,7 @@ class BaseModel(object): @property def config_dict(self): + # type: () -> dict """ The configuration as a dictionary, parsed from the design text. This usually represents the model configuration. For example, prototxt, an ini file, or Python code to evaluate. @@ -194,6 +208,7 @@ class BaseModel(object): @property def labels(self): + # type: () -> Dict[str, int] """ The label enumeration of string (label) to integer (value) pairs. @@ -206,6 +221,7 @@ class BaseModel(object): @property def task(self): + # type: () -> str """ Return the creating task id (str) @@ -215,6 +231,7 @@ class BaseModel(object): @property def url(self): + # type: () -> str """ Return the url of the model file (or archived files) @@ -224,19 +241,23 @@ class BaseModel(object): @property def published(self): + # type: () -> bool return self._get_base_model().locked @property def framework(self): + # type: () -> str return self._get_model_data().framework def __init__(self, task=None): + # type: (Task) -> None super(BaseModel, self).__init__() self._log = get_logger() self._task = None self._set_task(task) def get_weights(self): + # type: () -> str """ Download the base model and return the locally stored filename. @@ -248,6 +269,7 @@ class BaseModel(object): return self._get_base_model().download_model_weights() def get_weights_package(self, return_path=False): + # type: (bool) -> Union[str, List[Path]] """ Download the base model package into a temporary directory (extract the files), or return a list of the locally stored filenames. @@ -368,6 +390,7 @@ class Model(BaseModel): """ def __init__(self, model_id): + # type: (str) ->None """ Load model based on id, returned object is read-only and can be connected to a task @@ -380,6 +403,7 @@ class Model(BaseModel): self._base_model = None def get_local_copy(self, extract_archive=True): + # type: (bool) -> str """ Retrieve a valid link to the model file(s). If the model URL is a file system link, it will be returned directly. @@ -424,17 +448,18 @@ class InputModel(Model): @classmethod def import_model( cls, - weights_url, - config_text=None, - config_dict=None, - label_enumeration=None, - name=None, - tags=None, - comment=None, - is_package=False, - create_as_published=False, - framework=None, + weights_url, # type: str + config_text=None, # type: Optional[str] + config_dict=None, # type: Optional[dict] + label_enumeration=None, # type: Optional[Dict[str, int]] + name=None, # type: Optional[str] + tags=None, # type: Optional[List[str]] + comment=None, # type: Optional[str] + is_package=False, # type: bool + create_as_published=False, # type: bool + framework=None, # type: Optional[str] ): + # type: (...) -> InputModel """ Create an InputModel object from a pre-trained model by specifying the URL of an initial weight files. Optionally, input a configuration, label enumeration, name for the model, tags describing the model, @@ -577,11 +602,8 @@ class InputModel(Model): return this_model @classmethod - def load_model( - cls, - weights_url, - load_archived=False - ): + def load_model(cls, weights_url, load_archived=False): + # type: (str, bool) -> InputModel """ Load an already registered model based on a pre-existing model file (link must be valid). @@ -625,12 +647,8 @@ class InputModel(Model): return InputModel(model_id=model.id) @classmethod - def empty( - cls, - config_text=None, - config_dict=None, - label_enumeration=None, - ): + def empty(cls, config_text=None, config_dict=None, label_enumeration=None): + # type: (Optional[str], Optional[dict], Optional[Dict[str, int]]) -> InputModel """ Create an empty model object. Later, you can assign a model to the empty model object. @@ -664,6 +682,7 @@ class InputModel(Model): return this_model def __init__(self, model_id): + # type: (str) -> None """ :param str model_id: The Trains Id (system UUID) of the input model whose metadata the **Trains Server** (backend) stores. @@ -672,9 +691,11 @@ class InputModel(Model): @property def id(self): + # type: () -> str return self._base_model_id def connect(self, task): + # type: (Task) -> None """ Connect the current model to a Task object, if the model is preexisting. Preexisting models include: @@ -738,12 +759,21 @@ class OutputModel(BaseModel): @property def published(self): + # type: () -> bool + """ + Get the published state of this model. + + :return: ``True`` if the model is published, ``False`` otherwise. + + :rtype: bool + """ if not self.id: return False return self._get_base_model().locked @property def config_text(self): + # type: () -> str """ Get the configuration as a string. For example, prototxt, an ini file, or Python code to evaluate. @@ -755,6 +785,7 @@ class OutputModel(BaseModel): @config_text.setter def config_text(self, value): + # type: (str) -> None """ Set the configuration. Store a blob of text for custom usage. """ @@ -762,6 +793,7 @@ class OutputModel(BaseModel): @property def config_dict(self): + # type: () -> dict """ Get the configuration as a dictionary parsed from the ``config_text`` text. This usually represents the model configuration. For example, from prototxt to ini file or python code to evaluate. @@ -774,6 +806,7 @@ class OutputModel(BaseModel): @config_dict.setter def config_dict(self, value): + # type: (dict) -> None """ Set the configuration. Saved in the model object. @@ -783,6 +816,7 @@ class OutputModel(BaseModel): @property def labels(self): + # type: () -> Dict[str, int] """ Get the label enumeration as a dictionary of string (label) to integer (value) pairs. @@ -803,6 +837,7 @@ class OutputModel(BaseModel): @labels.setter def labels(self, value): + # type: (Dict[str, int]) -> None """ Set the label enumeration. @@ -822,19 +857,20 @@ class OutputModel(BaseModel): @property def upload_storage_uri(self): + # type: () -> str return self._get_base_model().upload_storage_uri def __init__( self, - task, - config_text=None, - config_dict=None, - label_enumeration=None, - name=None, - tags=None, - comment=None, - framework=None, - base_model_id=None, + task, # type: Task + config_text=None, # type: Optional[str] + config_dict=None, # type: Optional[dict] + label_enumeration=None, # type: Optional[Dict[str, int]] + name=None, # type: Optional[str] + tags=None, # type: Optional[List[str]] + comment=None, # type: Optional[str] + framework=None, # type: Optional[Union[str, Framework]] + base_model_id=None, # type: Optional[str] ): """ Create a new model and immediately connect it to a task. @@ -908,6 +944,7 @@ class OutputModel(BaseModel): self.connect(task) def connect(self, task): + # type: (Task) -> None """ Connect the current model to a Task object, if the model is a preexisting model. Preexisting models include: @@ -947,6 +984,7 @@ class OutputModel(BaseModel): self.task._save_output_model(self) def set_upload_destination(self, uri): + # type: (str) -> None """ Set the URI of the storage destination for uploaded model weight files. Supported storage destinations include S3, Google Cloud Storage), and file locations. @@ -989,8 +1027,17 @@ class OutputModel(BaseModel): # store default uri self._get_base_model().upload_storage_uri = uri - def update_weights(self, weights_filename=None, upload_uri=None, target_filename=None, - auto_delete_file=True, register_uri=None, iteration=None, update_comment=True): + def update_weights( + self, + weights_filename=None, # type: Optional[str] + upload_uri=None, # type: Optional[str] + target_filename=None, # type: Optional[str] + auto_delete_file=True, # type: bool + register_uri=None, # type: Optional[str] + iteration=None, # type: Optional[int] + update_comment=True # type: bool + ): + # type: (...) -> str """ Update the model weights from a locally stored model filename. @@ -1010,6 +1057,7 @@ class OutputModel(BaseModel): :param str register_uri: The URI of an already uploaded weights file. The URI must be valid. Specify ``register_uri`` or ``weights_filename``, but not both. + :param int iteration: The iteration number. :param bool update_comment: Update the model comment with the local weights file name (to maintain provenance)? (Optional) @@ -1111,8 +1159,16 @@ class OutputModel(BaseModel): return output_uri - def update_weights_package(self, weights_filenames=None, weights_path=None, upload_uri=None, - target_filename=None, auto_delete_file=True, iteration=None): + def update_weights_package( + self, + weights_filenames=None, # type: Optional[str] + weights_path=None, # type: Optional[str] + upload_uri=None, # type: Optional[str] + target_filename=None, # type: Optional[str] + auto_delete_file=True, # type: bool + iteration=None # type: Optional[int] + ): + # type: (...) -> str """ Update the model weights from locally stored model files, or from directory containing multiple files. @@ -1134,6 +1190,8 @@ class OutputModel(BaseModel): - ``True`` - Delete (Default) - ``False`` - Do not delete + :param int iteration: The iteration number. + :return: The uploaded URI for the weights package. :rtype: str @@ -1185,6 +1243,7 @@ class OutputModel(BaseModel): return uploaded_uri def update_design(self, config_text=None, config_dict=None): + # type: (Optional[str], Optional[dict]) -> bool """ Update the model configuration. Store a blob of text for custom usage. @@ -1223,6 +1282,7 @@ class OutputModel(BaseModel): return result def update_labels(self, labels): + # type: (Dict[str, int]) -> Optional[Waitable] """ Update the label enumeration. @@ -1259,6 +1319,7 @@ class OutputModel(BaseModel): @classmethod def wait_for_uploads(cls, timeout=None, max_num_uploads=None): + # type: (Optional[float], Optional[int]) -> None """ Wait for any pending or in-progress model uploads to complete. If no uploads are pending or in-progress, then the ``wait_for_uploads`` returns immediately. diff --git a/trains/task.py b/trains/task.py index 98f0ce77..bc42077b 100644 --- a/trains/task.py +++ b/trains/task.py @@ -12,7 +12,7 @@ try: except ImportError: from collections import Callable, Sequence as CollectionsSequence -from typing import Optional, Union, Mapping, Sequence, Any, Dict, List +from typing import Optional, Union, Mapping, Sequence, Any, Dict, List, TYPE_CHECKING import psutil import six @@ -52,6 +52,12 @@ from .utilities.resource_monitor import ResourceMonitor from .utilities.seed import make_deterministic +if TYPE_CHECKING: + import pandas + import numpy + from PIL import Image + + class Task(_Task): """ The ``Task`` class is a code template for a Task object which, together with its connected experiment components, @@ -157,17 +163,17 @@ class Task(_Task): @classmethod def init( - cls, - project_name=None, - task_name=None, - task_type=TaskTypes.training, - reuse_last_task_id=True, - output_uri=None, - auto_connect_arg_parser=True, - auto_connect_frameworks=True, - auto_resource_monitoring=True, + cls, + project_name=None, # type: Optional[str] + task_name=None, # type: Optional[str] + task_type=TaskTypes.training, # type: Task.TaskTypes + reuse_last_task_id=True, # type: bool + output_uri=None, # type: Optional[str] + auto_connect_arg_parser=True, # type: bool + auto_connect_frameworks=True, # type: Union[bool, Mapping[str, bool]] + auto_resource_monitoring=True, # type: bool ): - # type: (Optional[str], Optional[str], TaskTypes, bool, Optional[str], bool, Union[bool, Mapping[str, bool]], bool) -> Task + # type: (...) -> Task """ Creates a new Task (experiment), or returns the existing Task, depending upon the following: @@ -464,12 +470,7 @@ class Task(_Task): return task @classmethod - def create( - cls, - project_name=None, - task_name=None, - task_type=TaskTypes.training, - ): + def create(cls, project_name=None, task_name=None, task_type=TaskTypes.training): # type: (Optional[str], Optional[str], TaskTypes) -> Task """ Create a new, non-reproducible Task (experiment). This is called a sub-task. @@ -523,8 +524,8 @@ class Task(_Task): """ Get a Task by Id, or project name / task name combination. - :param str task_id: The Id (system UUID) of the experiment to get. If specified, ``project_name`` and ``task_name`` - are ignored. + :param str task_id: The Id (system UUID) of the experiment to get. + If specified, ``project_name`` and ``task_name`` are ignored. :param str project_name: The project name of the Task to get. :param str task_name: The name of the Task within ``project_name`` to get. @@ -1028,7 +1029,7 @@ class Task(_Task): self.__register_at_exit(None) def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True): - # type: (str, "pandas.DataFrame", Dict, Union[bool, Sequence[str]]) -> None + # type: (str, pandas.DataFrame, Dict, Union[bool, Sequence[str]]) -> None """ Register (add) an artifact for the current Task. Registered artifacts are dynamically sychronized with the **Trains Server** (backend). If a registered artifact is updated, the update is stored in the @@ -1089,8 +1090,14 @@ class Task(_Task): """ return self._artifacts_manager.registered_artifacts - def upload_artifact(self, name, artifact_object, metadata=None, delete_after_upload=False): - # type: (str, Union[str, Mapping, "pandas.DataFrame", "numpy.ndarray", "PIL.Image.Image"], Optional[Mapping], bool) -> bool + def upload_artifact( + self, + name, # type: str + artifact_object, # type: Union[str, Mapping, pandas.DataFrame, numpy.ndarray, Image.Image] + metadata=None, # type: Optional[Mapping] + delete_after_upload=False # type: bool + ): + # type: (...) -> bool """ Upload (add) a static artifact to a Task object. The artifact is uploaded in the background. @@ -1416,8 +1423,9 @@ class Task(_Task): task._dev_worker = None @classmethod - def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id, - detect_repo=True): + def _create_dev_task( + cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id, detect_repo=True + ): if not default_project_name or not default_task_name: # get project name and task name from repository name and entry_point result, _ = ScriptInfo.get(create_requirements=False, check_uncommitted=False)