Fix on_abort bash callback, if main processes leave while on_abort callback is running, wait for the on_abort to complete

This commit is contained in:
clearml 2025-02-24 13:46:55 +02:00
parent d30d4e7e61
commit 66494c598d

View File

@ -578,7 +578,7 @@ class TaskStopSignal(object):
self._bash_callback = shlex.split(ENV_ABORT_CALLBACK_CMD.get())
# make sure we are re-testing in subprocesses
os.environ[ENV_ABORT_CALLBACK_CMD.vars[0]+"_REGISTERED"] = ENV_ABORT_CALLBACK_CMD.get()
os.environ.pop(ENV_ABORT_CALLBACK_CMD.vars[0], None)
ENV_ABORT_CALLBACK_CMD.pop()
# noinspection PyBroadException
try:
@ -611,7 +611,21 @@ class TaskStopSignal(object):
self._self_monitor_thread.start()
print("INFO: bash on_abort monitor thread started")
def stop_monitor_thread(self):
def stop_monitor_thread(self, wait_on_abort_timeout=-1):
# if no on_abort callback was launched, or it's done, nothing to do
if not self._bash_callback_thread or not self._self_monitor_thread:
self._self_monitor_thread = None
return
# we need to wait for the on_abort callback to be done
if wait_on_abort_timeout and wait_on_abort_timeout < 0:
wait_on_abort_timeout = self._abort_callback_max_timeout
if wait_on_abort_timeout:
tic = time()
while self._bash_callback_thread and (time() - tic < wait_on_abort_timeout):
sleep(1)
self._self_monitor_thread = None
def _monitor_thread_loop(self, polling_interval_sec=10):
@ -620,6 +634,7 @@ class TaskStopSignal(object):
stop_reason = self.test()
if stop_reason != TaskStopSignal.default:
# mark quit loop
self._self_monitor_thread = None
break
def _bash_callback_launch_thread(self):
@ -641,7 +656,10 @@ class TaskStopSignal(object):
task=self.task_id, runtime=runtime_properties, force=True)
except Exception as ex:
print("WARNING: failed updating bash callback completed: {}".format(ex))
return
# mark we are done
self._bash_callback_thread = None
return
def test(self):
# type: () -> TaskStopReason
@ -3224,6 +3242,9 @@ class Worker(ServiceCommandSection):
os.execv(command.argv[0].as_posix(), tuple([command.argv[0].as_posix()])+command.argv[1:])
else:
exit_code = command.check_call(cwd=script_dir)
# if we have an on_abort callback still running, wait for it
if stop_signal:
stop_signal.stop_monitor_thread()
exit(exit_code)
except subprocess.CalledProcessError as ex:
# non zero return code
@ -3257,6 +3278,10 @@ class Worker(ServiceCommandSection):
self.log_traceback(e)
exit_code = -1
# if we have an on_abort callback still running, wait for it
if stop_signal:
stop_signal.stop_monitor_thread()
# kill leftover processes
kill_all_child_processes()
@ -3266,8 +3291,6 @@ class Worker(ServiceCommandSection):
exit_code = exit_code if exit_code != ExitStatus.interrupted else -1
if not disable_monitoring:
if stop_signal:
stop_signal.stop_monitor_thread()
# we need to change task status according to exit code
self.handle_task_termination(current_task.id, exit_code, TaskStopReason.no_stop)
self.stop_monitor()