Add Model and Task type-annotations

This commit is contained in:
allegroai 2020-05-08 22:08:48 +03:00
parent f90f8f06e2
commit 966cd6118a
2 changed files with 127 additions and 58 deletions

View File

@ -6,6 +6,7 @@ from tempfile import mkdtemp, mkstemp
import pyparsing import pyparsing
import six import six
from typing import List, Dict, Union, Optional, TYPE_CHECKING
from .backend_api import Session from .backend_api import Session
from .backend_api.services import models from .backend_api.services import models
@ -20,6 +21,10 @@ from .backend_interface import Task as _Task
from .backend_interface.model import create_dummy_model, Model as _Model from .backend_interface.model import create_dummy_model, Model as _Model
from .config import running_remotely, get_cache_dir from .config import running_remotely, get_cache_dir
if TYPE_CHECKING:
from .task import Task
ARCHIVED_TAG = "archived" ARCHIVED_TAG = "archived"
@ -98,6 +103,7 @@ class BaseModel(object):
@property @property
def id(self): def id(self):
# type: () -> str
""" """
The Id (system UUID) of the model. The Id (system UUID) of the model.
@ -109,6 +115,7 @@ class BaseModel(object):
@property @property
def name(self): def name(self):
# type: () -> str
""" """
The name of the model. The name of the model.
@ -120,6 +127,7 @@ class BaseModel(object):
@name.setter @name.setter
def name(self, value): def name(self, value):
# type: (str) -> None
""" """
Set the model name. Set the model name.
@ -129,6 +137,7 @@ class BaseModel(object):
@property @property
def comment(self): def comment(self):
# type: () -> str
""" """
The comment for the model. Also, use for a model description. The comment for the model. Also, use for a model description.
@ -140,6 +149,7 @@ class BaseModel(object):
@comment.setter @comment.setter
def comment(self, value): def comment(self, value):
# type: (str) -> None
""" """
Set comment for the model. Also, use for a model description. Set comment for the model. Also, use for a model description.
@ -149,6 +159,7 @@ class BaseModel(object):
@property @property
def tags(self): def tags(self):
# type: () -> List[str]
""" """
A list of tags describing the model. A list of tags describing the model.
@ -160,6 +171,7 @@ class BaseModel(object):
@tags.setter @tags.setter
def tags(self, value): def tags(self, value):
# type: (List[str]) -> None
""" """
Set the list of tags describing the model. Set the list of tags describing the model.
@ -171,6 +183,7 @@ class BaseModel(object):
@property @property
def config_text(self): def config_text(self):
# type: () -> str
""" """
The configuration as a string. For example, prototxt, an ini file, or Python code to evaluate. The configuration as a string. For example, prototxt, an ini file, or Python code to evaluate.
@ -182,6 +195,7 @@ class BaseModel(object):
@property @property
def config_dict(self): def config_dict(self):
# type: () -> dict
""" """
The configuration as a dictionary, parsed from the design text. This usually represents the model configuration. The configuration as a dictionary, parsed from the design text. This usually represents the model configuration.
For example, prototxt, an ini file, or Python code to evaluate. For example, prototxt, an ini file, or Python code to evaluate.
@ -194,6 +208,7 @@ class BaseModel(object):
@property @property
def labels(self): def labels(self):
# type: () -> Dict[str, int]
""" """
The label enumeration of string (label) to integer (value) pairs. The label enumeration of string (label) to integer (value) pairs.
@ -206,6 +221,7 @@ class BaseModel(object):
@property @property
def task(self): def task(self):
# type: () -> str
""" """
Return the creating task id (str) Return the creating task id (str)
@ -215,6 +231,7 @@ class BaseModel(object):
@property @property
def url(self): def url(self):
# type: () -> str
""" """
Return the url of the model file (or archived files) Return the url of the model file (or archived files)
@ -224,19 +241,23 @@ class BaseModel(object):
@property @property
def published(self): def published(self):
# type: () -> bool
return self._get_base_model().locked return self._get_base_model().locked
@property @property
def framework(self): def framework(self):
# type: () -> str
return self._get_model_data().framework return self._get_model_data().framework
def __init__(self, task=None): def __init__(self, task=None):
# type: (Task) -> None
super(BaseModel, self).__init__() super(BaseModel, self).__init__()
self._log = get_logger() self._log = get_logger()
self._task = None self._task = None
self._set_task(task) self._set_task(task)
def get_weights(self): def get_weights(self):
# type: () -> str
""" """
Download the base model and return the locally stored filename. Download the base model and return the locally stored filename.
@ -248,6 +269,7 @@ class BaseModel(object):
return self._get_base_model().download_model_weights() return self._get_base_model().download_model_weights()
def get_weights_package(self, return_path=False): def get_weights_package(self, return_path=False):
# type: (bool) -> Union[str, List[Path]]
""" """
Download the base model package into a temporary directory (extract the files), or return a list of the Download the base model package into a temporary directory (extract the files), or return a list of the
locally stored filenames. locally stored filenames.
@ -368,6 +390,7 @@ class Model(BaseModel):
""" """
def __init__(self, model_id): def __init__(self, model_id):
# type: (str) ->None
""" """
Load model based on id, returned object is read-only and can be connected to a task Load model based on id, returned object is read-only and can be connected to a task
@ -380,6 +403,7 @@ class Model(BaseModel):
self._base_model = None self._base_model = None
def get_local_copy(self, extract_archive=True): def get_local_copy(self, extract_archive=True):
# type: (bool) -> str
""" """
Retrieve a valid link to the model file(s). Retrieve a valid link to the model file(s).
If the model URL is a file system link, it will be returned directly. If the model URL is a file system link, it will be returned directly.
@ -424,17 +448,18 @@ class InputModel(Model):
@classmethod @classmethod
def import_model( def import_model(
cls, cls,
weights_url, weights_url, # type: str
config_text=None, config_text=None, # type: Optional[str]
config_dict=None, config_dict=None, # type: Optional[dict]
label_enumeration=None, label_enumeration=None, # type: Optional[Dict[str, int]]
name=None, name=None, # type: Optional[str]
tags=None, tags=None, # type: Optional[List[str]]
comment=None, comment=None, # type: Optional[str]
is_package=False, is_package=False, # type: bool
create_as_published=False, create_as_published=False, # type: bool
framework=None, framework=None, # type: Optional[str]
): ):
# type: (...) -> InputModel
""" """
Create an InputModel object from a pre-trained model by specifying the URL of an initial weight files. Create an InputModel object from a pre-trained model by specifying the URL of an initial weight files.
Optionally, input a configuration, label enumeration, name for the model, tags describing the model, Optionally, input a configuration, label enumeration, name for the model, tags describing the model,
@ -577,11 +602,8 @@ class InputModel(Model):
return this_model return this_model
@classmethod @classmethod
def load_model( def load_model(cls, weights_url, load_archived=False):
cls, # type: (str, bool) -> InputModel
weights_url,
load_archived=False
):
""" """
Load an already registered model based on a pre-existing model file (link must be valid). Load an already registered model based on a pre-existing model file (link must be valid).
@ -625,12 +647,8 @@ class InputModel(Model):
return InputModel(model_id=model.id) return InputModel(model_id=model.id)
@classmethod @classmethod
def empty( def empty(cls, config_text=None, config_dict=None, label_enumeration=None):
cls, # type: (Optional[str], Optional[dict], Optional[Dict[str, int]]) -> InputModel
config_text=None,
config_dict=None,
label_enumeration=None,
):
""" """
Create an empty model object. Later, you can assign a model to the empty model object. Create an empty model object. Later, you can assign a model to the empty model object.
@ -664,6 +682,7 @@ class InputModel(Model):
return this_model return this_model
def __init__(self, model_id): def __init__(self, model_id):
# type: (str) -> None
""" """
:param str model_id: The Trains Id (system UUID) of the input model whose metadata the **Trains Server** :param str model_id: The Trains Id (system UUID) of the input model whose metadata the **Trains Server**
(backend) stores. (backend) stores.
@ -672,9 +691,11 @@ class InputModel(Model):
@property @property
def id(self): def id(self):
# type: () -> str
return self._base_model_id return self._base_model_id
def connect(self, task): def connect(self, task):
# type: (Task) -> None
""" """
Connect the current model to a Task object, if the model is preexisting. Preexisting models include: Connect the current model to a Task object, if the model is preexisting. Preexisting models include:
@ -738,12 +759,21 @@ class OutputModel(BaseModel):
@property @property
def published(self): def published(self):
# type: () -> bool
"""
Get the published state of this model.
:return: ``True`` if the model is published, ``False`` otherwise.
:rtype: bool
"""
if not self.id: if not self.id:
return False return False
return self._get_base_model().locked return self._get_base_model().locked
@property @property
def config_text(self): def config_text(self):
# type: () -> str
""" """
Get the configuration as a string. For example, prototxt, an ini file, or Python code to evaluate. Get the configuration as a string. For example, prototxt, an ini file, or Python code to evaluate.
@ -755,6 +785,7 @@ class OutputModel(BaseModel):
@config_text.setter @config_text.setter
def config_text(self, value): def config_text(self, value):
# type: (str) -> None
""" """
Set the configuration. Store a blob of text for custom usage. Set the configuration. Store a blob of text for custom usage.
""" """
@ -762,6 +793,7 @@ class OutputModel(BaseModel):
@property @property
def config_dict(self): def config_dict(self):
# type: () -> dict
""" """
Get the configuration as a dictionary parsed from the ``config_text`` text. This usually represents the model Get the configuration as a dictionary parsed from the ``config_text`` text. This usually represents the model
configuration. For example, from prototxt to ini file or python code to evaluate. configuration. For example, from prototxt to ini file or python code to evaluate.
@ -774,6 +806,7 @@ class OutputModel(BaseModel):
@config_dict.setter @config_dict.setter
def config_dict(self, value): def config_dict(self, value):
# type: (dict) -> None
""" """
Set the configuration. Saved in the model object. Set the configuration. Saved in the model object.
@ -783,6 +816,7 @@ class OutputModel(BaseModel):
@property @property
def labels(self): def labels(self):
# type: () -> Dict[str, int]
""" """
Get the label enumeration as a dictionary of string (label) to integer (value) pairs. Get the label enumeration as a dictionary of string (label) to integer (value) pairs.
@ -803,6 +837,7 @@ class OutputModel(BaseModel):
@labels.setter @labels.setter
def labels(self, value): def labels(self, value):
# type: (Dict[str, int]) -> None
""" """
Set the label enumeration. Set the label enumeration.
@ -822,19 +857,20 @@ class OutputModel(BaseModel):
@property @property
def upload_storage_uri(self): def upload_storage_uri(self):
# type: () -> str
return self._get_base_model().upload_storage_uri return self._get_base_model().upload_storage_uri
def __init__( def __init__(
self, self,
task, task, # type: Task
config_text=None, config_text=None, # type: Optional[str]
config_dict=None, config_dict=None, # type: Optional[dict]
label_enumeration=None, label_enumeration=None, # type: Optional[Dict[str, int]]
name=None, name=None, # type: Optional[str]
tags=None, tags=None, # type: Optional[List[str]]
comment=None, comment=None, # type: Optional[str]
framework=None, framework=None, # type: Optional[Union[str, Framework]]
base_model_id=None, base_model_id=None, # type: Optional[str]
): ):
""" """
Create a new model and immediately connect it to a task. Create a new model and immediately connect it to a task.
@ -908,6 +944,7 @@ class OutputModel(BaseModel):
self.connect(task) self.connect(task)
def connect(self, task): def connect(self, task):
# type: (Task) -> None
""" """
Connect the current model to a Task object, if the model is a preexisting model. Preexisting models include: Connect the current model to a Task object, if the model is a preexisting model. Preexisting models include:
@ -947,6 +984,7 @@ class OutputModel(BaseModel):
self.task._save_output_model(self) self.task._save_output_model(self)
def set_upload_destination(self, uri): def set_upload_destination(self, uri):
# type: (str) -> None
""" """
Set the URI of the storage destination for uploaded model weight files. Supported storage destinations include Set the URI of the storage destination for uploaded model weight files. Supported storage destinations include
S3, Google Cloud Storage), and file locations. S3, Google Cloud Storage), and file locations.
@ -989,8 +1027,17 @@ class OutputModel(BaseModel):
# store default uri # store default uri
self._get_base_model().upload_storage_uri = uri self._get_base_model().upload_storage_uri = uri
def update_weights(self, weights_filename=None, upload_uri=None, target_filename=None, def update_weights(
auto_delete_file=True, register_uri=None, iteration=None, update_comment=True): self,
weights_filename=None, # type: Optional[str]
upload_uri=None, # type: Optional[str]
target_filename=None, # type: Optional[str]
auto_delete_file=True, # type: bool
register_uri=None, # type: Optional[str]
iteration=None, # type: Optional[int]
update_comment=True # type: bool
):
# type: (...) -> str
""" """
Update the model weights from a locally stored model filename. Update the model weights from a locally stored model filename.
@ -1010,6 +1057,7 @@ class OutputModel(BaseModel):
:param str register_uri: The URI of an already uploaded weights file. The URI must be valid. Specify :param str register_uri: The URI of an already uploaded weights file. The URI must be valid. Specify
``register_uri`` or ``weights_filename``, but not both. ``register_uri`` or ``weights_filename``, but not both.
:param int iteration: The iteration number.
:param bool update_comment: Update the model comment with the local weights file name (to maintain :param bool update_comment: Update the model comment with the local weights file name (to maintain
provenance)? (Optional) provenance)? (Optional)
@ -1111,8 +1159,16 @@ class OutputModel(BaseModel):
return output_uri return output_uri
def update_weights_package(self, weights_filenames=None, weights_path=None, upload_uri=None, def update_weights_package(
target_filename=None, auto_delete_file=True, iteration=None): self,
weights_filenames=None, # type: Optional[str]
weights_path=None, # type: Optional[str]
upload_uri=None, # type: Optional[str]
target_filename=None, # type: Optional[str]
auto_delete_file=True, # type: bool
iteration=None # type: Optional[int]
):
# type: (...) -> str
""" """
Update the model weights from locally stored model files, or from directory containing multiple files. Update the model weights from locally stored model files, or from directory containing multiple files.
@ -1134,6 +1190,8 @@ class OutputModel(BaseModel):
- ``True`` - Delete (Default) - ``True`` - Delete (Default)
- ``False`` - Do not delete - ``False`` - Do not delete
:param int iteration: The iteration number.
:return: The uploaded URI for the weights package. :return: The uploaded URI for the weights package.
:rtype: str :rtype: str
@ -1185,6 +1243,7 @@ class OutputModel(BaseModel):
return uploaded_uri return uploaded_uri
def update_design(self, config_text=None, config_dict=None): def update_design(self, config_text=None, config_dict=None):
# type: (Optional[str], Optional[dict]) -> bool
""" """
Update the model configuration. Store a blob of text for custom usage. Update the model configuration. Store a blob of text for custom usage.
@ -1223,6 +1282,7 @@ class OutputModel(BaseModel):
return result return result
def update_labels(self, labels): def update_labels(self, labels):
# type: (Dict[str, int]) -> Optional[Waitable]
""" """
Update the label enumeration. Update the label enumeration.
@ -1259,6 +1319,7 @@ class OutputModel(BaseModel):
@classmethod @classmethod
def wait_for_uploads(cls, timeout=None, max_num_uploads=None): def wait_for_uploads(cls, timeout=None, max_num_uploads=None):
# type: (Optional[float], Optional[int]) -> None
""" """
Wait for any pending or in-progress model uploads to complete. If no uploads are pending or in-progress, Wait for any pending or in-progress model uploads to complete. If no uploads are pending or in-progress,
then the ``wait_for_uploads`` returns immediately. then the ``wait_for_uploads`` returns immediately.

