Add type annotations and fix docstrings

This commit is contained in:
allegroai 2020-03-23 23:26:46 +02:00
parent 766c8ab24f
commit c4719f2e2f

View File

@ -12,7 +12,7 @@ try:
except ImportError:
from collections import Callable, Sequence
from typing import Optional
from typing import Optional, Union, Mapping, Sequence as TSequence, Any, Dict, List
import psutil
import six
@ -41,7 +41,7 @@ from .config.cache import SessionCache
from .debugging.log import LoggerRoot
from .errors import UsageError
from .logger import Logger
from .model import InputModel, OutputModel, ARCHIVED_TAG
from .model import Model, InputModel, OutputModel, ARCHIVED_TAG
from .task_parameters import TaskParameters
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
argparser_update_currenttask
@ -99,7 +99,8 @@ class Task(_Task):
def __init__(self, private=None, **kwargs):
"""
Do not construct Task manually!
.. warning::
Do not construct Task manually!
**Please use Task.init() or Task.get_task(id=, project=, name=)**
"""
if private is not Task.__create_protection:
@ -141,6 +142,7 @@ class Task(_Task):
auto_connect_frameworks=True,
auto_resource_monitoring=True,
):
# type: (Optional[str], Optional[str], TaskTypes, bool, Optional[str], bool, Union[bool, Mapping[str, bool]], bool) -> Task
"""
Return the Task object for the main execution task (task context).
@ -369,6 +371,7 @@ class Task(_Task):
task_name=None,
task_type=TaskTypes.training,
):
# type: (Optional[str], Optional[str], TaskTypes) -> Task
"""
Create a new Task object, regardless of the main execution task (Task.init).
@ -403,6 +406,7 @@ class Task(_Task):
@classmethod
def get_task(cls, task_id=None, project_name=None, task_name=None):
# type: (Optional[str], Optional[str], Optional[str]) -> Task
"""
Returns Task object based on either, task_id (system uuid) or task name
@ -415,6 +419,7 @@ class Task(_Task):
@classmethod
def get_tasks(cls, task_ids=None, project_name=None, task_name=None):
# type: (Optional[TSequence[str]], Optional[str], Optional[str]) -> Task
"""
Returns a list of Task objects, matching requested task name (or partially matching)
@ -429,10 +434,12 @@ class Task(_Task):
@property
def output_uri(self):
# type: () -> str
return self.storage_uri
@output_uri.setter
def output_uri(self, value):
# type: (str) -> None
# check if we have the correct packages / configuration
if value and value != self.storage_uri:
from .storage.helper import StorageHelper
@ -445,9 +452,11 @@ class Task(_Task):
@property
def artifacts(self):
# type: () -> Dict[str, Artifact]
"""
read-only dictionary of Task artifacts (name, artifact)
:return: dict
Read-only dictionary of Task artifacts (name, artifact)
:return dict: dictionary of artifacts
"""
if not Session.check_min_api_version('2.3'):
return ReadOnlyDict()
@ -470,6 +479,7 @@ class Task(_Task):
@classmethod
def clone(cls, source_task=None, name=None, comment=None, parent=None, project=None):
# type: (Optional[Task], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]) -> Task
"""
Clone a task object, create a copy a task.
@ -502,6 +512,7 @@ class Task(_Task):
@classmethod
def enqueue(cls, task, queue_name=None, queue_id=None):
# type: (Task, Optional[str], Optional[str]) -> Any
"""
Enqueue (send) a task for execution, by adding it to an execution queue
@ -534,6 +545,7 @@ class Task(_Task):
@classmethod
def dequeue(cls, task):
# type: (Union[Task, str]) -> Any
"""
Dequeue (remove) task from execution queue.
@ -553,10 +565,11 @@ class Task(_Task):
return resp
def add_tags(self, tags):
# type: (Union[Sequence[str], str]) -> None
"""
Add tags to this task. Old tags are not deleted
In remote, this is a no-op.
When running remotely, this method has no effect.
:param tags: An iterable or space separated string of new tags (string) to add.
:type tags: str or iterable of str
@ -570,6 +583,7 @@ class Task(_Task):
self._edit(tags=list(set(self.data.tags)))
def connect(self, mutable):
# type: (Any) -> Any
"""
Connect an object to a task (see introduction to Task connect design)
@ -597,6 +611,7 @@ class Task(_Task):
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
def connect_configuration(self, configuration):
# type: (Union[Mapping, Path, str]) -> Union[Mapping, Path, str]
"""
Connect a configuration dict / file (pathlib.Path / str) with the Task
Connecting configuration file should be called before reading the configuration file.
@ -626,14 +641,14 @@ class Task(_Task):
# parameter dictionary
if isinstance(configuration, dict):
def _update_config_dict(task, config_dict):
task.set_model_config(config_dict=config_dict)
task._set_model_config(config_dict=config_dict)
if not running_remotely() or not self.is_main_task():
self.set_model_config(config_dict=configuration)
self._set_model_config(config_dict=configuration)
configuration = ProxyDictPostWrite(self, _update_config_dict, **configuration)
else:
configuration.clear()
configuration.update(self.get_model_config_dict())
configuration.update(self._get_model_config_dict())
configuration = ProxyDictPreWrite(False, False, **configuration)
return configuration
@ -649,10 +664,10 @@ class Task(_Task):
except Exception:
raise ValueError("Could not connect configuration file {}, file could not be read".format(
configuration_path.as_posix()))
self.set_model_config(config_text=configuration_text)
self._set_model_config(config_text=configuration_text)
return configuration
else:
configuration_text = self.get_model_config_text()
configuration_text = self._get_model_config_text()
configuration_path = Path(configuration)
fd, local_filename = mkstemp(prefix='trains_task_config_',
suffix=configuration_path.suffixes[-1] if
@ -662,6 +677,7 @@ class Task(_Task):
return Path(local_filename) if isinstance(configuration, Path) else local_filename
def connect_label_enumeration(self, enumeration):
# type: (Dict[str, int]) -> Dict[str, int]
"""
Connect a label enumeration dictionary with the Task
@ -686,7 +702,7 @@ class Task(_Task):
def get_logger(self):
# type: () -> Logger
"""
get a logger object for reporting, for this task context.
Get a logger object for reporting, for this task context.
All reports (metrics, text etc.) related to this task are accessible in the web UI
:return: Logger object
@ -695,7 +711,7 @@ class Task(_Task):
def mark_started(self):
"""
Manually Mark the task as started (will happen automatically)
Manually Mark the task as started (happens automatically)
"""
# UI won't let us see metrics if we're not started
self.started()
@ -703,7 +719,7 @@ class Task(_Task):
def mark_stopped(self):
"""
Manually Mark the task as stopped (also used in self._at_exit)
Manually Mark the task as stopped (also used in :func:`_at_exit`)
"""
# flush any outstanding logs
self.flush(wait_for_uploads=True)
@ -711,8 +727,9 @@ class Task(_Task):
self.stopped()
def flush(self, wait_for_uploads=False):
# type: (bool) -> bool
"""
flush any outstanding reports or console logs
Flush any outstanding reports or console logs
:param wait_for_uploads: if True the flush will exit only after all outstanding uploads are completed
"""
@ -736,6 +753,7 @@ class Task(_Task):
return True
def reset(self, set_started_on_success=False, force=False):
# type: (bool, bool) -> None
"""
Reset the task. Task will be reloaded following a successful reset.
@ -759,8 +777,9 @@ class Task(_Task):
self.__register_at_exit(None)
def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True):
# type: (str, "pandas.DataFrame", Dict, Union[bool, TSequence[str]]) -> None
"""
Add artifact for the current Task, used mostly for Data Audition.
Add artifact for the current Task, used mostly for Data Auditing.
Currently supported artifacts object types: pandas.DataFrame
:param str name: name of the artifacts. Notice! it will override previous artifacts if name already exists.
@ -776,6 +795,7 @@ class Task(_Task):
self._artifacts_manager.register_artifact(name=name, artifact=artifact, metadata=metadata, uniqueness_columns=uniqueness_columns)
def unregister_artifact(self, name):
# type: (str) -> None
"""
Remove artifact from the watch list. Notice this will not remove the artifacts from the Task.
It will only stop monitoring the artifact,
@ -784,6 +804,7 @@ class Task(_Task):
self._artifacts_manager.unregister_artifact(name=name)
def get_registered_artifacts(self):
# type: () -> Dict[str, Artifact]
"""
dictionary of Task registered artifacts (name, artifact object)
Notice these objects can be modified, changes will be uploaded automatically
@ -793,9 +814,11 @@ class Task(_Task):
return self._artifacts_manager.registered_artifacts
def upload_artifact(self, name, artifact_object, metadata=None, delete_after_upload=False):
# type: (str, Union[str, Mapping, "pandas.DataFrame", "numpy.ndarray", "PIL.Image.Image"], Optional[Mapping], bool) -> bool
"""
Add static artifact to Task. Artifact file/object will be uploaded in the background
Raise ValueError if artifact_object is not supported
:raises ValueError: if artifact_object is not supported
:param str name: Artifact name. Notice! it will override previous artifact if name already exists
:param object artifact_object: Artifact object to upload. Currently supports:
@ -814,8 +837,9 @@ class Task(_Task):
metadata=metadata, delete_after_upload=delete_after_upload)
def get_models(self):
# type: () -> Dict[str, List[Model]]
"""
Return a dictionary with {'input': [], 'output': []} loaded/stored models of the current Task.
Return a dictionary with {'input': [], 'output': []} loaded/stored models of the current Task
Input models are files loaded in the task, either manually or automatically logged
Output models are files stored in the task, either manually or automatically logged
Automatically logged frameworks are for example: TensorFlow, Keras, PyTorch, ScikitLearn(joblib) etc.
@ -828,26 +852,29 @@ class Task(_Task):
return task_models
def is_current_task(self):
# type: () -> bool
"""
Check if this task is the main task (returned by Task.init())
NOTE: This call is deprecated. Please use Task.is_main_task()
.. deprecated:: 0.1.0
Use :func:`is_main_task()` instead
If Task.init() was never called, this method will *not* create
it, making this test cheaper than Task.init() == task
:return: True if this task is the current task
:return: True if this task is the main task
"""
return self.is_main_task()
def is_main_task(self):
# type: () -> bool
"""
Check if this task is the main task (returned by Task.init())
Check if this task is the main task (created/returned by Task.init())
If Task.init() was never called, this method will *not* create
it, making this test cheaper than Task.init() == task
:return: True if this task is the current task
:return: True if this task is the main task
"""
return self is self.__main_task
@ -876,6 +903,7 @@ class Task(_Task):
return self._get_model_config_dict()
def set_model_label_enumeration(self, enumeration=None):
# type: (Optional[Mapping[str, int]]) -> ()
"""
Set Task output label enumeration (before creating an output model)
When an output model is created it will inherit these properties
@ -886,6 +914,7 @@ class Task(_Task):
super(Task, self).set_model_label_enumeration(enumeration=enumeration)
def get_last_iteration(self):
# type: () -> int
"""
Return the maximum reported iteration (i.e. the maximum iteration the task reported a metric for)
Notice, this is not a cached call, it will ask the backend for the answer (no local caching)
@ -896,6 +925,7 @@ class Task(_Task):
return max(self.data.last_iteration, self._reporter.max_iteration if self._reporter else 0)
def set_last_iteration(self, last_iteration):
# type: (int) -> None
"""
Forcefully set the last reported iteration
(i.e. the maximum iteration the task reported a metric for)
@ -907,6 +937,7 @@ class Task(_Task):
self._edit(last_iteration=self.data.last_iteration)
def set_initial_iteration(self, offset=0):
# type: (int) -> int
"""
Set initial iteration, instead of zero. Useful when continuing training from previous checkpoints
@ -916,17 +947,19 @@ class Task(_Task):
return super(Task, self).set_initial_iteration(offset=offset)
def get_initial_iteration(self):
# type: () -> int
"""
Return the initial iteration offset, default is 0.
Useful when continuing training from previous checkpoints.
Return the initial iteration offset, default is 0
Useful when continuing training from previous checkpoints
:return int: initial iteration offset
"""
return super(Task, self).get_initial_iteration()
def get_last_scalar_metrics(self):
# type: () -> Dict[str, Dict[str, Dict[str, float]]]
"""
Extract the last scalar metrics, ordered by title & series in a nested dictionary
Extract the last scalar metrics, ordered by title and series in a nested dictionary
:return: dict. Example: {'title': {'series': {'last': 0.5, 'min': 0.1, 'max': 0.9}}}
"""
@ -940,34 +973,40 @@ class Task(_Task):
return scalar_metrics
def get_parameters_as_dict(self):
# type: () -> Dict
"""
Get task parameters as a raw nested dict
Note that values are not parsed and returned as is (i.e. string)
.. note::
values are not parsed and returned as is (i.e. string)
"""
return naive_nested_from_flat_dictionary(self.get_parameters())
def set_parameters_as_dict(self, dictionary):
# type: (Dict) -> None
"""
Set task parameters from a (possibly nested) dict.
Set task parameters from a (possibly nested) dict
While parameters are set just as they would be in connect(dict), this does not link the dict to the task,
but rather does a one-time update.
but rather performs a one-time update.
"""
self._arguments.copy_from_dict(flatten_dictionary(dictionary))
@classmethod
def set_credentials(cls, api_host=None, web_host=None, files_host=None, key=None, secret=None, host=None):
# type: (Optional[str], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]) -> ()
"""
Set new default TRAINS-server host and credentials
Set new default trains-server host values and credentials
These configurations will be overridden by either OS environment variables or trains.conf configuration file
Notice! credentials needs to be set *prior* to Task initialization
.. note::
credentials need to be set *prior* to Task initialization
:param str api_host: Trains API server url, example: host='http://localhost:8008'
:param str web_host: Trains WEB server url, example: host='http://localhost:8080'
:param str files_host: Trains Files server url, example: host='http://localhost:8081'
:param str key: user key/secret pair, example: key='thisisakey123'
:param str secret: user key/secret pair, example: secret='thisisseceret123'
:param str host: host url, example: host='http://localhost:8008' (deprecated)
:param str host: host url (overrides api_host), example: host='http://localhost:8008'
"""
if api_host:
Session.default_host = api_host