Use os.register_at_fork instead of monkey patching fork for python > 3.6

This commit is contained in:
Alex Burlacu 2023-08-21 14:06:59 +03:00
parent f6ad5e6c06
commit 0b521b00a6

View File

@ -89,6 +89,7 @@ class SimpleQueueWrapper(object):
class PatchOsFork(object): class PatchOsFork(object):
_original_fork = None _original_fork = None
_registered_fork_callbacks = False
_current_task = None _current_task = None
_original_process_run = None _original_process_run = None
@ -104,13 +105,20 @@ class PatchOsFork(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# only once # only once
if cls._original_fork: if cls._registered_fork_callbacks or cls._original_fork:
return return
try:
os.register_at_fork(before=PatchOsFork._fork_callback_before,
after_in_child=PatchOsFork._fork_callback_after_child)
cls._registered_fork_callbacks = True
except Exception:
# python <3.6
if six.PY2: if six.PY2:
cls._original_fork = staticmethod(os.fork) cls._original_fork = staticmethod(os.fork)
else: else:
cls._original_fork = os.fork cls._original_fork = os.fork
os.fork = cls._patched_fork os.fork = cls._patched_fork
except Exception: except Exception:
pass pass
@ -182,10 +190,9 @@ class PatchOsFork(object):
pass pass
@staticmethod @staticmethod
def _patched_fork(*args, **kwargs): def _fork_callback_before():
if not PatchOsFork._current_task: if not PatchOsFork._current_task:
return PatchOsFork._original_fork(*args, **kwargs) return
from ..task import Task from ..task import Task
# ensure deferred is done, but never try to generate a Task object # ensure deferred is done, but never try to generate a Task object
@ -195,15 +202,17 @@ class PatchOsFork(object):
# noinspection PyProtectedMember # noinspection PyProtectedMember
Task._wait_for_deferred(task) Task._wait_for_deferred(task)
ret = PatchOsFork._original_fork(*args, **kwargs) @staticmethod
def _fork_callback_after_child():
if not PatchOsFork._current_task: if not PatchOsFork._current_task:
return ret return
# Make sure the new process stdout is logged
if not ret: from ..task import Task
# force creating a Task # force creating a Task
task = Task.current_task() task = Task.current_task()
if not task: if not task:
return ret return
PatchOsFork._current_task = task PatchOsFork._current_task = task
# # Hack: now make sure we setup the reporter threads (Log+Reporter) # # Hack: now make sure we setup the reporter threads (Log+Reporter)
@ -237,4 +246,19 @@ class PatchOsFork(object):
os._exit = _at_exit_callback os._exit = _at_exit_callback
@staticmethod
def _patched_fork(*args, **kwargs):
if not PatchOsFork._current_task:
return PatchOsFork._original_fork(*args, **kwargs)
PatchOsFork._fork_callback_before()
ret = PatchOsFork._original_fork(*args, **kwargs)
if not PatchOsFork._current_task:
return ret
# Make sure the new process stdout is logged
if not ret:
PatchOsFork._fork_callback_after_child()
return ret return ret