Fix type hints, add ignore/fix pep8 warnings

This commit is contained in:
allegroai 2020-05-31 11:59:05 +03:00
parent 7440799bb0
commit 9259e4efeb

View File

@ -8,11 +8,12 @@ from argparse import ArgumentParser
from tempfile import mkstemp from tempfile import mkstemp
try: try:
# noinspection PyCompatibility
from collections.abc import Callable, Sequence as CollectionsSequence from collections.abc import Callable, Sequence as CollectionsSequence
except ImportError: except ImportError:
from collections import Callable, Sequence as CollectionsSequence from collections import Callable, Sequence as CollectionsSequence
from typing import Optional, Union, Mapping, Sequence, Any, Dict, List, TYPE_CHECKING from typing import Optional, Union, Mapping, Sequence, Any, Dict, TYPE_CHECKING
import psutil import psutil
import six import six
@ -23,10 +24,9 @@ from .backend_api.session.session import Session, ENV_ACCESS_KEY, ENV_SECRET_KEY
from .backend_interface.metrics import Metrics from .backend_interface.metrics import Metrics
from .backend_interface.model import Model as BackendModel from .backend_interface.model import Model as BackendModel
from .backend_interface.task import Task as _Task from .backend_interface.task import Task as _Task
from .backend_interface.task.args import _Arguments
from .backend_interface.task.development.worker import DevWorker from .backend_interface.task.development.worker import DevWorker
from .backend_interface.task.repo import ScriptInfo from .backend_interface.task.repo import ScriptInfo
from .backend_interface.util import get_single_result, exact_match_regex, make_message from .backend_interface.util import get_single_result, exact_match_regex, make_message, mutually_exclusive
from .binding.absl_bind import PatchAbsl from .binding.absl_bind import PatchAbsl
from .binding.artifacts import Artifacts, Artifact from .binding.artifacts import Artifacts, Artifact
from .binding.environ_bind import EnvironmentBind, PatchOsFork from .binding.environ_bind import EnvironmentBind, PatchOsFork
@ -50,6 +50,8 @@ from .utilities.proxy_object import ProxyDictPreWrite, ProxyDictPostWrite, flatt
nested_from_flat_dictionary, naive_nested_from_flat_dictionary nested_from_flat_dictionary, naive_nested_from_flat_dictionary
from .utilities.resource_monitor import ResourceMonitor from .utilities.resource_monitor import ResourceMonitor
from .utilities.seed import make_deterministic from .utilities.seed import make_deterministic
# noinspection PyProtectedMember
from .backend_interface.task.args import _Arguments
if TYPE_CHECKING: if TYPE_CHECKING:
@ -105,7 +107,7 @@ class Task(_Task):
NotSet = object() NotSet = object()
__create_protection = object() __create_protection = object()
__main_task = None # type: Task __main_task = None # type: Optional[Task]
__exit_hook = None __exit_hook = None
__forked_proc_main_pid = None __forked_proc_main_pid = None
__task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0)) __task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0))
@ -326,6 +328,7 @@ class Task(_Task):
cls.__main_task.get_logger() cls.__main_task.get_logger()
cls.__main_task._artifacts_manager = Artifacts(cls.__main_task) cls.__main_task._artifacts_manager = Artifacts(cls.__main_task)
# unregister signal hooks, they cause subprocess to hang # unregister signal hooks, they cause subprocess to hang
# noinspection PyProtectedMember
cls.__main_task.__register_at_exit(cls.__main_task._at_exit) cls.__main_task.__register_at_exit(cls.__main_task._at_exit)
# TODO: Check if the signal handler method is safe enough, for the time being, do not unhook # TODO: Check if the signal handler method is safe enough, for the time being, do not unhook
# cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True) # cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
@ -335,6 +338,7 @@ class Task(_Task):
return cls.__main_task return cls.__main_task
is_sub_process_task_id = None
# check that we are not a child process, in that case do nothing. # check that we are not a child process, in that case do nothing.
# we should not get here unless this is Windows platform, all others support fork # we should not get here unless this is Windows platform, all others support fork
if cls.__is_subprocess(): if cls.__is_subprocess():
@ -370,7 +374,7 @@ class Task(_Task):
if task_type not in Task.TaskTypes.__members__: if task_type not in Task.TaskTypes.__members__:
raise ValueError("Task type '{}' not supported, options are: {}".format( raise ValueError("Task type '{}' not supported, options are: {}".format(
task_type, Task.TaskTypes.__members__.keys())) task_type, Task.TaskTypes.__members__.keys()))
task_type = Task.TaskTypes.__members__[task_type] task_type = Task.TaskTypes.__members__[str(task_type)]
try: try:
if not running_remotely(): if not running_remotely():
@ -606,7 +610,7 @@ class Task(_Task):
@property @property
def models(self): def models(self):
# type: () -> Dict[str, List[Model]] # type: () -> Dict[str, Sequence[Model]]
""" """
Read-only dictionary of the Task's loaded/stored models Read-only dictionary of the Task's loaded/stored models
@ -712,6 +716,9 @@ class Task(_Task):
raise ValueError("Trains-server does not support DevOps features, " raise ValueError("Trains-server does not support DevOps features, "
"upgrade trains-server to 0.12.0 or above") "upgrade trains-server to 0.12.0 or above")
# make sure we have wither name ot id
mutually_exclusive(queue_name=queue_name, queue_id=queue_id)
task_id = task if isinstance(task, six.string_types) else task.id task_id = task if isinstance(task, six.string_types) else task.id
session = cls._get_default_session() session = cls._get_default_session()
if not queue_id: if not queue_id:
@ -873,6 +880,7 @@ class Task(_Task):
# parameter dictionary # parameter dictionary
if isinstance(configuration, dict): if isinstance(configuration, dict):
def _update_config_dict(task, config_dict): def _update_config_dict(task, config_dict):
# noinspection PyProtectedMember
task._set_model_config(config_dict=config_dict) task._set_model_config(config_dict=config_dict)
if not running_remotely() or not self.is_main_task(): if not running_remotely() or not self.is_main_task():
@ -1139,7 +1147,7 @@ class Task(_Task):
metadata=metadata, delete_after_upload=delete_after_upload) metadata=metadata, delete_after_upload=delete_after_upload)
def get_models(self): def get_models(self):
# type: () -> Dict[str, List[Model]] # type: () -> Dict[str, Sequence[Model]]
""" """
Return a dictionary with {'input': [], 'output': []} loaded/stored models of the current Task Return a dictionary with {'input': [], 'output': []} loaded/stored models of the current Task
Input models are files loaded in the task, either manually or automatically logged Input models are files loaded in the task, either manually or automatically logged
@ -1381,6 +1389,7 @@ class Task(_Task):
:param config_dict: model configuration parameters dictionary. :param config_dict: model configuration parameters dictionary.
If `config_dict` is not None, `config_text` must not be provided. If `config_dict` is not None, `config_text` must not be provided.
""" """
# noinspection PyProtectedMember
design = OutputModel._resolve_config(config_text=config_text, config_dict=config_dict) design = OutputModel._resolve_config(config_text=config_text, config_dict=config_dict)
super(Task, self)._set_model_design(design=design) super(Task, self)._set_model_design(design=design)
@ -1403,6 +1412,7 @@ class Task(_Task):
:return: config_dict: model configuration parameters dictionary :return: config_dict: model configuration parameters dictionary
""" """
config_text = self._get_model_config_text() config_text = self._get_model_config_text()
# noinspection PyProtectedMember
return OutputModel._text_to_config_dict(config_text) return OutputModel._text_to_config_dict(config_text)
@classmethod @classmethod
@ -1449,6 +1459,7 @@ class Task(_Task):
closed_old_task = False closed_old_task = False
default_task_id = None default_task_id = None
task = None
in_dev_mode = not running_remotely() in_dev_mode = not running_remotely()
if in_dev_mode: if in_dev_mode:
@ -1648,6 +1659,7 @@ class Task(_Task):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if 'IPython' in sys.modules: if 'IPython' in sys.modules:
# noinspection PyPackageRequirements
from IPython import get_ipython from IPython import get_ipython
ip = get_ipython() ip = get_ipython()
if ip is not None and 'IPKernelApp' in ip.config: if ip is not None and 'IPKernelApp' in ip.config:
@ -1676,14 +1688,17 @@ class Task(_Task):
def _connect_dictionary(self, dictionary): def _connect_dictionary(self, dictionary):
def _update_args_dict(task, config_dict): def _update_args_dict(task, config_dict):
# noinspection PyProtectedMember
task._arguments.copy_from_dict(flatten_dictionary(config_dict)) task._arguments.copy_from_dict(flatten_dictionary(config_dict))
def _refresh_args_dict(task, config_dict): def _refresh_args_dict(task, config_dict):
# reread from task including newly added keys # reread from task including newly added keys
flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict)) # noinspection PyProtectedMember
a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict))
# noinspection PyProtectedMember
nested_dict = config_dict._to_dict() nested_dict = config_dict._to_dict()
config_dict.clear() config_dict.clear()
config_dict.update(nested_from_flat_dictionary(nested_dict, flat_dict)) config_dict.update(nested_from_flat_dictionary(nested_dict, a_flat_dict))
self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary) self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary)
@ -1790,6 +1805,7 @@ class Task(_Task):
with self._repo_detect_lock: with self._repo_detect_lock:
if not self._detect_repo_async_thread: if not self._detect_repo_async_thread:
return return
# noinspection PyBroadException
try: try:
if self._detect_repo_async_thread.is_alive(): if self._detect_repo_async_thread.is_alive():
# if negative timeout, just kill the thread: # if negative timeout, just kill the thread:
@ -1858,9 +1874,10 @@ class Task(_Task):
is_exception = self.__exit_hook.exception is_exception = self.__exit_hook.exception
# check if we are running inside a debugger # check if we are running inside a debugger
if not is_exception and sys.modules.get('pydevd'): if not is_exception and sys.modules.get('pydevd'):
# noinspection PyBroadException
try: try:
is_exception = sys.last_type is_exception = sys.last_type
except: except Exception:
pass pass
if (is_exception and not isinstance(self.__exit_hook.exception, KeyboardInterrupt)) \ if (is_exception and not isinstance(self.__exit_hook.exception, KeyboardInterrupt)) \
@ -1906,15 +1923,17 @@ class Task(_Task):
# notice: this will close the jupyter monitoring # notice: this will close the jupyter monitoring
ScriptInfo.close() ScriptInfo.close()
if self.is_main_task(): if self.is_main_task():
# noinspection PyBroadException
try: try:
from .storage.helper import StorageHelper from .storage.helper import StorageHelper
StorageHelper.close_async_threads() StorageHelper.close_async_threads()
except: except Exception:
pass pass
if print_done_waiting: if print_done_waiting:
self.log.info('Finished uploading') self.log.info('Finished uploading')
elif self._logger: elif self._logger:
# noinspection PyProtectedMember
self._logger._flush_stdout_handler() self._logger._flush_stdout_handler()
# from here, do not check worker status # from here, do not check worker status
@ -1940,6 +1959,7 @@ class Task(_Task):
if self._logger: if self._logger:
self._logger.set_flush_period(None) self._logger.set_flush_period(None)
# noinspection PyProtectedMember
self._logger._close_stdout_handler(wait=wait_for_uploads or wait_for_std_log) self._logger._close_stdout_handler(wait=wait_for_uploads or wait_for_std_log)
# this is so in theory we can close a main task and start a new one # this is so in theory we can close a main task and start a new one
@ -1949,10 +1969,10 @@ class Task(_Task):
# make sure we do not interrupt the exit process # make sure we do not interrupt the exit process
pass pass
# delete locking object (lock file) # delete locking object (lock file)
# noinspection PyBroadException
if self._edit_lock: if self._edit_lock:
# noinspection PyBroadException
try: try:
del self._edit_lock del self.__edit_lock
except Exception: except Exception:
pass pass
self._edit_lock = None self._edit_lock = None
@ -1975,6 +1995,7 @@ class Task(_Task):
def update_callback(self, callback): def update_callback(self, callback):
if self._exit_callback and not six.PY2: if self._exit_callback and not six.PY2:
# noinspection PyBroadException
try: try:
atexit.unregister(self._exit_callback) atexit.unregister(self._exit_callback)
except Exception: except Exception:
@ -1987,10 +2008,10 @@ class Task(_Task):
if self._orig_exc_handler: if self._orig_exc_handler:
sys.excepthook = self._orig_exc_handler sys.excepthook = self._orig_exc_handler
self._orig_exc_handler = None self._orig_exc_handler = None
for s in self._org_handlers: for h in self._org_handlers:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
signal.signal(s, self._org_handlers[s]) signal.signal(h, self._org_handlers[h])
except Exception: except Exception:
pass pass
self._org_handlers = {} self._org_handlers = {}
@ -2015,11 +2036,11 @@ class Task(_Task):
else: else:
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT, catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT] signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
for s in catch_signals: for c in catch_signals:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
self._org_handlers[s] = signal.getsignal(s) self._org_handlers[c] = signal.getsignal(c)
signal.signal(s, self.signal_handler) signal.signal(c, self.signal_handler)
except Exception: except Exception:
pass pass
@ -2029,13 +2050,16 @@ class Task(_Task):
def exc_handler(self, exctype, value, traceback, *args, **kwargs): def exc_handler(self, exctype, value, traceback, *args, **kwargs):
if self._except_recursion_protection_flag: if self._except_recursion_protection_flag:
# noinspection PyArgumentList
return sys.__excepthook__(exctype, value, traceback, *args, **kwargs) return sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
self._except_recursion_protection_flag = True self._except_recursion_protection_flag = True
self.exception = value self.exception = value
if self._orig_exc_handler: if self._orig_exc_handler:
# noinspection PyArgumentList
ret = self._orig_exc_handler(exctype, value, traceback, *args, **kwargs) ret = self._orig_exc_handler(exctype, value, traceback, *args, **kwargs)
else: else:
# noinspection PyNoneFunctionAssignment, PyArgumentList
ret = sys.__excepthook__(exctype, value, traceback, *args, **kwargs) ret = sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
self._except_recursion_protection_flag = False self._except_recursion_protection_flag = False
@ -2069,6 +2093,7 @@ class Task(_Task):
# remove stdout logger, just in case # remove stdout logger, just in case
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# noinspection PyProtectedMember
Logger._remove_std_logger() Logger._remove_std_logger()
except Exception: except Exception:
pass pass