Fix type hints, add ignore/fix pep8 warnings

This commit is contained in:
allegroai 2020-05-31 11:59:05 +03:00
parent 7440799bb0
commit 9259e4efeb

View File

@ -8,11 +8,12 @@ from argparse import ArgumentParser
from tempfile import mkstemp
try:
# noinspection PyCompatibility
from collections.abc import Callable, Sequence as CollectionsSequence
except ImportError:
from collections import Callable, Sequence as CollectionsSequence
from typing import Optional, Union, Mapping, Sequence, Any, Dict, List, TYPE_CHECKING
from typing import Optional, Union, Mapping, Sequence, Any, Dict, TYPE_CHECKING
import psutil
import six
@ -23,10 +24,9 @@ from .backend_api.session.session import Session, ENV_ACCESS_KEY, ENV_SECRET_KEY
from .backend_interface.metrics import Metrics
from .backend_interface.model import Model as BackendModel
from .backend_interface.task import Task as _Task
from .backend_interface.task.args import _Arguments
from .backend_interface.task.development.worker import DevWorker
from .backend_interface.task.repo import ScriptInfo
from .backend_interface.util import get_single_result, exact_match_regex, make_message
from .backend_interface.util import get_single_result, exact_match_regex, make_message, mutually_exclusive
from .binding.absl_bind import PatchAbsl
from .binding.artifacts import Artifacts, Artifact
from .binding.environ_bind import EnvironmentBind, PatchOsFork
@ -50,6 +50,8 @@ from .utilities.proxy_object import ProxyDictPreWrite, ProxyDictPostWrite, flatt
nested_from_flat_dictionary, naive_nested_from_flat_dictionary
from .utilities.resource_monitor import ResourceMonitor
from .utilities.seed import make_deterministic
# noinspection PyProtectedMember
from .backend_interface.task.args import _Arguments
if TYPE_CHECKING:
@ -105,7 +107,7 @@ class Task(_Task):
NotSet = object()
__create_protection = object()
__main_task = None # type: Task
__main_task = None # type: Optional[Task]
__exit_hook = None
__forked_proc_main_pid = None
__task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0))
@ -326,6 +328,7 @@ class Task(_Task):
cls.__main_task.get_logger()
cls.__main_task._artifacts_manager = Artifacts(cls.__main_task)
# unregister signal hooks, they cause subprocess to hang
# noinspection PyProtectedMember
cls.__main_task.__register_at_exit(cls.__main_task._at_exit)
# TODO: Check if the signal handler method is safe enough, for the time being, do not unhook
# cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
@ -335,6 +338,7 @@ class Task(_Task):
return cls.__main_task
is_sub_process_task_id = None
# check that we are not a child process, in that case do nothing.
# we should not get here unless this is Windows platform, all others support fork
if cls.__is_subprocess():
@ -370,7 +374,7 @@ class Task(_Task):
if task_type not in Task.TaskTypes.__members__:
raise ValueError("Task type '{}' not supported, options are: {}".format(
task_type, Task.TaskTypes.__members__.keys()))
task_type = Task.TaskTypes.__members__[task_type]
task_type = Task.TaskTypes.__members__[str(task_type)]
try:
if not running_remotely():
@ -606,7 +610,7 @@ class Task(_Task):
@property
def models(self):
# type: () -> Dict[str, List[Model]]
# type: () -> Dict[str, Sequence[Model]]
"""
Read-only dictionary of the Task's loaded/stored models
@ -712,6 +716,9 @@ class Task(_Task):
raise ValueError("Trains-server does not support DevOps features, "
"upgrade trains-server to 0.12.0 or above")
# make sure we have wither name ot id
mutually_exclusive(queue_name=queue_name, queue_id=queue_id)
task_id = task if isinstance(task, six.string_types) else task.id
session = cls._get_default_session()
if not queue_id:
@ -873,6 +880,7 @@ class Task(_Task):
# parameter dictionary
if isinstance(configuration, dict):
def _update_config_dict(task, config_dict):
# noinspection PyProtectedMember
task._set_model_config(config_dict=config_dict)
if not running_remotely() or not self.is_main_task():
@ -1139,7 +1147,7 @@ class Task(_Task):
metadata=metadata, delete_after_upload=delete_after_upload)
def get_models(self):
# type: () -> Dict[str, List[Model]]
# type: () -> Dict[str, Sequence[Model]]
"""
Return a dictionary with {'input': [], 'output': []} loaded/stored models of the current Task
Input models are files loaded in the task, either manually or automatically logged
@ -1381,6 +1389,7 @@ class Task(_Task):
:param config_dict: model configuration parameters dictionary.
If `config_dict` is not None, `config_text` must not be provided.
"""
# noinspection PyProtectedMember
design = OutputModel._resolve_config(config_text=config_text, config_dict=config_dict)
super(Task, self)._set_model_design(design=design)
@ -1403,6 +1412,7 @@ class Task(_Task):
:return: config_dict: model configuration parameters dictionary
"""
config_text = self._get_model_config_text()
# noinspection PyProtectedMember
return OutputModel._text_to_config_dict(config_text)
@classmethod
@ -1449,6 +1459,7 @@ class Task(_Task):
closed_old_task = False
default_task_id = None
task = None
in_dev_mode = not running_remotely()
if in_dev_mode:
@ -1648,6 +1659,7 @@ class Task(_Task):
# noinspection PyBroadException
try:
if 'IPython' in sys.modules:
# noinspection PyPackageRequirements
from IPython import get_ipython
ip = get_ipython()
if ip is not None and 'IPKernelApp' in ip.config:
@ -1676,14 +1688,17 @@ class Task(_Task):
def _connect_dictionary(self, dictionary):
def _update_args_dict(task, config_dict):
# noinspection PyProtectedMember
task._arguments.copy_from_dict(flatten_dictionary(config_dict))
def _refresh_args_dict(task, config_dict):
# reread from task including newly added keys
flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict))
# noinspection PyProtectedMember
a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict))
# noinspection PyProtectedMember
nested_dict = config_dict._to_dict()
config_dict.clear()
config_dict.update(nested_from_flat_dictionary(nested_dict, flat_dict))
config_dict.update(nested_from_flat_dictionary(nested_dict, a_flat_dict))
self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary)
@ -1790,6 +1805,7 @@ class Task(_Task):
with self._repo_detect_lock:
if not self._detect_repo_async_thread:
return
# noinspection PyBroadException
try:
if self._detect_repo_async_thread.is_alive():
# if negative timeout, just kill the thread:
@ -1858,9 +1874,10 @@ class Task(_Task):
is_exception = self.__exit_hook.exception
# check if we are running inside a debugger
if not is_exception and sys.modules.get('pydevd'):
# noinspection PyBroadException
try:
is_exception = sys.last_type
except:
except Exception:
pass
if (is_exception and not isinstance(self.__exit_hook.exception, KeyboardInterrupt)) \
@ -1906,15 +1923,17 @@ class Task(_Task):
# notice: this will close the jupyter monitoring
ScriptInfo.close()
if self.is_main_task():
# noinspection PyBroadException
try:
from .storage.helper import StorageHelper
StorageHelper.close_async_threads()
except:
except Exception:
pass
if print_done_waiting:
self.log.info('Finished uploading')
elif self._logger:
# noinspection PyProtectedMember
self._logger._flush_stdout_handler()
# from here, do not check worker status
@ -1940,6 +1959,7 @@ class Task(_Task):
if self._logger:
self._logger.set_flush_period(None)
# noinspection PyProtectedMember
self._logger._close_stdout_handler(wait=wait_for_uploads or wait_for_std_log)
# this is so in theory we can close a main task and start a new one
@ -1949,10 +1969,10 @@ class Task(_Task):
# make sure we do not interrupt the exit process
pass
# delete locking object (lock file)
# noinspection PyBroadException
if self._edit_lock:
# noinspection PyBroadException
try:
del self._edit_lock
del self.__edit_lock
except Exception:
pass
self._edit_lock = None
@ -1975,6 +1995,7 @@ class Task(_Task):
def update_callback(self, callback):
if self._exit_callback and not six.PY2:
# noinspection PyBroadException
try:
atexit.unregister(self._exit_callback)
except Exception:
@ -1987,10 +2008,10 @@ class Task(_Task):
if self._orig_exc_handler:
sys.excepthook = self._orig_exc_handler
self._orig_exc_handler = None
for s in self._org_handlers:
for h in self._org_handlers:
# noinspection PyBroadException
try:
signal.signal(s, self._org_handlers[s])
signal.signal(h, self._org_handlers[h])
except Exception:
pass
self._org_handlers = {}
@ -2015,11 +2036,11 @@ class Task(_Task):
else:
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
for s in catch_signals:
for c in catch_signals:
# noinspection PyBroadException
try:
self._org_handlers[s] = signal.getsignal(s)
signal.signal(s, self.signal_handler)
self._org_handlers[c] = signal.getsignal(c)
signal.signal(c, self.signal_handler)
except Exception:
pass
@ -2029,13 +2050,16 @@ class Task(_Task):
def exc_handler(self, exctype, value, traceback, *args, **kwargs):
if self._except_recursion_protection_flag:
# noinspection PyArgumentList
return sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
self._except_recursion_protection_flag = True
self.exception = value
if self._orig_exc_handler:
# noinspection PyArgumentList
ret = self._orig_exc_handler(exctype, value, traceback, *args, **kwargs)
else:
# noinspection PyNoneFunctionAssignment, PyArgumentList
ret = sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
self._except_recursion_protection_flag = False
@ -2069,6 +2093,7 @@ class Task(_Task):
# remove stdout logger, just in case
# noinspection PyBroadException
try:
# noinspection PyProtectedMember
Logger._remove_std_logger()
except Exception:
pass