Added CLEARML_AGENT_ABORT_CALLBACK_CMD and CLEARML_AGENT_ABORT_CALLBACK_TIMEOUT

(default 180 sec) to define callback command to be called on abort status change
This commit is contained in:
clearml 2025-02-24 13:46:00 +02:00
parent ee286e2fb7
commit 0e2657421f
2 changed files with 132 additions and 10 deletions

View File

@ -22,6 +22,7 @@ from datetime import datetime
from functools import partial from functools import partial
from os.path import basename from os.path import basename
from tempfile import mkdtemp, NamedTemporaryFile from tempfile import mkdtemp, NamedTemporaryFile
from threading import Thread
from time import sleep, time from time import sleep, time
from typing import Text, Optional, Any, Tuple, List, Dict, Mapping, Union from typing import Text, Optional, Any, Tuple, List, Dict, Mapping, Union
@ -78,7 +79,7 @@ from clearml_agent.definitions import (
ENV_AGENT_FORCE_EXEC_SCRIPT, ENV_AGENT_FORCE_EXEC_SCRIPT,
ENV_TEMP_STDOUT_FILE_DIR, ENV_TEMP_STDOUT_FILE_DIR,
ENV_AGENT_FORCE_TASK_INIT, ENV_AGENT_FORCE_TASK_INIT,
ENV_AGENT_DEBUG_GET_NEXT_TASK, ENV_AGENT_DEBUG_GET_NEXT_TASK, ENV_ABORT_CALLBACK_CMD, ENV_ABORT_CALLBACK_CMD_TIMEOUT,
) )
from clearml_agent.definitions import WORKING_REPOSITORY_DIR, PIP_EXTRA_INDICES from clearml_agent.definitions import WORKING_REPOSITORY_DIR, PIP_EXTRA_INDICES
from clearml_agent.errors import ( from clearml_agent.errors import (
@ -534,14 +535,18 @@ class TaskStopSignal(object):
] ]
default = TaskStopReason.no_stop default = TaskStopReason.no_stop
stopping_message = "stopping" stopping_message = "stopping"
property_abort_callback_completed = "_abort_callback_completed"
property_abort_callback_timeout = "_abort_callback_timeout"
property_abort_poll_freq = "_abort_poll_freq"
def __init__(self, command, session, events_service, task_id): def __init__(self, command, session, events_service, task_id, bash_cwd=None):
# type: (Worker, Session, Events, Text) -> () # type: (Worker, Session, Events, Text, Text) -> None
""" """
:param command: workers command :param command: workers command
:param session: command session :param session: command session
:param events_service: events service object :param events_service: events service object
:param task_id: followed task ID :param task_id: followed task ID
:param bash_cwd: cwd for bash on_abort callback
""" """
self.command = command self.command = command
self.session = session self.session = session
@ -553,6 +558,90 @@ class TaskStopSignal(object):
self._active_callback_timestamp = None self._active_callback_timestamp = None
self._active_callback_timeout = None self._active_callback_timeout = None
self._abort_callback_max_timeout = float(self.session.config.get('agent.abort_callback_max_timeout', 1800)) self._abort_callback_max_timeout = float(self.session.config.get('agent.abort_callback_max_timeout', 1800))
self._bash_callback_cwd = bash_cwd
self._bash_callback = None
self._bash_callback_timeout = None
self._bash_callback_thread = None
self._self_monitor_thread = None
self.register_bash_callback()
@classmethod
def check_registered_bash_callback(cls):
return bool((ENV_ABORT_CALLBACK_CMD.get() or
os.environ.get(ENV_ABORT_CALLBACK_CMD.vars[0] + "_REGISTERED") or
"").strip())
def register_bash_callback(self):
# check if the env variable defined a callback
if not (ENV_ABORT_CALLBACK_CMD.get() or "").strip():
return
self._bash_callback = shlex.split(ENV_ABORT_CALLBACK_CMD.get())
# make sure we are re-testing in subprocesses
os.environ[ENV_ABORT_CALLBACK_CMD.vars[0]+"_REGISTERED"] = ENV_ABORT_CALLBACK_CMD.get()
os.environ.pop(ENV_ABORT_CALLBACK_CMD.vars[0], None)
# noinspection PyBroadException
try:
self._bash_callback_timeout = int((ENV_ABORT_CALLBACK_CMD_TIMEOUT.get() or "").strip() or 180)
except Exception:
self._bash_callback_timeout = 180
# update the task runtime properties
try:
task_info = self.session.get(
service="tasks", action="get_all", version="2.13",
id=[self.task_id], only_fields=["runtime"])['tasks'][0]
runtime_properties = task_info.get("runtime") or {}
runtime_properties[self.property_abort_callback_timeout] = str(self._bash_callback_timeout)
runtime_properties[self.property_abort_poll_freq] = str(10)
self.session.post(service="tasks", action="edit", version="2.13",
task=self.task_id, runtime=runtime_properties, force=True)
except Exception as ex:
print("WARNING: failed registering bash callback: {}".format(ex))
return
print("INFO: registering bash callback, timout {}s: {}".format(
self._bash_callback_timeout, self._bash_callback))
def start_monitor_thread(self, polling_interval_sec=10):
if self._self_monitor_thread:
return
self._self_monitor_thread = Thread(target=self._monitor_thread_loop, args=(polling_interval_sec, ))
self._self_monitor_thread.daemon = True
self._self_monitor_thread.start()
print("INFO: bash on_abort monitor thread started")
def stop_monitor_thread(self):
self._self_monitor_thread = None
def _monitor_thread_loop(self, polling_interval_sec=10):
while self._self_monitor_thread is not None:
sleep(polling_interval_sec)
stop_reason = self.test()
if stop_reason != TaskStopSignal.default:
# mark quit loop
break
def _bash_callback_launch_thread(self):
command = Argv(*self._bash_callback)
print("INFO: running bash on_abort callback: {}".format(command.pretty()))
try:
command.check_call(cwd=self._bash_callback_cwd or None)
except Exception as ex:
print("WARNING: failed running bash callback: {}".format(ex))
# update the task runtime properties
try:
task_info = self.session.get(
service="tasks", action="get_all", version="2.13",
id=[self.task_id], only_fields=["runtime"])['tasks'][0]
runtime_properties = task_info.get("runtime") or {}
runtime_properties[self.property_abort_callback_completed] = str(1)
self.session.post(service="tasks", action="edit", version="2.13",
task=self.task_id, runtime=runtime_properties, force=True)
except Exception as ex:
print("WARNING: failed updating bash callback completed: {}".format(ex))
return
def test(self): def test(self):
# type: () -> TaskStopReason # type: () -> TaskStopReason
@ -583,8 +672,12 @@ class TaskStopSignal(object):
try: try:
task_info = self.session.get( task_info = self.session.get(
service="tasks", action="get_all", version="2.13", id=[self.task_id], service="tasks", action="get_all", version="2.13", id=[self.task_id],
only_fields=["status", "status_message", "runtime._abort_callback_completed"]) only_fields=[
cb_completed = task_info['tasks'][0]['runtime'].get('_abort_callback_completed', None) "status", "status_message",
"runtime.{}".format(self.property_abort_callback_completed)
]
)
cb_completed = task_info['tasks'][0]['runtime'].get(self.property_abort_callback_completed, None)
except: # noqa except: # noqa
pass pass
@ -620,6 +713,13 @@ class TaskStopSignal(object):
session=self.session session=self.session
) )
# launch bash callback if exists
if self._bash_callback and not self._bash_callback_thread:
print("INFO: bash on_abort callback thread started")
self._bash_callback_thread = Thread(target=self._bash_callback_launch_thread)
self._bash_callback_thread.daemon = True
self._bash_callback_thread.start()
self._active_callback_timestamp = time() self._active_callback_timestamp = time()
self._active_callback_timeout = timeout self._active_callback_timeout = timeout
return bool(cb_completed) return bool(cb_completed)
@ -629,11 +729,16 @@ class TaskStopSignal(object):
try: try:
task_info = self.session.get( task_info = self.session.get(
service="tasks", action="get_all", version="2.13", id=[self.task_id], service="tasks", action="get_all", version="2.13", id=[self.task_id],
only_fields=["status", "status_message", "runtime._abort_callback_timeout", only_fields=[
"runtime._abort_poll_freq", "runtime._abort_callback_completed"]) "status", "status_message",
abort_timeout = task_info['tasks'][0]['runtime'].get('_abort_callback_timeout', 0) "runtime.{}".format(self.property_abort_callback_timeout),
poll_timeout = task_info['tasks'][0]['runtime'].get('_abort_poll_freq', 0) "runtime.{}".format(self.property_abort_poll_freq),
cb_completed = task_info['tasks'][0]['runtime'].get('_abort_callback_completed', None) "runtime.{}".format(self.property_abort_callback_completed)
]
)
abort_timeout = task_info['tasks'][0]['runtime'].get(self.property_abort_callback_timeout, 0)
poll_timeout = task_info['tasks'][0]['runtime'].get(self.property_abort_poll_freq, 0)
cb_completed = task_info['tasks'][0]['runtime'].get(self.property_abort_callback_completed, None)
except: # noqa except: # noqa
abort_timeout = None abort_timeout = None
poll_timeout = None poll_timeout = None
@ -3087,6 +3192,18 @@ class Worker(ServiceCommandSection):
force=True, force=True,
) )
stop_signal = None
if TaskStopSignal.check_registered_bash_callback():
use_execv = False
stop_signal = TaskStopSignal(
command=self,
session=self._session,
events_service=self.get_service(Events),
task_id=task_id,
bash_cwd=script_dir
)
stop_signal.start_monitor_thread(polling_interval_sec=self._polling_interval)
# check if we need to add encoding to the subprocess # check if we need to add encoding to the subprocess
if sys.getfilesystemencoding() == 'ascii' and not os.environ.get("PYTHONIOENCODING"): if sys.getfilesystemencoding() == 'ascii' and not os.environ.get("PYTHONIOENCODING"):
os.environ["PYTHONIOENCODING"] = "utf-8" os.environ["PYTHONIOENCODING"] = "utf-8"
@ -3149,6 +3266,8 @@ class Worker(ServiceCommandSection):
exit_code = exit_code if exit_code != ExitStatus.interrupted else -1 exit_code = exit_code if exit_code != ExitStatus.interrupted else -1
if not disable_monitoring: if not disable_monitoring:
if stop_signal:
stop_signal.stop_monitor_thread()
# we need to change task status according to exit code # we need to change task status according to exit code
self.handle_task_termination(current_task.id, exit_code, TaskStopReason.no_stop) self.handle_task_termination(current_task.id, exit_code, TaskStopReason.no_stop)
self.stop_monitor() self.stop_monitor()

