From 0e2657421f9463fbbff276be77f5ed2a6dc3f7c1 Mon Sep 17 00:00:00 2001 From: clearml <> Date: Mon, 24 Feb 2025 13:46:00 +0200 Subject: [PATCH] Added CLEARML_AGENT_ABORT_CALLBACK_CMD and CLEARML_AGENT_ABORT_CALLBACK_TIMEOUT (default 180 sec) to define callback command to be called on abort status change --- clearml_agent/commands/worker.py | 139 ++++++++++++++++++++++++++++--- clearml_agent/definitions.py | 3 + 2 files changed, 132 insertions(+), 10 deletions(-) diff --git a/clearml_agent/commands/worker.py b/clearml_agent/commands/worker.py index 45237c6..eb56ebe 100644 --- a/clearml_agent/commands/worker.py +++ b/clearml_agent/commands/worker.py @@ -22,6 +22,7 @@ from datetime import datetime from functools import partial from os.path import basename from tempfile import mkdtemp, NamedTemporaryFile +from threading import Thread from time import sleep, time from typing import Text, Optional, Any, Tuple, List, Dict, Mapping, Union @@ -78,7 +79,7 @@ from clearml_agent.definitions import ( ENV_AGENT_FORCE_EXEC_SCRIPT, ENV_TEMP_STDOUT_FILE_DIR, ENV_AGENT_FORCE_TASK_INIT, - ENV_AGENT_DEBUG_GET_NEXT_TASK, + ENV_AGENT_DEBUG_GET_NEXT_TASK, ENV_ABORT_CALLBACK_CMD, ENV_ABORT_CALLBACK_CMD_TIMEOUT, ) from clearml_agent.definitions import WORKING_REPOSITORY_DIR, PIP_EXTRA_INDICES from clearml_agent.errors import ( @@ -534,14 +535,18 @@ class TaskStopSignal(object): ] default = TaskStopReason.no_stop stopping_message = "stopping" + property_abort_callback_completed = "_abort_callback_completed" + property_abort_callback_timeout = "_abort_callback_timeout" + property_abort_poll_freq = "_abort_poll_freq" - def __init__(self, command, session, events_service, task_id): - # type: (Worker, Session, Events, Text) -> () + def __init__(self, command, session, events_service, task_id, bash_cwd=None): + # type: (Worker, Session, Events, Text, Text) -> None """ :param command: workers command :param session: command session :param events_service: events service object :param task_id: followed task ID + :param bash_cwd: cwd for bash on_abort callback """ self.command = command self.session = session @@ -553,6 +558,90 @@ class TaskStopSignal(object): self._active_callback_timestamp = None self._active_callback_timeout = None self._abort_callback_max_timeout = float(self.session.config.get('agent.abort_callback_max_timeout', 1800)) + self._bash_callback_cwd = bash_cwd + self._bash_callback = None + self._bash_callback_timeout = None + self._bash_callback_thread = None + self._self_monitor_thread = None + self.register_bash_callback() + + @classmethod + def check_registered_bash_callback(cls): + return bool((ENV_ABORT_CALLBACK_CMD.get() or + os.environ.get(ENV_ABORT_CALLBACK_CMD.vars[0] + "_REGISTERED") or + "").strip()) + + def register_bash_callback(self): + # check if the env variable defined a callback + if not (ENV_ABORT_CALLBACK_CMD.get() or "").strip(): + return + self._bash_callback = shlex.split(ENV_ABORT_CALLBACK_CMD.get()) + # make sure we are re-testing in subprocesses + os.environ[ENV_ABORT_CALLBACK_CMD.vars[0]+"_REGISTERED"] = ENV_ABORT_CALLBACK_CMD.get() + os.environ.pop(ENV_ABORT_CALLBACK_CMD.vars[0], None) + + # noinspection PyBroadException + try: + self._bash_callback_timeout = int((ENV_ABORT_CALLBACK_CMD_TIMEOUT.get() or "").strip() or 180) + except Exception: + self._bash_callback_timeout = 180 + + # update the task runtime properties + try: + task_info = self.session.get( + service="tasks", action="get_all", version="2.13", + id=[self.task_id], only_fields=["runtime"])['tasks'][0] + runtime_properties = task_info.get("runtime") or {} + runtime_properties[self.property_abort_callback_timeout] = str(self._bash_callback_timeout) + runtime_properties[self.property_abort_poll_freq] = str(10) + self.session.post(service="tasks", action="edit", version="2.13", + task=self.task_id, runtime=runtime_properties, force=True) + except Exception as ex: + print("WARNING: failed registering bash callback: {}".format(ex)) + return + + print("INFO: registering bash callback, timout {}s: {}".format( + self._bash_callback_timeout, self._bash_callback)) + + def start_monitor_thread(self, polling_interval_sec=10): + if self._self_monitor_thread: + return + self._self_monitor_thread = Thread(target=self._monitor_thread_loop, args=(polling_interval_sec, )) + self._self_monitor_thread.daemon = True + self._self_monitor_thread.start() + print("INFO: bash on_abort monitor thread started") + + def stop_monitor_thread(self): + self._self_monitor_thread = None + + def _monitor_thread_loop(self, polling_interval_sec=10): + while self._self_monitor_thread is not None: + sleep(polling_interval_sec) + stop_reason = self.test() + if stop_reason != TaskStopSignal.default: + # mark quit loop + break + + def _bash_callback_launch_thread(self): + command = Argv(*self._bash_callback) + print("INFO: running bash on_abort callback: {}".format(command.pretty())) + try: + command.check_call(cwd=self._bash_callback_cwd or None) + except Exception as ex: + print("WARNING: failed running bash callback: {}".format(ex)) + + # update the task runtime properties + try: + task_info = self.session.get( + service="tasks", action="get_all", version="2.13", + id=[self.task_id], only_fields=["runtime"])['tasks'][0] + runtime_properties = task_info.get("runtime") or {} + runtime_properties[self.property_abort_callback_completed] = str(1) + self.session.post(service="tasks", action="edit", version="2.13", + task=self.task_id, runtime=runtime_properties, force=True) + except Exception as ex: + print("WARNING: failed updating bash callback completed: {}".format(ex)) + return def test(self): # type: () -> TaskStopReason @@ -583,8 +672,12 @@ class TaskStopSignal(object): try: task_info = self.session.get( service="tasks", action="get_all", version="2.13", id=[self.task_id], - only_fields=["status", "status_message", "runtime._abort_callback_completed"]) - cb_completed = task_info['tasks'][0]['runtime'].get('_abort_callback_completed', None) + only_fields=[ + "status", "status_message", + "runtime.{}".format(self.property_abort_callback_completed) + ] + ) + cb_completed = task_info['tasks'][0]['runtime'].get(self.property_abort_callback_completed, None) except: # noqa pass @@ -620,6 +713,13 @@ class TaskStopSignal(object): session=self.session ) + # launch bash callback if exists + if self._bash_callback and not self._bash_callback_thread: + print("INFO: bash on_abort callback thread started") + self._bash_callback_thread = Thread(target=self._bash_callback_launch_thread) + self._bash_callback_thread.daemon = True + self._bash_callback_thread.start() + self._active_callback_timestamp = time() self._active_callback_timeout = timeout return bool(cb_completed) @@ -629,11 +729,16 @@ class TaskStopSignal(object): try: task_info = self.session.get( service="tasks", action="get_all", version="2.13", id=[self.task_id], - only_fields=["status", "status_message", "runtime._abort_callback_timeout", - "runtime._abort_poll_freq", "runtime._abort_callback_completed"]) - abort_timeout = task_info['tasks'][0]['runtime'].get('_abort_callback_timeout', 0) - poll_timeout = task_info['tasks'][0]['runtime'].get('_abort_poll_freq', 0) - cb_completed = task_info['tasks'][0]['runtime'].get('_abort_callback_completed', None) + only_fields=[ + "status", "status_message", + "runtime.{}".format(self.property_abort_callback_timeout), + "runtime.{}".format(self.property_abort_poll_freq), + "runtime.{}".format(self.property_abort_callback_completed) + ] + ) + abort_timeout = task_info['tasks'][0]['runtime'].get(self.property_abort_callback_timeout, 0) + poll_timeout = task_info['tasks'][0]['runtime'].get(self.property_abort_poll_freq, 0) + cb_completed = task_info['tasks'][0]['runtime'].get(self.property_abort_callback_completed, None) except: # noqa abort_timeout = None poll_timeout = None @@ -3087,6 +3192,18 @@ class Worker(ServiceCommandSection): force=True, ) + stop_signal = None + if TaskStopSignal.check_registered_bash_callback(): + use_execv = False + stop_signal = TaskStopSignal( + command=self, + session=self._session, + events_service=self.get_service(Events), + task_id=task_id, + bash_cwd=script_dir + ) + stop_signal.start_monitor_thread(polling_interval_sec=self._polling_interval) + # check if we need to add encoding to the subprocess if sys.getfilesystemencoding() == 'ascii' and not os.environ.get("PYTHONIOENCODING"): os.environ["PYTHONIOENCODING"] = "utf-8" @@ -3149,6 +3266,8 @@ class Worker(ServiceCommandSection): exit_code = exit_code if exit_code != ExitStatus.interrupted else -1 if not disable_monitoring: + if stop_signal: + stop_signal.stop_monitor_thread() # we need to change task status according to exit code self.handle_task_termination(current_task.id, exit_code, TaskStopReason.no_stop) self.stop_monitor() diff --git a/clearml_agent/definitions.py b/clearml_agent/definitions.py index 80bdced..57986d9 100644 --- a/clearml_agent/definitions.py +++ b/clearml_agent/definitions.py @@ -187,6 +187,9 @@ ENV_CHILD_AGENTS_COUNT_CMD = EnvironmentConfig("CLEARML_AGENT_CHILD_AGENTS_COUNT ENV_DOCKER_ARGS_FILTERS = EnvironmentConfig("CLEARML_AGENT_DOCKER_ARGS_FILTERS") ENV_DOCKER_ARGS_HIDE_ENV = EnvironmentConfig("CLEARML_AGENT_DOCKER_ARGS_HIDE_ENV") ENV_CONFIG_BC_IN_STANDALONE = EnvironmentConfig("CLEARML_AGENT_STANDALONE_CONFIG_BC", type=bool) +ENV_ABORT_CALLBACK_CMD = EnvironmentConfig("CLEARML_AGENT_ABORT_CALLBACK_CMD") +ENV_ABORT_CALLBACK_CMD_TIMEOUT = EnvironmentConfig("CLEARML_AGENT_ABORT_CALLBACK_TIMEOUT") + """ Maintain backwards compatible configuration when launching in standalone mode """ ENV_FORCE_DOCKER_AGENT_REPO = EnvironmentConfig("FORCE_CLEARML_AGENT_REPO", "CLEARML_AGENT_DOCKER_AGENT_REPO")