mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Fix multiprocessing Pool throw exception in pool would hang the program. Call original signal handler and re-flush stdout.
Add py2.7 support for get_current_thread_id.
This commit is contained in:
parent
42148345b4
commit
2c68a188d9
@ -55,6 +55,7 @@ from .utilities.proxy_object import ProxyDictPreWrite, ProxyDictPostWrite, flatt
|
|||||||
nested_from_flat_dictionary, naive_nested_from_flat_dictionary
|
nested_from_flat_dictionary, naive_nested_from_flat_dictionary
|
||||||
from .utilities.resource_monitor import ResourceMonitor
|
from .utilities.resource_monitor import ResourceMonitor
|
||||||
from .utilities.seed import make_deterministic
|
from .utilities.seed import make_deterministic
|
||||||
|
from .utilities.lowlevel.threads import get_current_thread_id
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
from .backend_interface.task.args import _Arguments
|
from .backend_interface.task.args import _Arguments
|
||||||
|
|
||||||
@ -2307,7 +2308,7 @@ class Task(_Task):
|
|||||||
|
|
||||||
def _at_exit(self):
|
def _at_exit(self):
|
||||||
# protect sub-process at_exit (should never happen)
|
# protect sub-process at_exit (should never happen)
|
||||||
if self._at_exit_called:
|
if self._at_exit_called and self._at_exit_called != get_current_thread_id():
|
||||||
return
|
return
|
||||||
# shutdown will clear the main, so we have to store it before.
|
# shutdown will clear the main, so we have to store it before.
|
||||||
# is_main = self.is_main_task()
|
# is_main = self.is_main_task()
|
||||||
@ -2324,14 +2325,22 @@ class Task(_Task):
|
|||||||
"""
|
"""
|
||||||
# protect sub-process at_exit
|
# protect sub-process at_exit
|
||||||
if self._at_exit_called:
|
if self._at_exit_called:
|
||||||
|
# if we are called twice (signal in the middle of the shutdown),
|
||||||
|
# make sure we flush stdout, this is the best we can do.
|
||||||
|
if self._at_exit_called == get_current_thread_id() and self._logger and self.__is_subprocess():
|
||||||
|
self._logger.set_flush_period(None)
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
self._logger._close_stdout_handler(wait=True)
|
||||||
|
self._at_exit_called = True
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# from here only a single thread can re-enter
|
||||||
|
self._at_exit_called = get_current_thread_id()
|
||||||
|
|
||||||
is_sub_process = self.__is_subprocess()
|
is_sub_process = self.__is_subprocess()
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# from here do not get into watch dog
|
|
||||||
self._at_exit_called = True
|
|
||||||
wait_for_uploads = True
|
wait_for_uploads = True
|
||||||
# first thing mark task as stopped, so we will not end up with "running" on lost tasks
|
# first thing mark task as stopped, so we will not end up with "running" on lost tasks
|
||||||
# if we are running remotely, the daemon will take care of it
|
# if we are running remotely, the daemon will take care of it
|
||||||
@ -2464,6 +2473,9 @@ class Task(_Task):
|
|||||||
pass
|
pass
|
||||||
self._edit_lock = None
|
self._edit_lock = None
|
||||||
|
|
||||||
|
# make sure no one will re-enter the shutdown method
|
||||||
|
self._at_exit_called = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __register_at_exit(cls, exit_callback, only_remove_signal_and_exception_hooks=False):
|
def __register_at_exit(cls, exit_callback, only_remove_signal_and_exception_hooks=False):
|
||||||
class ExitHooks(object):
|
class ExitHooks(object):
|
||||||
@ -2556,8 +2568,11 @@ class Task(_Task):
|
|||||||
if self._signal_recursion_protection_flag:
|
if self._signal_recursion_protection_flag:
|
||||||
# call original
|
# call original
|
||||||
org_handler = self._org_handlers.get(sig)
|
org_handler = self._org_handlers.get(sig)
|
||||||
if isinstance(org_handler, Callable):
|
if callable(org_handler):
|
||||||
org_handler = org_handler(sig, frame)
|
org_handler = org_handler(sig, frame)
|
||||||
|
else:
|
||||||
|
signal.signal(sig, org_handler or signal.SIG_DFL)
|
||||||
|
os.kill(os.getpid(), sig)
|
||||||
return org_handler
|
return org_handler
|
||||||
|
|
||||||
self._signal_recursion_protection_flag = True
|
self._signal_recursion_protection_flag = True
|
||||||
@ -2571,12 +2586,13 @@ class Task(_Task):
|
|||||||
pass
|
pass
|
||||||
# call original signal handler
|
# call original signal handler
|
||||||
org_handler = self._org_handlers.get(sig)
|
org_handler = self._org_handlers.get(sig)
|
||||||
if isinstance(org_handler, Callable):
|
self._org_handlers[sig] = None
|
||||||
# noinspection PyBroadException
|
if callable(org_handler):
|
||||||
try:
|
ret = org_handler(sig, frame)
|
||||||
org_handler = org_handler(sig, frame)
|
else:
|
||||||
except Exception:
|
signal.signal(sig, org_handler or signal.SIG_DFL)
|
||||||
org_handler = signal.SIG_DFL
|
ret = 0
|
||||||
|
|
||||||
# remove stdout logger, just in case
|
# remove stdout logger, just in case
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
@ -2584,9 +2600,14 @@ class Task(_Task):
|
|||||||
Logger._remove_std_logger()
|
Logger._remove_std_logger()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if not callable(org_handler):
|
||||||
|
os.kill(os.getpid(), sig)
|
||||||
|
|
||||||
self._signal_recursion_protection_flag = False
|
self._signal_recursion_protection_flag = False
|
||||||
|
|
||||||
# return handler result
|
# return handler result
|
||||||
return org_handler
|
return ret
|
||||||
|
|
||||||
# we only remove the signals since this will hang subprocesses
|
# we only remove the signals since this will hang subprocesses
|
||||||
if only_remove_signal_and_exception_hooks:
|
if only_remove_signal_and_exception_hooks:
|
||||||
|
@ -1,9 +1,14 @@
|
|||||||
import ctypes
|
import ctypes
|
||||||
import threading
|
import threading
|
||||||
|
import six
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_thread_id():
|
||||||
|
return threading._get_ident() if six.PY2 else threading.get_ident()
|
||||||
|
|
||||||
|
|
||||||
# Nasty hack to raise exception for other threads
|
# Nasty hack to raise exception for other threads
|
||||||
def _lowlevel_async_raise(thread_obj, exception=None):
|
def _lowlevel_async_raise(thread_obj, exception=None):
|
||||||
NULL = 0
|
NULL = 0
|
||||||
|
Loading…
Reference in New Issue
Block a user