View File

@ -187,6 +187,9 @@ ENV_CHILD_AGENTS_COUNT_CMD = EnvironmentConfig("CLEARML_AGENT_CHILD_AGENTS_COUNT
ENV_DOCKER_ARGS_FILTERS = EnvironmentConfig("CLEARML_AGENT_DOCKER_ARGS_FILTERS") ENV_DOCKER_ARGS_FILTERS = EnvironmentConfig("CLEARML_AGENT_DOCKER_ARGS_FILTERS")
ENV_DOCKER_ARGS_HIDE_ENV = EnvironmentConfig("CLEARML_AGENT_DOCKER_ARGS_HIDE_ENV") ENV_DOCKER_ARGS_HIDE_ENV = EnvironmentConfig("CLEARML_AGENT_DOCKER_ARGS_HIDE_ENV")
ENV_CONFIG_BC_IN_STANDALONE = EnvironmentConfig("CLEARML_AGENT_STANDALONE_CONFIG_BC", type=bool) ENV_CONFIG_BC_IN_STANDALONE = EnvironmentConfig("CLEARML_AGENT_STANDALONE_CONFIG_BC", type=bool)
ENV_ABORT_CALLBACK_CMD = EnvironmentConfig("CLEARML_AGENT_ABORT_CALLBACK_CMD")
ENV_ABORT_CALLBACK_CMD_TIMEOUT = EnvironmentConfig("CLEARML_AGENT_ABORT_CALLBACK_TIMEOUT")
""" Maintain backwards compatible configuration when launching in standalone mode """ """ Maintain backwards compatible configuration when launching in standalone mode """
ENV_FORCE_DOCKER_AGENT_REPO = EnvironmentConfig("FORCE_CLEARML_AGENT_REPO", "CLEARML_AGENT_DOCKER_AGENT_REPO") ENV_FORCE_DOCKER_AGENT_REPO = EnvironmentConfig("FORCE_CLEARML_AGENT_REPO", "CLEARML_AGENT_DOCKER_AGENT_REPO")