View File

@ -12,7 +12,7 @@ try:
except ImportError: except ImportError:
from collections import Callable, Sequence as CollectionsSequence from collections import Callable, Sequence as CollectionsSequence
from typing import Optional, Union, Mapping, Sequence, Any, Dict, List from typing import Optional, Union, Mapping, Sequence, Any, Dict, List, TYPE_CHECKING
import psutil import psutil
import six import six
@ -52,6 +52,12 @@ from .utilities.resource_monitor import ResourceMonitor
from .utilities.seed import make_deterministic from .utilities.seed import make_deterministic
if TYPE_CHECKING:
import pandas
import numpy
from PIL import Image
class Task(_Task): class Task(_Task):
""" """
The ``Task`` class is a code template for a Task object which, together with its connected experiment components, The ``Task`` class is a code template for a Task object which, together with its connected experiment components,
@ -157,17 +163,17 @@ class Task(_Task):
@classmethod @classmethod
def init( def init(
cls, cls,
project_name=None, project_name=None, # type: Optional[str]
task_name=None, task_name=None, # type: Optional[str]
task_type=TaskTypes.training, task_type=TaskTypes.training, # type: Task.TaskTypes
reuse_last_task_id=True, reuse_last_task_id=True, # type: bool
output_uri=None, output_uri=None, # type: Optional[str]
auto_connect_arg_parser=True, auto_connect_arg_parser=True, # type: bool
auto_connect_frameworks=True, auto_connect_frameworks=True, # type: Union[bool, Mapping[str, bool]]
auto_resource_monitoring=True, auto_resource_monitoring=True, # type: bool
): ):
# type: (Optional[str], Optional[str], TaskTypes, bool, Optional[str], bool, Union[bool, Mapping[str, bool]], bool) -> Task # type: (...) -> Task
""" """
Creates a new Task (experiment), or returns the existing Task, depending upon the following: Creates a new Task (experiment), or returns the existing Task, depending upon the following:
@ -464,12 +470,7 @@ class Task(_Task):
return task return task
@classmethod @classmethod
def create( def create(cls, project_name=None, task_name=None, task_type=TaskTypes.training):
cls,
project_name=None,
task_name=None,
task_type=TaskTypes.training,
):
# type: (Optional[str], Optional[str], TaskTypes) -> Task # type: (Optional[str], Optional[str], TaskTypes) -> Task
""" """
Create a new, non-reproducible Task (experiment). This is called a sub-task. Create a new, non-reproducible Task (experiment). This is called a sub-task.
@ -523,8 +524,8 @@ class Task(_Task):
""" """
Get a Task by Id, or project name / task name combination. Get a Task by Id, or project name / task name combination.
:param str task_id: The Id (system UUID) of the experiment to get. If specified, ``project_name`` and ``task_name`` :param str task_id: The Id (system UUID) of the experiment to get.
are ignored. If specified, ``project_name`` and ``task_name`` are ignored.
:param str project_name: The project name of the Task to get. :param str project_name: The project name of the Task to get.
:param str task_name: The name of the Task within ``project_name`` to get. :param str task_name: The name of the Task within ``project_name`` to get.
@ -1028,7 +1029,7 @@ class Task(_Task):
self.__register_at_exit(None) self.__register_at_exit(None)
def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True): def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True):
# type: (str, "pandas.DataFrame", Dict, Union[bool, Sequence[str]]) -> None # type: (str, pandas.DataFrame, Dict, Union[bool, Sequence[str]]) -> None
""" """
Register (add) an artifact for the current Task. Registered artifacts are dynamically sychronized with the Register (add) an artifact for the current Task. Registered artifacts are dynamically sychronized with the
**Trains Server** (backend). If a registered artifact is updated, the update is stored in the **Trains Server** (backend). If a registered artifact is updated, the update is stored in the
@ -1089,8 +1090,14 @@ class Task(_Task):
""" """
return self._artifacts_manager.registered_artifacts return self._artifacts_manager.registered_artifacts
def upload_artifact(self, name, artifact_object, metadata=None, delete_after_upload=False): def upload_artifact(
# type: (str, Union[str, Mapping, "pandas.DataFrame", "numpy.ndarray", "PIL.Image.Image"], Optional[Mapping], bool) -> bool self,
name, # type: str
artifact_object, # type: Union[str, Mapping, pandas.DataFrame, numpy.ndarray, Image.Image]
metadata=None, # type: Optional[Mapping]
delete_after_upload=False # type: bool
):
# type: (...) -> bool
""" """
Upload (add) a static artifact to a Task object. The artifact is uploaded in the background. Upload (add) a static artifact to a Task object. The artifact is uploaded in the background.
@ -1416,8 +1423,9 @@ class Task(_Task):
task._dev_worker = None task._dev_worker = None
@classmethod @classmethod
def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id, def _create_dev_task(
detect_repo=True): cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id, detect_repo=True
):
if not default_project_name or not default_task_name: if not default_project_name or not default_task_name:
# get project name and task name from repository name and entry_point # get project name and task name from repository name and entry_point
result, _ = ScriptInfo.get(create_requirements=False, check_uncommitted=False) result, _ = ScriptInfo.get(create_requirements=False, check_uncommitted=False)