Improve linters (PyCharm) support detecting return type of Task.init / current_task

This commit is contained in:
allegroai 2022-09-15 16:03:29 +03:00
parent 789ca3a76f
commit 916c273a08

View File

@ -30,6 +30,7 @@ from typing import (
Callable, Callable,
Tuple, Tuple,
List, List,
TypeVar,
) )
import psutil import psutil
@ -106,6 +107,9 @@ if TYPE_CHECKING:
import numpy import numpy
from PIL import Image from PIL import Image
# Forward declaration to help linters
TaskInstance = TypeVar("TaskInstance", bound="Task")
class Task(_Task): class Task(_Task):
""" """
@ -207,12 +211,13 @@ class Task(_Task):
@classmethod @classmethod
def current_task(cls): def current_task(cls):
# type: () -> Task # type: () -> TaskInstance
""" """
Get the current running Task (experiment). This is the main execution Task (task context) returned as a Task Get the current running Task (experiment). This is the main execution Task (task context) returned as a Task
object. object.
:return: The current running Task (experiment). :return: The current running Task (experiment).
:rtype: Task
""" """
# check if we have no main Task, but the main process created one. # check if we have no main Task, but the main process created one.
if not cls.__main_task and cls.__get_master_id_task_id(): if not cls.__main_task and cls.__get_master_id_task_id():
@ -237,7 +242,7 @@ class Task(_Task):
auto_connect_streams=True, # type: Union[bool, Mapping[str, bool]] auto_connect_streams=True, # type: Union[bool, Mapping[str, bool]]
deferred_init=False, # type: bool deferred_init=False, # type: bool
): ):
# type: (...) -> Task # type: (...) -> TaskInstance
""" """
Creates a new Task (experiment) if: Creates a new Task (experiment) if:
@ -462,6 +467,7 @@ class Task(_Task):
and Task init is called synchronously (default) and Task init is called synchronously (default)
:return: The main execution Task (Task context) :return: The main execution Task (Task context)
:rtype: Task
""" """
def verify_defaults_match(): def verify_defaults_match():
@ -805,7 +811,7 @@ class Task(_Task):
base_task_id=None, # type: Optional[str] base_task_id=None, # type: Optional[str]
add_task_init_call=True, # type: bool add_task_init_call=True, # type: bool
): ):
# type: (...) -> Task # type: (...) -> TaskInstance
""" """
Manually create and populate a new Task (experiment) in the system. Manually create and populate a new Task (experiment) in the system.
If the code does not already contain a call to ``Task.init``, pass add_task_init_call=True, If the code does not already contain a call to ``Task.init``, pass add_task_init_call=True,
@ -847,6 +853,7 @@ class Task(_Task):
:param add_task_init_call: If True, a 'Task.init()' call is added to the script entry point in remote execution. :param add_task_init_call: If True, a 'Task.init()' call is added to the script entry point in remote execution.
:return: The newly created Task (experiment) :return: The newly created Task (experiment)
:rtype: Task
""" """
if not project_name and not base_task_id: if not project_name and not base_task_id:
if not cls.__main_task: if not cls.__main_task:
@ -881,7 +888,7 @@ class Task(_Task):
allow_archived=True, # type: bool allow_archived=True, # type: bool
task_filter=None # type: Optional[dict] task_filter=None # type: Optional[dict]
): ):
# type: (...) -> "Task" # type: (...) -> TaskInstance
""" """
Get a Task by Id, or project name / task name combination. Get a Task by Id, or project name / task name combination.
@ -927,6 +934,7 @@ class Task(_Task):
Pass additional query filters, on top of project/name. See details in Task.get_tasks. Pass additional query filters, on top of project/name. See details in Task.get_tasks.
:return: The Task specified by ID, or project name / experiment name combination. :return: The Task specified by ID, or project name / experiment name combination.
:rtype: Task
""" """
return cls.__get_task( return cls.__get_task(
task_id=task_id, project_name=project_name, task_name=task_name, tags=tags, task_id=task_id, project_name=project_name, task_name=task_name, tags=tags,
@ -942,7 +950,7 @@ class Task(_Task):
tags=None, # type: Optional[Sequence[str]] tags=None, # type: Optional[Sequence[str]]
task_filter=None # type: Optional[Dict] task_filter=None # type: Optional[Dict]
): ):
# type: (...) -> List["Task"] # type: (...) -> List[TaskInstance]
""" """
Get a list of Tasks objects matching the queries/filters Get a list of Tasks objects matching the queries/filters
@ -996,6 +1004,7 @@ class Task(_Task):
{'order_by'=['-last_update'], '_all_'=dict(fields=['script.repository'], pattern='github.com/user')) {'order_by'=['-last_update'], '_all_'=dict(fields=['script.repository'], pattern='github.com/user'))
:return: The Tasks specified by the parameter combinations (see the parameters). :return: The Tasks specified by the parameter combinations (see the parameters).
:rtype: List[Task]
""" """
return cls.__get_tasks(task_ids=task_ids, project_name=project_name, tags=tags, return cls.__get_tasks(task_ids=task_ids, project_name=project_name, tags=tags,
task_name=task_name, **(task_filter or {})) task_name=task_name, **(task_filter or {}))
@ -1172,7 +1181,7 @@ class Task(_Task):
parent=None, # type: Optional[str] parent=None, # type: Optional[str]
project=None, # type: Optional[str] project=None, # type: Optional[str]
): ):
# type: (...) -> Task # type: (...) -> TaskInstance
""" """
Create a duplicate (a clone) of a Task (experiment). The status of the cloned Task is ``Draft`` Create a duplicate (a clone) of a Task (experiment). The status of the cloned Task is ``Draft``
and modifiable. and modifiable.
@ -1192,6 +1201,7 @@ class Task(_Task):
If ``None``, the new task inherits the original Task's project. (Optional) If ``None``, the new task inherits the original Task's project. (Optional)
:return: The new cloned Task (experiment). :return: The new cloned Task (experiment).
:rtype: Task
""" """
assert isinstance(source_task, (six.string_types, Task)) assert isinstance(source_task, (six.string_types, Task))
if not Session.check_min_api_version('2.4'): if not Session.check_min_api_version('2.4'):
@ -1873,7 +1883,7 @@ class Task(_Task):
preview=None, # type: Any preview=None, # type: Any
wait_on_upload=False, # type: bool wait_on_upload=False, # type: bool
extension_name=None, # type: Optional[str] extension_name=None, # type: Optional[str]
serialization_function=None # type: Optional[Callable[Any, Union[bytes, bytearray]]] serialization_function=None # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
): ):
# type: (...) -> bool # type: (...) -> bool
""" """
@ -2916,7 +2926,7 @@ class Task(_Task):
@classmethod @classmethod
def _create(cls, project_name=None, task_name=None, task_type=TaskTypes.training): def _create(cls, project_name=None, task_name=None, task_type=TaskTypes.training):
# type: (Optional[str], Optional[str], Task.TaskTypes) -> Task # type: (Optional[str], Optional[str], Task.TaskTypes) -> TaskInstance
""" """
Create a new unpopulated Task (experiment). Create a new unpopulated Task (experiment).
@ -2928,6 +2938,7 @@ class Task(_Task):
:param TaskTypes task_type: The task type. :param TaskTypes task_type: The task type.
:return: The newly created task created. :return: The newly created task created.
:rtype: Task
""" """
if not project_name: if not project_name:
if not cls.__main_task: if not cls.__main_task:
@ -3854,7 +3865,7 @@ class Task(_Task):
tags=None, # type: Optional[Sequence[str]] tags=None, # type: Optional[Sequence[str]]
task_filter=None # type: Optional[dict] task_filter=None # type: Optional[dict]
): ):
# type: (...) -> Task # type: (...) -> TaskInstance
if task_id: if task_id:
return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False) return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False)
@ -3904,7 +3915,8 @@ class Task(_Task):
include_archived=include_archived, task_filter=task_filter).items() if v}, include_archived=include_archived, task_filter=task_filter).items() if v},
results=res_tasks, raise_on_error=False) results=res_tasks, raise_on_error=False)
if not task: if not task:
return None # should never happen
return None # noqa
return cls( return cls(
private=cls.__create_protection, private=cls.__create_protection,