diff --git a/clearml_agent/commands/worker.py b/clearml_agent/commands/worker.py index eb56ebe..98b6533 100644 --- a/clearml_agent/commands/worker.py +++ b/clearml_agent/commands/worker.py @@ -578,7 +578,7 @@ class TaskStopSignal(object): self._bash_callback = shlex.split(ENV_ABORT_CALLBACK_CMD.get()) # make sure we are re-testing in subprocesses os.environ[ENV_ABORT_CALLBACK_CMD.vars[0]+"_REGISTERED"] = ENV_ABORT_CALLBACK_CMD.get() - os.environ.pop(ENV_ABORT_CALLBACK_CMD.vars[0], None) + ENV_ABORT_CALLBACK_CMD.pop() # noinspection PyBroadException try: @@ -611,7 +611,21 @@ class TaskStopSignal(object): self._self_monitor_thread.start() print("INFO: bash on_abort monitor thread started") - def stop_monitor_thread(self): + def stop_monitor_thread(self, wait_on_abort_timeout=-1): + # if no on_abort callback was launched, or it's done, nothing to do + if not self._bash_callback_thread or not self._self_monitor_thread: + self._self_monitor_thread = None + return + + # we need to wait for the on_abort callback to be done + if wait_on_abort_timeout and wait_on_abort_timeout < 0: + wait_on_abort_timeout = self._abort_callback_max_timeout + + if wait_on_abort_timeout: + tic = time() + while self._bash_callback_thread and (time() - tic < wait_on_abort_timeout): + sleep(1) + self._self_monitor_thread = None def _monitor_thread_loop(self, polling_interval_sec=10): @@ -620,6 +634,7 @@ class TaskStopSignal(object): stop_reason = self.test() if stop_reason != TaskStopSignal.default: # mark quit loop + self._self_monitor_thread = None break def _bash_callback_launch_thread(self): @@ -641,7 +656,10 @@ class TaskStopSignal(object): task=self.task_id, runtime=runtime_properties, force=True) except Exception as ex: print("WARNING: failed updating bash callback completed: {}".format(ex)) - return + + # mark we are done + self._bash_callback_thread = None + return def test(self): # type: () -> TaskStopReason @@ -3224,6 +3242,9 @@ class Worker(ServiceCommandSection): os.execv(command.argv[0].as_posix(), tuple([command.argv[0].as_posix()])+command.argv[1:]) else: exit_code = command.check_call(cwd=script_dir) + # if we have an on_abort callback still running, wait for it + if stop_signal: + stop_signal.stop_monitor_thread() exit(exit_code) except subprocess.CalledProcessError as ex: # non zero return code @@ -3257,6 +3278,10 @@ class Worker(ServiceCommandSection): self.log_traceback(e) exit_code = -1 + # if we have an on_abort callback still running, wait for it + if stop_signal: + stop_signal.stop_monitor_thread() + # kill leftover processes kill_all_child_processes() @@ -3266,8 +3291,6 @@ class Worker(ServiceCommandSection): exit_code = exit_code if exit_code != ExitStatus.interrupted else -1 if not disable_monitoring: - if stop_signal: - stop_signal.stop_monitor_thread() # we need to change task status according to exit code self.handle_task_termination(current_task.id, exit_code, TaskStopReason.no_stop) self.stop_monitor()