mirror of
https://github.com/clearml/clearml
synced 2025-05-07 06:14:31 +00:00
Add Model and Task type-annotations
This commit is contained in:
parent
f90f8f06e2
commit
966cd6118a
129
trains/model.py
129
trains/model.py
@ -6,6 +6,7 @@ from tempfile import mkdtemp, mkstemp
|
||||
|
||||
import pyparsing
|
||||
import six
|
||||
from typing import List, Dict, Union, Optional, TYPE_CHECKING
|
||||
|
||||
from .backend_api import Session
|
||||
from .backend_api.services import models
|
||||
@ -20,6 +21,10 @@ from .backend_interface import Task as _Task
|
||||
from .backend_interface.model import create_dummy_model, Model as _Model
|
||||
from .config import running_remotely, get_cache_dir
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .task import Task
|
||||
|
||||
ARCHIVED_TAG = "archived"
|
||||
|
||||
|
||||
@ -98,6 +103,7 @@ class BaseModel(object):
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
The Id (system UUID) of the model.
|
||||
|
||||
@ -109,6 +115,7 @@ class BaseModel(object):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
The name of the model.
|
||||
|
||||
@ -120,6 +127,7 @@ class BaseModel(object):
|
||||
|
||||
@name.setter
|
||||
def name(self, value):
|
||||
# type: (str) -> None
|
||||
"""
|
||||
Set the model name.
|
||||
|
||||
@ -129,6 +137,7 @@ class BaseModel(object):
|
||||
|
||||
@property
|
||||
def comment(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
The comment for the model. Also, use for a model description.
|
||||
|
||||
@ -140,6 +149,7 @@ class BaseModel(object):
|
||||
|
||||
@comment.setter
|
||||
def comment(self, value):
|
||||
# type: (str) -> None
|
||||
"""
|
||||
Set comment for the model. Also, use for a model description.
|
||||
|
||||
@ -149,6 +159,7 @@ class BaseModel(object):
|
||||
|
||||
@property
|
||||
def tags(self):
|
||||
# type: () -> List[str]
|
||||
"""
|
||||
A list of tags describing the model.
|
||||
|
||||
@ -160,6 +171,7 @@ class BaseModel(object):
|
||||
|
||||
@tags.setter
|
||||
def tags(self, value):
|
||||
# type: (List[str]) -> None
|
||||
"""
|
||||
Set the list of tags describing the model.
|
||||
|
||||
@ -171,6 +183,7 @@ class BaseModel(object):
|
||||
|
||||
@property
|
||||
def config_text(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
The configuration as a string. For example, prototxt, an ini file, or Python code to evaluate.
|
||||
|
||||
@ -182,6 +195,7 @@ class BaseModel(object):
|
||||
|
||||
@property
|
||||
def config_dict(self):
|
||||
# type: () -> dict
|
||||
"""
|
||||
The configuration as a dictionary, parsed from the design text. This usually represents the model configuration.
|
||||
For example, prototxt, an ini file, or Python code to evaluate.
|
||||
@ -194,6 +208,7 @@ class BaseModel(object):
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
# type: () -> Dict[str, int]
|
||||
"""
|
||||
The label enumeration of string (label) to integer (value) pairs.
|
||||
|
||||
@ -206,6 +221,7 @@ class BaseModel(object):
|
||||
|
||||
@property
|
||||
def task(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
Return the creating task id (str)
|
||||
|
||||
@ -215,6 +231,7 @@ class BaseModel(object):
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
Return the url of the model file (or archived files)
|
||||
|
||||
@ -224,19 +241,23 @@ class BaseModel(object):
|
||||
|
||||
@property
|
||||
def published(self):
|
||||
# type: () -> bool
|
||||
return self._get_base_model().locked
|
||||
|
||||
@property
|
||||
def framework(self):
|
||||
# type: () -> str
|
||||
return self._get_model_data().framework
|
||||
|
||||
def __init__(self, task=None):
|
||||
# type: (Task) -> None
|
||||
super(BaseModel, self).__init__()
|
||||
self._log = get_logger()
|
||||
self._task = None
|
||||
self._set_task(task)
|
||||
|
||||
def get_weights(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
Download the base model and return the locally stored filename.
|
||||
|
||||
@ -248,6 +269,7 @@ class BaseModel(object):
|
||||
return self._get_base_model().download_model_weights()
|
||||
|
||||
def get_weights_package(self, return_path=False):
|
||||
# type: (bool) -> Union[str, List[Path]]
|
||||
"""
|
||||
Download the base model package into a temporary directory (extract the files), or return a list of the
|
||||
locally stored filenames.
|
||||
@ -368,6 +390,7 @@ class Model(BaseModel):
|
||||
"""
|
||||
|
||||
def __init__(self, model_id):
|
||||
# type: (str) ->None
|
||||
"""
|
||||
Load model based on id, returned object is read-only and can be connected to a task
|
||||
|
||||
@ -380,6 +403,7 @@ class Model(BaseModel):
|
||||
self._base_model = None
|
||||
|
||||
def get_local_copy(self, extract_archive=True):
|
||||
# type: (bool) -> str
|
||||
"""
|
||||
Retrieve a valid link to the model file(s).
|
||||
If the model URL is a file system link, it will be returned directly.
|
||||
@ -424,17 +448,18 @@ class InputModel(Model):
|
||||
@classmethod
|
||||
def import_model(
|
||||
cls,
|
||||
weights_url,
|
||||
config_text=None,
|
||||
config_dict=None,
|
||||
label_enumeration=None,
|
||||
name=None,
|
||||
tags=None,
|
||||
comment=None,
|
||||
is_package=False,
|
||||
create_as_published=False,
|
||||
framework=None,
|
||||
weights_url, # type: str
|
||||
config_text=None, # type: Optional[str]
|
||||
config_dict=None, # type: Optional[dict]
|
||||
label_enumeration=None, # type: Optional[Dict[str, int]]
|
||||
name=None, # type: Optional[str]
|
||||
tags=None, # type: Optional[List[str]]
|
||||
comment=None, # type: Optional[str]
|
||||
is_package=False, # type: bool
|
||||
create_as_published=False, # type: bool
|
||||
framework=None, # type: Optional[str]
|
||||
):
|
||||
# type: (...) -> InputModel
|
||||
"""
|
||||
Create an InputModel object from a pre-trained model by specifying the URL of an initial weight files.
|
||||
Optionally, input a configuration, label enumeration, name for the model, tags describing the model,
|
||||
@ -577,11 +602,8 @@ class InputModel(Model):
|
||||
return this_model
|
||||
|
||||
@classmethod
|
||||
def load_model(
|
||||
cls,
|
||||
weights_url,
|
||||
load_archived=False
|
||||
):
|
||||
def load_model(cls, weights_url, load_archived=False):
|
||||
# type: (str, bool) -> InputModel
|
||||
"""
|
||||
Load an already registered model based on a pre-existing model file (link must be valid).
|
||||
|
||||
@ -625,12 +647,8 @@ class InputModel(Model):
|
||||
return InputModel(model_id=model.id)
|
||||
|
||||
@classmethod
|
||||
def empty(
|
||||
cls,
|
||||
config_text=None,
|
||||
config_dict=None,
|
||||
label_enumeration=None,
|
||||
):
|
||||
def empty(cls, config_text=None, config_dict=None, label_enumeration=None):
|
||||
# type: (Optional[str], Optional[dict], Optional[Dict[str, int]]) -> InputModel
|
||||
"""
|
||||
Create an empty model object. Later, you can assign a model to the empty model object.
|
||||
|
||||
@ -664,6 +682,7 @@ class InputModel(Model):
|
||||
return this_model
|
||||
|
||||
def __init__(self, model_id):
|
||||
# type: (str) -> None
|
||||
"""
|
||||
:param str model_id: The Trains Id (system UUID) of the input model whose metadata the **Trains Server**
|
||||
(backend) stores.
|
||||
@ -672,9 +691,11 @@ class InputModel(Model):
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
# type: () -> str
|
||||
return self._base_model_id
|
||||
|
||||
def connect(self, task):
|
||||
# type: (Task) -> None
|
||||
"""
|
||||
Connect the current model to a Task object, if the model is preexisting. Preexisting models include:
|
||||
|
||||
@ -738,12 +759,21 @@ class OutputModel(BaseModel):
|
||||
|
||||
@property
|
||||
def published(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Get the published state of this model.
|
||||
|
||||
:return: ``True`` if the model is published, ``False`` otherwise.
|
||||
|
||||
:rtype: bool
|
||||
"""
|
||||
if not self.id:
|
||||
return False
|
||||
return self._get_base_model().locked
|
||||
|
||||
@property
|
||||
def config_text(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
Get the configuration as a string. For example, prototxt, an ini file, or Python code to evaluate.
|
||||
|
||||
@ -755,6 +785,7 @@ class OutputModel(BaseModel):
|
||||
|
||||
@config_text.setter
|
||||
def config_text(self, value):
|
||||
# type: (str) -> None
|
||||
"""
|
||||
Set the configuration. Store a blob of text for custom usage.
|
||||
"""
|
||||
@ -762,6 +793,7 @@ class OutputModel(BaseModel):
|
||||
|
||||
@property
|
||||
def config_dict(self):
|
||||
# type: () -> dict
|
||||
"""
|
||||
Get the configuration as a dictionary parsed from the ``config_text`` text. This usually represents the model
|
||||
configuration. For example, from prototxt to ini file or python code to evaluate.
|
||||
@ -774,6 +806,7 @@ class OutputModel(BaseModel):
|
||||
|
||||
@config_dict.setter
|
||||
def config_dict(self, value):
|
||||
# type: (dict) -> None
|
||||
"""
|
||||
Set the configuration. Saved in the model object.
|
||||
|
||||
@ -783,6 +816,7 @@ class OutputModel(BaseModel):
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
# type: () -> Dict[str, int]
|
||||
"""
|
||||
Get the label enumeration as a dictionary of string (label) to integer (value) pairs.
|
||||
|
||||
@ -803,6 +837,7 @@ class OutputModel(BaseModel):
|
||||
|
||||
@labels.setter
|
||||
def labels(self, value):
|
||||
# type: (Dict[str, int]) -> None
|
||||
"""
|
||||
Set the label enumeration.
|
||||
|
||||
@ -822,19 +857,20 @@ class OutputModel(BaseModel):
|
||||
|
||||
@property
|
||||
def upload_storage_uri(self):
|
||||
# type: () -> str
|
||||
return self._get_base_model().upload_storage_uri
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
config_text=None,
|
||||
config_dict=None,
|
||||
label_enumeration=None,
|
||||
name=None,
|
||||
tags=None,
|
||||
comment=None,
|
||||
framework=None,
|
||||
base_model_id=None,
|
||||
task, # type: Task
|
||||
config_text=None, # type: Optional[str]
|
||||
config_dict=None, # type: Optional[dict]
|
||||
label_enumeration=None, # type: Optional[Dict[str, int]]
|
||||
name=None, # type: Optional[str]
|
||||
tags=None, # type: Optional[List[str]]
|
||||
comment=None, # type: Optional[str]
|
||||
framework=None, # type: Optional[Union[str, Framework]]
|
||||
base_model_id=None, # type: Optional[str]
|
||||
):
|
||||
"""
|
||||
Create a new model and immediately connect it to a task.
|
||||
@ -908,6 +944,7 @@ class OutputModel(BaseModel):
|
||||
self.connect(task)
|
||||
|
||||
def connect(self, task):
|
||||
# type: (Task) -> None
|
||||
"""
|
||||
Connect the current model to a Task object, if the model is a preexisting model. Preexisting models include:
|
||||
|
||||
@ -947,6 +984,7 @@ class OutputModel(BaseModel):
|
||||
self.task._save_output_model(self)
|
||||
|
||||
def set_upload_destination(self, uri):
|
||||
# type: (str) -> None
|
||||
"""
|
||||
Set the URI of the storage destination for uploaded model weight files. Supported storage destinations include
|
||||
S3, Google Cloud Storage), and file locations.
|
||||
@ -989,8 +1027,17 @@ class OutputModel(BaseModel):
|
||||
# store default uri
|
||||
self._get_base_model().upload_storage_uri = uri
|
||||
|
||||
def update_weights(self, weights_filename=None, upload_uri=None, target_filename=None,
|
||||
auto_delete_file=True, register_uri=None, iteration=None, update_comment=True):
|
||||
def update_weights(
|
||||
self,
|
||||
weights_filename=None, # type: Optional[str]
|
||||
upload_uri=None, # type: Optional[str]
|
||||
target_filename=None, # type: Optional[str]
|
||||
auto_delete_file=True, # type: bool
|
||||
register_uri=None, # type: Optional[str]
|
||||
iteration=None, # type: Optional[int]
|
||||
update_comment=True # type: bool
|
||||
):
|
||||
# type: (...) -> str
|
||||
"""
|
||||
Update the model weights from a locally stored model filename.
|
||||
|
||||
@ -1010,6 +1057,7 @@ class OutputModel(BaseModel):
|
||||
|
||||
:param str register_uri: The URI of an already uploaded weights file. The URI must be valid. Specify
|
||||
``register_uri`` or ``weights_filename``, but not both.
|
||||
:param int iteration: The iteration number.
|
||||
:param bool update_comment: Update the model comment with the local weights file name (to maintain
|
||||
provenance)? (Optional)
|
||||
|
||||
@ -1111,8 +1159,16 @@ class OutputModel(BaseModel):
|
||||
|
||||
return output_uri
|
||||
|
||||
def update_weights_package(self, weights_filenames=None, weights_path=None, upload_uri=None,
|
||||
target_filename=None, auto_delete_file=True, iteration=None):
|
||||
def update_weights_package(
|
||||
self,
|
||||
weights_filenames=None, # type: Optional[str]
|
||||
weights_path=None, # type: Optional[str]
|
||||
upload_uri=None, # type: Optional[str]
|
||||
target_filename=None, # type: Optional[str]
|
||||
auto_delete_file=True, # type: bool
|
||||
iteration=None # type: Optional[int]
|
||||
):
|
||||
# type: (...) -> str
|
||||
"""
|
||||
Update the model weights from locally stored model files, or from directory containing multiple files.
|
||||
|
||||
@ -1134,6 +1190,8 @@ class OutputModel(BaseModel):
|
||||
- ``True`` - Delete (Default)
|
||||
- ``False`` - Do not delete
|
||||
|
||||
:param int iteration: The iteration number.
|
||||
|
||||
:return: The uploaded URI for the weights package.
|
||||
|
||||
:rtype: str
|
||||
@ -1185,6 +1243,7 @@ class OutputModel(BaseModel):
|
||||
return uploaded_uri
|
||||
|
||||
def update_design(self, config_text=None, config_dict=None):
|
||||
# type: (Optional[str], Optional[dict]) -> bool
|
||||
"""
|
||||
Update the model configuration. Store a blob of text for custom usage.
|
||||
|
||||
@ -1223,6 +1282,7 @@ class OutputModel(BaseModel):
|
||||
return result
|
||||
|
||||
def update_labels(self, labels):
|
||||
# type: (Dict[str, int]) -> Optional[Waitable]
|
||||
"""
|
||||
Update the label enumeration.
|
||||
|
||||
@ -1259,6 +1319,7 @@ class OutputModel(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def wait_for_uploads(cls, timeout=None, max_num_uploads=None):
|
||||
# type: (Optional[float], Optional[int]) -> None
|
||||
"""
|
||||
Wait for any pending or in-progress model uploads to complete. If no uploads are pending or in-progress,
|
||||
then the ``wait_for_uploads`` returns immediately.
|
||||
|
@ -12,7 +12,7 @@ try:
|
||||
except ImportError:
|
||||
from collections import Callable, Sequence as CollectionsSequence
|
||||
|
||||
from typing import Optional, Union, Mapping, Sequence, Any, Dict, List
|
||||
from typing import Optional, Union, Mapping, Sequence, Any, Dict, List, TYPE_CHECKING
|
||||
|
||||
import psutil
|
||||
import six
|
||||
@ -52,6 +52,12 @@ from .utilities.resource_monitor import ResourceMonitor
|
||||
from .utilities.seed import make_deterministic
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas
|
||||
import numpy
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Task(_Task):
|
||||
"""
|
||||
The ``Task`` class is a code template for a Task object which, together with its connected experiment components,
|
||||
@ -157,17 +163,17 @@ class Task(_Task):
|
||||
|
||||
@classmethod
|
||||
def init(
|
||||
cls,
|
||||
project_name=None,
|
||||
task_name=None,
|
||||
task_type=TaskTypes.training,
|
||||
reuse_last_task_id=True,
|
||||
output_uri=None,
|
||||
auto_connect_arg_parser=True,
|
||||
auto_connect_frameworks=True,
|
||||
auto_resource_monitoring=True,
|
||||
cls,
|
||||
project_name=None, # type: Optional[str]
|
||||
task_name=None, # type: Optional[str]
|
||||
task_type=TaskTypes.training, # type: Task.TaskTypes
|
||||
reuse_last_task_id=True, # type: bool
|
||||
output_uri=None, # type: Optional[str]
|
||||
auto_connect_arg_parser=True, # type: bool
|
||||
auto_connect_frameworks=True, # type: Union[bool, Mapping[str, bool]]
|
||||
auto_resource_monitoring=True, # type: bool
|
||||
):
|
||||
# type: (Optional[str], Optional[str], TaskTypes, bool, Optional[str], bool, Union[bool, Mapping[str, bool]], bool) -> Task
|
||||
# type: (...) -> Task
|
||||
"""
|
||||
Creates a new Task (experiment), or returns the existing Task, depending upon the following:
|
||||
|
||||
@ -464,12 +470,7 @@ class Task(_Task):
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
project_name=None,
|
||||
task_name=None,
|
||||
task_type=TaskTypes.training,
|
||||
):
|
||||
def create(cls, project_name=None, task_name=None, task_type=TaskTypes.training):
|
||||
# type: (Optional[str], Optional[str], TaskTypes) -> Task
|
||||
"""
|
||||
Create a new, non-reproducible Task (experiment). This is called a sub-task.
|
||||
@ -523,8 +524,8 @@ class Task(_Task):
|
||||
"""
|
||||
Get a Task by Id, or project name / task name combination.
|
||||
|
||||
:param str task_id: The Id (system UUID) of the experiment to get. If specified, ``project_name`` and ``task_name``
|
||||
are ignored.
|
||||
:param str task_id: The Id (system UUID) of the experiment to get.
|
||||
If specified, ``project_name`` and ``task_name`` are ignored.
|
||||
:param str project_name: The project name of the Task to get.
|
||||
:param str task_name: The name of the Task within ``project_name`` to get.
|
||||
|
||||
@ -1028,7 +1029,7 @@ class Task(_Task):
|
||||
self.__register_at_exit(None)
|
||||
|
||||
def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True):
|
||||
# type: (str, "pandas.DataFrame", Dict, Union[bool, Sequence[str]]) -> None
|
||||
# type: (str, pandas.DataFrame, Dict, Union[bool, Sequence[str]]) -> None
|
||||
"""
|
||||
Register (add) an artifact for the current Task. Registered artifacts are dynamically sychronized with the
|
||||
**Trains Server** (backend). If a registered artifact is updated, the update is stored in the
|
||||
@ -1089,8 +1090,14 @@ class Task(_Task):
|
||||
"""
|
||||
return self._artifacts_manager.registered_artifacts
|
||||
|
||||
def upload_artifact(self, name, artifact_object, metadata=None, delete_after_upload=False):
|
||||
# type: (str, Union[str, Mapping, "pandas.DataFrame", "numpy.ndarray", "PIL.Image.Image"], Optional[Mapping], bool) -> bool
|
||||
def upload_artifact(
|
||||
self,
|
||||
name, # type: str
|
||||
artifact_object, # type: Union[str, Mapping, pandas.DataFrame, numpy.ndarray, Image.Image]
|
||||
metadata=None, # type: Optional[Mapping]
|
||||
delete_after_upload=False # type: bool
|
||||
):
|
||||
# type: (...) -> bool
|
||||
"""
|
||||
Upload (add) a static artifact to a Task object. The artifact is uploaded in the background.
|
||||
|
||||
@ -1416,8 +1423,9 @@ class Task(_Task):
|
||||
task._dev_worker = None
|
||||
|
||||
@classmethod
|
||||
def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id,
|
||||
detect_repo=True):
|
||||
def _create_dev_task(
|
||||
cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id, detect_repo=True
|
||||
):
|
||||
if not default_project_name or not default_task_name:
|
||||
# get project name and task name from repository name and entry_point
|
||||
result, _ = ScriptInfo.get(create_requirements=False, check_uncommitted=False)
|
||||
|
Loading…
Reference in New Issue
Block a user