Improve linters (PyCharm) support detecting return type of Task.init / current_task

This commit is contained in:
allegroai 2022-09-15 16:03:29 +03:00
parent 789ca3a76f
commit 916c273a08

View File

@ -30,6 +30,7 @@ from typing import (
Callable,
Tuple,
List,
TypeVar,
)
import psutil
@ -106,6 +107,9 @@ if TYPE_CHECKING:
import numpy
from PIL import Image
# Forward declaration to help linters
TaskInstance = TypeVar("TaskInstance", bound="Task")
class Task(_Task):
"""
@ -207,12 +211,13 @@ class Task(_Task):
@classmethod
def current_task(cls):
# type: () -> Task
# type: () -> TaskInstance
"""
Get the current running Task (experiment). This is the main execution Task (task context) returned as a Task
object.
:return: The current running Task (experiment).
:rtype: Task
"""
# check if we have no main Task, but the main process created one.
if not cls.__main_task and cls.__get_master_id_task_id():
@ -237,7 +242,7 @@ class Task(_Task):
auto_connect_streams=True, # type: Union[bool, Mapping[str, bool]]
deferred_init=False, # type: bool
):
# type: (...) -> Task
# type: (...) -> TaskInstance
"""
Creates a new Task (experiment) if:
@ -462,6 +467,7 @@ class Task(_Task):
and Task init is called synchronously (default)
:return: The main execution Task (Task context)
:rtype: Task
"""
def verify_defaults_match():
@ -805,7 +811,7 @@ class Task(_Task):
base_task_id=None, # type: Optional[str]
add_task_init_call=True, # type: bool
):
# type: (...) -> Task
# type: (...) -> TaskInstance
"""
Manually create and populate a new Task (experiment) in the system.
If the code does not already contain a call to ``Task.init``, pass add_task_init_call=True,
@ -847,6 +853,7 @@ class Task(_Task):
:param add_task_init_call: If True, a 'Task.init()' call is added to the script entry point in remote execution.
:return: The newly created Task (experiment)
:rtype: Task
"""
if not project_name and not base_task_id:
if not cls.__main_task:
@ -881,7 +888,7 @@ class Task(_Task):
allow_archived=True, # type: bool
task_filter=None # type: Optional[dict]
):
# type: (...) -> "Task"
# type: (...) -> TaskInstance
"""
Get a Task by Id, or project name / task name combination.
@ -927,6 +934,7 @@ class Task(_Task):
Pass additional query filters, on top of project/name. See details in Task.get_tasks.
:return: The Task specified by ID, or project name / experiment name combination.
:rtype: Task
"""
return cls.__get_task(
task_id=task_id, project_name=project_name, task_name=task_name, tags=tags,
@ -942,7 +950,7 @@ class Task(_Task):
tags=None, # type: Optional[Sequence[str]]
task_filter=None # type: Optional[Dict]
):
# type: (...) -> List["Task"]
# type: (...) -> List[TaskInstance]
"""
Get a list of Tasks objects matching the queries/filters
@ -996,6 +1004,7 @@ class Task(_Task):
{'order_by'=['-last_update'], '_all_'=dict(fields=['script.repository'], pattern='github.com/user'))
:return: The Tasks specified by the parameter combinations (see the parameters).
:rtype: List[Task]
"""
return cls.__get_tasks(task_ids=task_ids, project_name=project_name, tags=tags,
task_name=task_name, **(task_filter or {}))
@ -1172,7 +1181,7 @@ class Task(_Task):
parent=None, # type: Optional[str]
project=None, # type: Optional[str]
):
# type: (...) -> Task
# type: (...) -> TaskInstance
"""
Create a duplicate (a clone) of a Task (experiment). The status of the cloned Task is ``Draft``
and modifiable.
@ -1192,6 +1201,7 @@ class Task(_Task):
If ``None``, the new task inherits the original Task's project. (Optional)
:return: The new cloned Task (experiment).
:rtype: Task
"""
assert isinstance(source_task, (six.string_types, Task))
if not Session.check_min_api_version('2.4'):
@ -1873,7 +1883,7 @@ class Task(_Task):
preview=None, # type: Any
wait_on_upload=False, # type: bool
extension_name=None, # type: Optional[str]
serialization_function=None # type: Optional[Callable[Any, Union[bytes, bytearray]]]
serialization_function=None # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
):
# type: (...) -> bool
"""
@ -2916,7 +2926,7 @@ class Task(_Task):
@classmethod
def _create(cls, project_name=None, task_name=None, task_type=TaskTypes.training):
# type: (Optional[str], Optional[str], Task.TaskTypes) -> Task
# type: (Optional[str], Optional[str], Task.TaskTypes) -> TaskInstance
"""
Create a new unpopulated Task (experiment).
@ -2928,6 +2938,7 @@ class Task(_Task):
:param TaskTypes task_type: The task type.
:return: The newly created task created.
:rtype: Task
"""
if not project_name:
if not cls.__main_task:
@ -3854,7 +3865,7 @@ class Task(_Task):
tags=None, # type: Optional[Sequence[str]]
task_filter=None # type: Optional[dict]
):
# type: (...) -> Task
# type: (...) -> TaskInstance
if task_id:
return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False)
@ -3904,7 +3915,8 @@ class Task(_Task):
include_archived=include_archived, task_filter=task_filter).items() if v},
results=res_tasks, raise_on_error=False)
if not task:
return None
# should never happen
return None # noqa
return cls(
private=cls.__create_protection,