Add Model and Task type-annotations

This commit is contained in:
allegroai 2020-05-08 22:08:48 +03:00
parent f90f8f06e2
commit 966cd6118a
2 changed files with 127 additions and 58 deletions

View File

@ -6,6 +6,7 @@ from tempfile import mkdtemp, mkstemp
import pyparsing
import six
from typing import List, Dict, Union, Optional, TYPE_CHECKING
from .backend_api import Session
from .backend_api.services import models
@ -20,6 +21,10 @@ from .backend_interface import Task as _Task
from .backend_interface.model import create_dummy_model, Model as _Model
from .config import running_remotely, get_cache_dir
if TYPE_CHECKING:
from .task import Task
ARCHIVED_TAG = "archived"
@ -98,6 +103,7 @@ class BaseModel(object):
@property
def id(self):
# type: () -> str
"""
The Id (system UUID) of the model.
@ -109,6 +115,7 @@ class BaseModel(object):
@property
def name(self):
# type: () -> str
"""
The name of the model.
@ -120,6 +127,7 @@ class BaseModel(object):
@name.setter
def name(self, value):
# type: (str) -> None
"""
Set the model name.
@ -129,6 +137,7 @@ class BaseModel(object):
@property
def comment(self):
# type: () -> str
"""
The comment for the model. Also, use for a model description.
@ -140,6 +149,7 @@ class BaseModel(object):
@comment.setter
def comment(self, value):
# type: (str) -> None
"""
Set comment for the model. Also, use for a model description.
@ -149,6 +159,7 @@ class BaseModel(object):
@property
def tags(self):
# type: () -> List[str]
"""
A list of tags describing the model.
@ -160,6 +171,7 @@ class BaseModel(object):
@tags.setter
def tags(self, value):
# type: (List[str]) -> None
"""
Set the list of tags describing the model.
@ -171,6 +183,7 @@ class BaseModel(object):
@property
def config_text(self):
# type: () -> str
"""
The configuration as a string. For example, prototxt, an ini file, or Python code to evaluate.
@ -182,6 +195,7 @@ class BaseModel(object):
@property
def config_dict(self):
# type: () -> dict
"""
The configuration as a dictionary, parsed from the design text. This usually represents the model configuration.
For example, prototxt, an ini file, or Python code to evaluate.
@ -194,6 +208,7 @@ class BaseModel(object):
@property
def labels(self):
# type: () -> Dict[str, int]
"""
The label enumeration of string (label) to integer (value) pairs.
@ -206,6 +221,7 @@ class BaseModel(object):
@property
def task(self):
# type: () -> str
"""
Return the creating task id (str)
@ -215,6 +231,7 @@ class BaseModel(object):
@property
def url(self):
# type: () -> str
"""
Return the url of the model file (or archived files)
@ -224,19 +241,23 @@ class BaseModel(object):
@property
def published(self):
# type: () -> bool
return self._get_base_model().locked
@property
def framework(self):
# type: () -> str
return self._get_model_data().framework
def __init__(self, task=None):
# type: (Task) -> None
super(BaseModel, self).__init__()
self._log = get_logger()
self._task = None
self._set_task(task)
def get_weights(self):
# type: () -> str
"""
Download the base model and return the locally stored filename.
@ -248,6 +269,7 @@ class BaseModel(object):
return self._get_base_model().download_model_weights()
def get_weights_package(self, return_path=False):
# type: (bool) -> Union[str, List[Path]]
"""
Download the base model package into a temporary directory (extract the files), or return a list of the
locally stored filenames.
@ -368,6 +390,7 @@ class Model(BaseModel):
"""
def __init__(self, model_id):
# type: (str) ->None
"""
Load model based on id, returned object is read-only and can be connected to a task
@ -380,6 +403,7 @@ class Model(BaseModel):
self._base_model = None
def get_local_copy(self, extract_archive=True):
# type: (bool) -> str
"""
Retrieve a valid link to the model file(s).
If the model URL is a file system link, it will be returned directly.
@ -424,17 +448,18 @@ class InputModel(Model):
@classmethod
def import_model(
cls,
weights_url,
config_text=None,
config_dict=None,
label_enumeration=None,
name=None,
tags=None,
comment=None,
is_package=False,
create_as_published=False,
framework=None,
weights_url, # type: str
config_text=None, # type: Optional[str]
config_dict=None, # type: Optional[dict]
label_enumeration=None, # type: Optional[Dict[str, int]]
name=None, # type: Optional[str]
tags=None, # type: Optional[List[str]]
comment=None, # type: Optional[str]
is_package=False, # type: bool
create_as_published=False, # type: bool
framework=None, # type: Optional[str]
):
# type: (...) -> InputModel
"""
Create an InputModel object from a pre-trained model by specifying the URL of an initial weight files.
Optionally, input a configuration, label enumeration, name for the model, tags describing the model,
@ -577,11 +602,8 @@ class InputModel(Model):
return this_model
@classmethod
def load_model(
cls,
weights_url,
load_archived=False
):
def load_model(cls, weights_url, load_archived=False):
# type: (str, bool) -> InputModel
"""
Load an already registered model based on a pre-existing model file (link must be valid).
@ -625,12 +647,8 @@ class InputModel(Model):
return InputModel(model_id=model.id)
@classmethod
def empty(
cls,
config_text=None,
config_dict=None,
label_enumeration=None,
):
def empty(cls, config_text=None, config_dict=None, label_enumeration=None):
# type: (Optional[str], Optional[dict], Optional[Dict[str, int]]) -> InputModel
"""
Create an empty model object. Later, you can assign a model to the empty model object.
@ -664,6 +682,7 @@ class InputModel(Model):
return this_model
def __init__(self, model_id):
# type: (str) -> None
"""
:param str model_id: The Trains Id (system UUID) of the input model whose metadata the **Trains Server**
(backend) stores.
@ -672,9 +691,11 @@ class InputModel(Model):
@property
def id(self):
# type: () -> str
return self._base_model_id
def connect(self, task):
# type: (Task) -> None
"""
Connect the current model to a Task object, if the model is preexisting. Preexisting models include:
@ -738,12 +759,21 @@ class OutputModel(BaseModel):
@property
def published(self):
# type: () -> bool
"""
Get the published state of this model.
:return: ``True`` if the model is published, ``False`` otherwise.
:rtype: bool
"""
if not self.id:
return False
return self._get_base_model().locked
@property
def config_text(self):
# type: () -> str
"""
Get the configuration as a string. For example, prototxt, an ini file, or Python code to evaluate.
@ -755,6 +785,7 @@ class OutputModel(BaseModel):
@config_text.setter
def config_text(self, value):
# type: (str) -> None
"""
Set the configuration. Store a blob of text for custom usage.
"""
@ -762,6 +793,7 @@ class OutputModel(BaseModel):
@property
def config_dict(self):
# type: () -> dict
"""
Get the configuration as a dictionary parsed from the ``config_text`` text. This usually represents the model
configuration. For example, from prototxt to ini file or python code to evaluate.
@ -774,6 +806,7 @@ class OutputModel(BaseModel):
@config_dict.setter
def config_dict(self, value):
# type: (dict) -> None
"""
Set the configuration. Saved in the model object.
@ -783,6 +816,7 @@ class OutputModel(BaseModel):
@property
def labels(self):
# type: () -> Dict[str, int]
"""
Get the label enumeration as a dictionary of string (label) to integer (value) pairs.
@ -803,6 +837,7 @@ class OutputModel(BaseModel):
@labels.setter
def labels(self, value):
# type: (Dict[str, int]) -> None
"""
Set the label enumeration.
@ -822,19 +857,20 @@ class OutputModel(BaseModel):
@property
def upload_storage_uri(self):
# type: () -> str
return self._get_base_model().upload_storage_uri
def __init__(
self,
task,
config_text=None,
config_dict=None,
label_enumeration=None,
name=None,
tags=None,
comment=None,
framework=None,
base_model_id=None,
task, # type: Task
config_text=None, # type: Optional[str]
config_dict=None, # type: Optional[dict]
label_enumeration=None, # type: Optional[Dict[str, int]]
name=None, # type: Optional[str]
tags=None, # type: Optional[List[str]]
comment=None, # type: Optional[str]
framework=None, # type: Optional[Union[str, Framework]]
base_model_id=None, # type: Optional[str]
):
"""
Create a new model and immediately connect it to a task.
@ -908,6 +944,7 @@ class OutputModel(BaseModel):
self.connect(task)
def connect(self, task):
# type: (Task) -> None
"""
Connect the current model to a Task object, if the model is a preexisting model. Preexisting models include:
@ -947,6 +984,7 @@ class OutputModel(BaseModel):
self.task._save_output_model(self)
def set_upload_destination(self, uri):
# type: (str) -> None
"""
Set the URI of the storage destination for uploaded model weight files. Supported storage destinations include
S3, Google Cloud Storage), and file locations.
@ -989,8 +1027,17 @@ class OutputModel(BaseModel):
# store default uri
self._get_base_model().upload_storage_uri = uri
def update_weights(self, weights_filename=None, upload_uri=None, target_filename=None,
auto_delete_file=True, register_uri=None, iteration=None, update_comment=True):
def update_weights(
self,
weights_filename=None, # type: Optional[str]
upload_uri=None, # type: Optional[str]
target_filename=None, # type: Optional[str]
auto_delete_file=True, # type: bool
register_uri=None, # type: Optional[str]
iteration=None, # type: Optional[int]
update_comment=True # type: bool
):
# type: (...) -> str
"""
Update the model weights from a locally stored model filename.
@ -1010,6 +1057,7 @@ class OutputModel(BaseModel):
:param str register_uri: The URI of an already uploaded weights file. The URI must be valid. Specify
``register_uri`` or ``weights_filename``, but not both.
:param int iteration: The iteration number.
:param bool update_comment: Update the model comment with the local weights file name (to maintain
provenance)? (Optional)
@ -1111,8 +1159,16 @@ class OutputModel(BaseModel):
return output_uri
def update_weights_package(self, weights_filenames=None, weights_path=None, upload_uri=None,
target_filename=None, auto_delete_file=True, iteration=None):
def update_weights_package(
self,
weights_filenames=None, # type: Optional[str]
weights_path=None, # type: Optional[str]
upload_uri=None, # type: Optional[str]
target_filename=None, # type: Optional[str]
auto_delete_file=True, # type: bool
iteration=None # type: Optional[int]
):
# type: (...) -> str
"""
Update the model weights from locally stored model files, or from directory containing multiple files.
@ -1134,6 +1190,8 @@ class OutputModel(BaseModel):
- ``True`` - Delete (Default)
- ``False`` - Do not delete
:param int iteration: The iteration number.
:return: The uploaded URI for the weights package.
:rtype: str
@ -1185,6 +1243,7 @@ class OutputModel(BaseModel):
return uploaded_uri
def update_design(self, config_text=None, config_dict=None):
# type: (Optional[str], Optional[dict]) -> bool
"""
Update the model configuration. Store a blob of text for custom usage.
@ -1223,6 +1282,7 @@ class OutputModel(BaseModel):
return result
def update_labels(self, labels):
# type: (Dict[str, int]) -> Optional[Waitable]
"""
Update the label enumeration.
@ -1259,6 +1319,7 @@ class OutputModel(BaseModel):
@classmethod
def wait_for_uploads(cls, timeout=None, max_num_uploads=None):
# type: (Optional[float], Optional[int]) -> None
"""
Wait for any pending or in-progress model uploads to complete. If no uploads are pending or in-progress,
then the ``wait_for_uploads`` returns immediately.

View File

@ -12,7 +12,7 @@ try:
except ImportError:
from collections import Callable, Sequence as CollectionsSequence
from typing import Optional, Union, Mapping, Sequence, Any, Dict, List
from typing import Optional, Union, Mapping, Sequence, Any, Dict, List, TYPE_CHECKING
import psutil
import six
@ -52,6 +52,12 @@ from .utilities.resource_monitor import ResourceMonitor
from .utilities.seed import make_deterministic
if TYPE_CHECKING:
import pandas
import numpy
from PIL import Image
class Task(_Task):
"""
The ``Task`` class is a code template for a Task object which, together with its connected experiment components,
@ -157,17 +163,17 @@ class Task(_Task):
@classmethod
def init(
cls,
project_name=None,
task_name=None,
task_type=TaskTypes.training,
reuse_last_task_id=True,
output_uri=None,
auto_connect_arg_parser=True,
auto_connect_frameworks=True,
auto_resource_monitoring=True,
cls,
project_name=None, # type: Optional[str]
task_name=None, # type: Optional[str]
task_type=TaskTypes.training, # type: Task.TaskTypes
reuse_last_task_id=True, # type: bool
output_uri=None, # type: Optional[str]
auto_connect_arg_parser=True, # type: bool
auto_connect_frameworks=True, # type: Union[bool, Mapping[str, bool]]
auto_resource_monitoring=True, # type: bool
):
# type: (Optional[str], Optional[str], TaskTypes, bool, Optional[str], bool, Union[bool, Mapping[str, bool]], bool) -> Task
# type: (...) -> Task
"""
Creates a new Task (experiment), or returns the existing Task, depending upon the following:
@ -464,12 +470,7 @@ class Task(_Task):
return task
@classmethod
def create(
cls,
project_name=None,
task_name=None,
task_type=TaskTypes.training,
):
def create(cls, project_name=None, task_name=None, task_type=TaskTypes.training):
# type: (Optional[str], Optional[str], TaskTypes) -> Task
"""
Create a new, non-reproducible Task (experiment). This is called a sub-task.
@ -523,8 +524,8 @@ class Task(_Task):
"""
Get a Task by Id, or project name / task name combination.
:param str task_id: The Id (system UUID) of the experiment to get. If specified, ``project_name`` and ``task_name``
are ignored.
:param str task_id: The Id (system UUID) of the experiment to get.
If specified, ``project_name`` and ``task_name`` are ignored.
:param str project_name: The project name of the Task to get.
:param str task_name: The name of the Task within ``project_name`` to get.
@ -1028,7 +1029,7 @@ class Task(_Task):
self.__register_at_exit(None)
def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True):
# type: (str, "pandas.DataFrame", Dict, Union[bool, Sequence[str]]) -> None
# type: (str, pandas.DataFrame, Dict, Union[bool, Sequence[str]]) -> None
"""
Register (add) an artifact for the current Task. Registered artifacts are dynamically sychronized with the
**Trains Server** (backend). If a registered artifact is updated, the update is stored in the
@ -1089,8 +1090,14 @@ class Task(_Task):
"""
return self._artifacts_manager.registered_artifacts
def upload_artifact(self, name, artifact_object, metadata=None, delete_after_upload=False):
# type: (str, Union[str, Mapping, "pandas.DataFrame", "numpy.ndarray", "PIL.Image.Image"], Optional[Mapping], bool) -> bool
def upload_artifact(
self,
name, # type: str
artifact_object, # type: Union[str, Mapping, pandas.DataFrame, numpy.ndarray, Image.Image]
metadata=None, # type: Optional[Mapping]
delete_after_upload=False # type: bool
):
# type: (...) -> bool
"""
Upload (add) a static artifact to a Task object. The artifact is uploaded in the background.
@ -1416,8 +1423,9 @@ class Task(_Task):
task._dev_worker = None
@classmethod
def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id,
detect_repo=True):
def _create_dev_task(
cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id, detect_repo=True
):
if not default_project_name or not default_task_name:
# get project name and task name from repository name and entry_point
result, _ = ScriptInfo.get(create_requirements=False, check_uncommitted=False)