mirror of
https://github.com/clearml/clearml
synced 2025-04-05 13:15:17 +00:00
Add type annotations and fix docstrings
This commit is contained in:
parent
766c8ab24f
commit
c4719f2e2f
101
trains/task.py
101
trains/task.py
@ -12,7 +12,7 @@ try:
|
||||
except ImportError:
|
||||
from collections import Callable, Sequence
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Union, Mapping, Sequence as TSequence, Any, Dict, List
|
||||
|
||||
import psutil
|
||||
import six
|
||||
@ -41,7 +41,7 @@ from .config.cache import SessionCache
|
||||
from .debugging.log import LoggerRoot
|
||||
from .errors import UsageError
|
||||
from .logger import Logger
|
||||
from .model import InputModel, OutputModel, ARCHIVED_TAG
|
||||
from .model import Model, InputModel, OutputModel, ARCHIVED_TAG
|
||||
from .task_parameters import TaskParameters
|
||||
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
|
||||
argparser_update_currenttask
|
||||
@ -99,7 +99,8 @@ class Task(_Task):
|
||||
|
||||
def __init__(self, private=None, **kwargs):
|
||||
"""
|
||||
Do not construct Task manually!
|
||||
.. warning::
|
||||
Do not construct Task manually!
|
||||
**Please use Task.init() or Task.get_task(id=, project=, name=)**
|
||||
"""
|
||||
if private is not Task.__create_protection:
|
||||
@ -141,6 +142,7 @@ class Task(_Task):
|
||||
auto_connect_frameworks=True,
|
||||
auto_resource_monitoring=True,
|
||||
):
|
||||
# type: (Optional[str], Optional[str], TaskTypes, bool, Optional[str], bool, Union[bool, Mapping[str, bool]], bool) -> Task
|
||||
"""
|
||||
Return the Task object for the main execution task (task context).
|
||||
|
||||
@ -369,6 +371,7 @@ class Task(_Task):
|
||||
task_name=None,
|
||||
task_type=TaskTypes.training,
|
||||
):
|
||||
# type: (Optional[str], Optional[str], TaskTypes) -> Task
|
||||
"""
|
||||
Create a new Task object, regardless of the main execution task (Task.init).
|
||||
|
||||
@ -403,6 +406,7 @@ class Task(_Task):
|
||||
|
||||
@classmethod
|
||||
def get_task(cls, task_id=None, project_name=None, task_name=None):
|
||||
# type: (Optional[str], Optional[str], Optional[str]) -> Task
|
||||
"""
|
||||
Returns Task object based on either, task_id (system uuid) or task name
|
||||
|
||||
@ -415,6 +419,7 @@ class Task(_Task):
|
||||
|
||||
@classmethod
|
||||
def get_tasks(cls, task_ids=None, project_name=None, task_name=None):
|
||||
# type: (Optional[TSequence[str]], Optional[str], Optional[str]) -> Task
|
||||
"""
|
||||
Returns a list of Task objects, matching requested task name (or partially matching)
|
||||
|
||||
@ -429,10 +434,12 @@ class Task(_Task):
|
||||
|
||||
@property
|
||||
def output_uri(self):
|
||||
# type: () -> str
|
||||
return self.storage_uri
|
||||
|
||||
@output_uri.setter
|
||||
def output_uri(self, value):
|
||||
# type: (str) -> None
|
||||
# check if we have the correct packages / configuration
|
||||
if value and value != self.storage_uri:
|
||||
from .storage.helper import StorageHelper
|
||||
@ -445,9 +452,11 @@ class Task(_Task):
|
||||
|
||||
@property
|
||||
def artifacts(self):
|
||||
# type: () -> Dict[str, Artifact]
|
||||
"""
|
||||
read-only dictionary of Task artifacts (name, artifact)
|
||||
:return: dict
|
||||
Read-only dictionary of Task artifacts (name, artifact)
|
||||
|
||||
:return dict: dictionary of artifacts
|
||||
"""
|
||||
if not Session.check_min_api_version('2.3'):
|
||||
return ReadOnlyDict()
|
||||
@ -470,6 +479,7 @@ class Task(_Task):
|
||||
|
||||
@classmethod
|
||||
def clone(cls, source_task=None, name=None, comment=None, parent=None, project=None):
|
||||
# type: (Optional[Task], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]) -> Task
|
||||
"""
|
||||
Clone a task object, create a copy a task.
|
||||
|
||||
@ -502,6 +512,7 @@ class Task(_Task):
|
||||
|
||||
@classmethod
|
||||
def enqueue(cls, task, queue_name=None, queue_id=None):
|
||||
# type: (Task, Optional[str], Optional[str]) -> Any
|
||||
"""
|
||||
Enqueue (send) a task for execution, by adding it to an execution queue
|
||||
|
||||
@ -534,6 +545,7 @@ class Task(_Task):
|
||||
|
||||
@classmethod
|
||||
def dequeue(cls, task):
|
||||
# type: (Union[Task, str]) -> Any
|
||||
"""
|
||||
Dequeue (remove) task from execution queue.
|
||||
|
||||
@ -553,10 +565,11 @@ class Task(_Task):
|
||||
return resp
|
||||
|
||||
def add_tags(self, tags):
|
||||
# type: (Union[Sequence[str], str]) -> None
|
||||
"""
|
||||
Add tags to this task. Old tags are not deleted
|
||||
|
||||
In remote, this is a no-op.
|
||||
When running remotely, this method has no effect.
|
||||
|
||||
:param tags: An iterable or space separated string of new tags (string) to add.
|
||||
:type tags: str or iterable of str
|
||||
@ -570,6 +583,7 @@ class Task(_Task):
|
||||
self._edit(tags=list(set(self.data.tags)))
|
||||
|
||||
def connect(self, mutable):
|
||||
# type: (Any) -> Any
|
||||
"""
|
||||
Connect an object to a task (see introduction to Task connect design)
|
||||
|
||||
@ -597,6 +611,7 @@ class Task(_Task):
|
||||
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
|
||||
|
||||
def connect_configuration(self, configuration):
|
||||
# type: (Union[Mapping, Path, str]) -> Union[Mapping, Path, str]
|
||||
"""
|
||||
Connect a configuration dict / file (pathlib.Path / str) with the Task
|
||||
Connecting configuration file should be called before reading the configuration file.
|
||||
@ -626,14 +641,14 @@ class Task(_Task):
|
||||
# parameter dictionary
|
||||
if isinstance(configuration, dict):
|
||||
def _update_config_dict(task, config_dict):
|
||||
task.set_model_config(config_dict=config_dict)
|
||||
task._set_model_config(config_dict=config_dict)
|
||||
|
||||
if not running_remotely() or not self.is_main_task():
|
||||
self.set_model_config(config_dict=configuration)
|
||||
self._set_model_config(config_dict=configuration)
|
||||
configuration = ProxyDictPostWrite(self, _update_config_dict, **configuration)
|
||||
else:
|
||||
configuration.clear()
|
||||
configuration.update(self.get_model_config_dict())
|
||||
configuration.update(self._get_model_config_dict())
|
||||
configuration = ProxyDictPreWrite(False, False, **configuration)
|
||||
return configuration
|
||||
|
||||
@ -649,10 +664,10 @@ class Task(_Task):
|
||||
except Exception:
|
||||
raise ValueError("Could not connect configuration file {}, file could not be read".format(
|
||||
configuration_path.as_posix()))
|
||||
self.set_model_config(config_text=configuration_text)
|
||||
self._set_model_config(config_text=configuration_text)
|
||||
return configuration
|
||||
else:
|
||||
configuration_text = self.get_model_config_text()
|
||||
configuration_text = self._get_model_config_text()
|
||||
configuration_path = Path(configuration)
|
||||
fd, local_filename = mkstemp(prefix='trains_task_config_',
|
||||
suffix=configuration_path.suffixes[-1] if
|
||||
@ -662,6 +677,7 @@ class Task(_Task):
|
||||
return Path(local_filename) if isinstance(configuration, Path) else local_filename
|
||||
|
||||
def connect_label_enumeration(self, enumeration):
|
||||
# type: (Dict[str, int]) -> Dict[str, int]
|
||||
"""
|
||||
Connect a label enumeration dictionary with the Task
|
||||
|
||||
@ -686,7 +702,7 @@ class Task(_Task):
|
||||
def get_logger(self):
|
||||
# type: () -> Logger
|
||||
"""
|
||||
get a logger object for reporting, for this task context.
|
||||
Get a logger object for reporting, for this task context.
|
||||
All reports (metrics, text etc.) related to this task are accessible in the web UI
|
||||
|
||||
:return: Logger object
|
||||
@ -695,7 +711,7 @@ class Task(_Task):
|
||||
|
||||
def mark_started(self):
|
||||
"""
|
||||
Manually Mark the task as started (will happen automatically)
|
||||
Manually Mark the task as started (happens automatically)
|
||||
"""
|
||||
# UI won't let us see metrics if we're not started
|
||||
self.started()
|
||||
@ -703,7 +719,7 @@ class Task(_Task):
|
||||
|
||||
def mark_stopped(self):
|
||||
"""
|
||||
Manually Mark the task as stopped (also used in self._at_exit)
|
||||
Manually Mark the task as stopped (also used in :func:`_at_exit`)
|
||||
"""
|
||||
# flush any outstanding logs
|
||||
self.flush(wait_for_uploads=True)
|
||||
@ -711,8 +727,9 @@ class Task(_Task):
|
||||
self.stopped()
|
||||
|
||||
def flush(self, wait_for_uploads=False):
|
||||
# type: (bool) -> bool
|
||||
"""
|
||||
flush any outstanding reports or console logs
|
||||
Flush any outstanding reports or console logs
|
||||
|
||||
:param wait_for_uploads: if True the flush will exit only after all outstanding uploads are completed
|
||||
"""
|
||||
@ -736,6 +753,7 @@ class Task(_Task):
|
||||
return True
|
||||
|
||||
def reset(self, set_started_on_success=False, force=False):
|
||||
# type: (bool, bool) -> None
|
||||
"""
|
||||
Reset the task. Task will be reloaded following a successful reset.
|
||||
|
||||
@ -759,8 +777,9 @@ class Task(_Task):
|
||||
self.__register_at_exit(None)
|
||||
|
||||
def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True):
|
||||
# type: (str, "pandas.DataFrame", Dict, Union[bool, TSequence[str]]) -> None
|
||||
"""
|
||||
Add artifact for the current Task, used mostly for Data Audition.
|
||||
Add artifact for the current Task, used mostly for Data Auditing.
|
||||
Currently supported artifacts object types: pandas.DataFrame
|
||||
|
||||
:param str name: name of the artifacts. Notice! it will override previous artifacts if name already exists.
|
||||
@ -776,6 +795,7 @@ class Task(_Task):
|
||||
self._artifacts_manager.register_artifact(name=name, artifact=artifact, metadata=metadata, uniqueness_columns=uniqueness_columns)
|
||||
|
||||
def unregister_artifact(self, name):
|
||||
# type: (str) -> None
|
||||
"""
|
||||
Remove artifact from the watch list. Notice this will not remove the artifacts from the Task.
|
||||
It will only stop monitoring the artifact,
|
||||
@ -784,6 +804,7 @@ class Task(_Task):
|
||||
self._artifacts_manager.unregister_artifact(name=name)
|
||||
|
||||
def get_registered_artifacts(self):
|
||||
# type: () -> Dict[str, Artifact]
|
||||
"""
|
||||
dictionary of Task registered artifacts (name, artifact object)
|
||||
Notice these objects can be modified, changes will be uploaded automatically
|
||||
@ -793,9 +814,11 @@ class Task(_Task):
|
||||
return self._artifacts_manager.registered_artifacts
|
||||
|
||||
def upload_artifact(self, name, artifact_object, metadata=None, delete_after_upload=False):
|
||||
# type: (str, Union[str, Mapping, "pandas.DataFrame", "numpy.ndarray", "PIL.Image.Image"], Optional[Mapping], bool) -> bool
|
||||
"""
|
||||
Add static artifact to Task. Artifact file/object will be uploaded in the background
|
||||
Raise ValueError if artifact_object is not supported
|
||||
|
||||
:raises ValueError: if artifact_object is not supported
|
||||
|
||||
:param str name: Artifact name. Notice! it will override previous artifact if name already exists
|
||||
:param object artifact_object: Artifact object to upload. Currently supports:
|
||||
@ -814,8 +837,9 @@ class Task(_Task):
|
||||
metadata=metadata, delete_after_upload=delete_after_upload)
|
||||
|
||||
def get_models(self):
|
||||
# type: () -> Dict[str, List[Model]]
|
||||
"""
|
||||
Return a dictionary with {'input': [], 'output': []} loaded/stored models of the current Task.
|
||||
Return a dictionary with {'input': [], 'output': []} loaded/stored models of the current Task
|
||||
Input models are files loaded in the task, either manually or automatically logged
|
||||
Output models are files stored in the task, either manually or automatically logged
|
||||
Automatically logged frameworks are for example: TensorFlow, Keras, PyTorch, ScikitLearn(joblib) etc.
|
||||
@ -828,26 +852,29 @@ class Task(_Task):
|
||||
return task_models
|
||||
|
||||
def is_current_task(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Check if this task is the main task (returned by Task.init())
|
||||
|
||||
NOTE: This call is deprecated. Please use Task.is_main_task()
|
||||
.. deprecated:: 0.1.0
|
||||
Use :func:`is_main_task()` instead
|
||||
|
||||
If Task.init() was never called, this method will *not* create
|
||||
it, making this test cheaper than Task.init() == task
|
||||
|
||||
:return: True if this task is the current task
|
||||
:return: True if this task is the main task
|
||||
"""
|
||||
return self.is_main_task()
|
||||
|
||||
def is_main_task(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Check if this task is the main task (returned by Task.init())
|
||||
Check if this task is the main task (created/returned by Task.init())
|
||||
|
||||
If Task.init() was never called, this method will *not* create
|
||||
it, making this test cheaper than Task.init() == task
|
||||
|
||||
:return: True if this task is the current task
|
||||
:return: True if this task is the main task
|
||||
"""
|
||||
return self is self.__main_task
|
||||
|
||||
@ -876,6 +903,7 @@ class Task(_Task):
|
||||
return self._get_model_config_dict()
|
||||
|
||||
def set_model_label_enumeration(self, enumeration=None):
|
||||
# type: (Optional[Mapping[str, int]]) -> ()
|
||||
"""
|
||||
Set Task output label enumeration (before creating an output model)
|
||||
When an output model is created it will inherit these properties
|
||||
@ -886,6 +914,7 @@ class Task(_Task):
|
||||
super(Task, self).set_model_label_enumeration(enumeration=enumeration)
|
||||
|
||||
def get_last_iteration(self):
|
||||
# type: () -> int
|
||||
"""
|
||||
Return the maximum reported iteration (i.e. the maximum iteration the task reported a metric for)
|
||||
Notice, this is not a cached call, it will ask the backend for the answer (no local caching)
|
||||
@ -896,6 +925,7 @@ class Task(_Task):
|
||||
return max(self.data.last_iteration, self._reporter.max_iteration if self._reporter else 0)
|
||||
|
||||
def set_last_iteration(self, last_iteration):
|
||||
# type: (int) -> None
|
||||
"""
|
||||
Forcefully set the last reported iteration
|
||||
(i.e. the maximum iteration the task reported a metric for)
|
||||
@ -907,6 +937,7 @@ class Task(_Task):
|
||||
self._edit(last_iteration=self.data.last_iteration)
|
||||
|
||||
def set_initial_iteration(self, offset=0):
|
||||
# type: (int) -> int
|
||||
"""
|
||||
Set initial iteration, instead of zero. Useful when continuing training from previous checkpoints
|
||||
|
||||
@ -916,17 +947,19 @@ class Task(_Task):
|
||||
return super(Task, self).set_initial_iteration(offset=offset)
|
||||
|
||||
def get_initial_iteration(self):
|
||||
# type: () -> int
|
||||
"""
|
||||
Return the initial iteration offset, default is 0.
|
||||
Useful when continuing training from previous checkpoints.
|
||||
Return the initial iteration offset, default is 0
|
||||
Useful when continuing training from previous checkpoints
|
||||
|
||||
:return int: initial iteration offset
|
||||
"""
|
||||
return super(Task, self).get_initial_iteration()
|
||||
|
||||
def get_last_scalar_metrics(self):
|
||||
# type: () -> Dict[str, Dict[str, Dict[str, float]]]
|
||||
"""
|
||||
Extract the last scalar metrics, ordered by title & series in a nested dictionary
|
||||
Extract the last scalar metrics, ordered by title and series in a nested dictionary
|
||||
|
||||
:return: dict. Example: {'title': {'series': {'last': 0.5, 'min': 0.1, 'max': 0.9}}}
|
||||
"""
|
||||
@ -940,34 +973,40 @@ class Task(_Task):
|
||||
return scalar_metrics
|
||||
|
||||
def get_parameters_as_dict(self):
|
||||
# type: () -> Dict
|
||||
"""
|
||||
Get task parameters as a raw nested dict
|
||||
Note that values are not parsed and returned as is (i.e. string)
|
||||
|
||||
.. note::
|
||||
values are not parsed and returned as is (i.e. string)
|
||||
"""
|
||||
return naive_nested_from_flat_dictionary(self.get_parameters())
|
||||
|
||||
def set_parameters_as_dict(self, dictionary):
|
||||
# type: (Dict) -> None
|
||||
"""
|
||||
Set task parameters from a (possibly nested) dict.
|
||||
Set task parameters from a (possibly nested) dict
|
||||
While parameters are set just as they would be in connect(dict), this does not link the dict to the task,
|
||||
but rather does a one-time update.
|
||||
but rather performs a one-time update.
|
||||
"""
|
||||
self._arguments.copy_from_dict(flatten_dictionary(dictionary))
|
||||
|
||||
@classmethod
|
||||
def set_credentials(cls, api_host=None, web_host=None, files_host=None, key=None, secret=None, host=None):
|
||||
# type: (Optional[str], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]) -> ()
|
||||
"""
|
||||
Set new default TRAINS-server host and credentials
|
||||
Set new default trains-server host values and credentials
|
||||
These configurations will be overridden by either OS environment variables or trains.conf configuration file
|
||||
|
||||
Notice! credentials needs to be set *prior* to Task initialization
|
||||
.. note::
|
||||
credentials need to be set *prior* to Task initialization
|
||||
|
||||
:param str api_host: Trains API server url, example: host='http://localhost:8008'
|
||||
:param str web_host: Trains WEB server url, example: host='http://localhost:8080'
|
||||
:param str files_host: Trains Files server url, example: host='http://localhost:8081'
|
||||
:param str key: user key/secret pair, example: key='thisisakey123'
|
||||
:param str secret: user key/secret pair, example: secret='thisisseceret123'
|
||||
:param str host: host url, example: host='http://localhost:8008' (deprecated)
|
||||
:param str host: host url (overrides api_host), example: host='http://localhost:8008'
|
||||
"""
|
||||
if api_host:
|
||||
Session.default_host = api_host
|
||||
|
Loading…
Reference in New Issue
Block a user