Added CLEARML_AGENT_ABORT_CALLBACK_CMD and CLEARML_AGENT_ABORT_CALLBACK_TIMEOUT

(default 180 sec) to define callback command to be called on abort status change
This commit is contained in:
clearml 2025-02-24 13:46:00 +02:00
parent ee286e2fb7
commit 0e2657421f
2 changed files with 132 additions and 10 deletions

View File

@ -22,6 +22,7 @@ from datetime import datetime
from functools import partial
from os.path import basename
from tempfile import mkdtemp, NamedTemporaryFile
from threading import Thread
from time import sleep, time
from typing import Text, Optional, Any, Tuple, List, Dict, Mapping, Union
@ -78,7 +79,7 @@ from clearml_agent.definitions import (
ENV_AGENT_FORCE_EXEC_SCRIPT,
ENV_TEMP_STDOUT_FILE_DIR,
ENV_AGENT_FORCE_TASK_INIT,
ENV_AGENT_DEBUG_GET_NEXT_TASK,
ENV_AGENT_DEBUG_GET_NEXT_TASK, ENV_ABORT_CALLBACK_CMD, ENV_ABORT_CALLBACK_CMD_TIMEOUT,
)
from clearml_agent.definitions import WORKING_REPOSITORY_DIR, PIP_EXTRA_INDICES
from clearml_agent.errors import (
@ -534,14 +535,18 @@ class TaskStopSignal(object):
]
default = TaskStopReason.no_stop
stopping_message = "stopping"
property_abort_callback_completed = "_abort_callback_completed"
property_abort_callback_timeout = "_abort_callback_timeout"
property_abort_poll_freq = "_abort_poll_freq"
def __init__(self, command, session, events_service, task_id):
# type: (Worker, Session, Events, Text) -> ()
def __init__(self, command, session, events_service, task_id, bash_cwd=None):
# type: (Worker, Session, Events, Text, Text) -> None
"""
:param command: workers command
:param session: command session
:param events_service: events service object
:param task_id: followed task ID
:param bash_cwd: cwd for bash on_abort callback
"""
self.command = command
self.session = session
@ -553,6 +558,90 @@ class TaskStopSignal(object):
self._active_callback_timestamp = None
self._active_callback_timeout = None
self._abort_callback_max_timeout = float(self.session.config.get('agent.abort_callback_max_timeout', 1800))
self._bash_callback_cwd = bash_cwd
self._bash_callback = None
self._bash_callback_timeout = None
self._bash_callback_thread = None
self._self_monitor_thread = None
self.register_bash_callback()
@classmethod
def check_registered_bash_callback(cls):
return bool((ENV_ABORT_CALLBACK_CMD.get() or
os.environ.get(ENV_ABORT_CALLBACK_CMD.vars[0] + "_REGISTERED") or
"").strip())
def register_bash_callback(self):
# check if the env variable defined a callback
if not (ENV_ABORT_CALLBACK_CMD.get() or "").strip():
return
self._bash_callback = shlex.split(ENV_ABORT_CALLBACK_CMD.get())
# make sure we are re-testing in subprocesses
os.environ[ENV_ABORT_CALLBACK_CMD.vars[0]+"_REGISTERED"] = ENV_ABORT_CALLBACK_CMD.get()
os.environ.pop(ENV_ABORT_CALLBACK_CMD.vars[0], None)
# noinspection PyBroadException
try:
self._bash_callback_timeout = int((ENV_ABORT_CALLBACK_CMD_TIMEOUT.get() or "").strip() or 180)
except Exception:
self._bash_callback_timeout = 180
# update the task runtime properties
try:
task_info = self.session.get(
service="tasks", action="get_all", version="2.13",
id=[self.task_id], only_fields=["runtime"])['tasks'][0]
runtime_properties = task_info.get("runtime") or {}
runtime_properties[self.property_abort_callback_timeout] = str(self._bash_callback_timeout)
runtime_properties[self.property_abort_poll_freq] = str(10)
self.session.post(service="tasks", action="edit", version="2.13",
task=self.task_id, runtime=runtime_properties, force=True)
except Exception as ex:
print("WARNING: failed registering bash callback: {}".format(ex))
return
print("INFO: registering bash callback, timout {}s: {}".format(
self._bash_callback_timeout, self._bash_callback))
def start_monitor_thread(self, polling_interval_sec=10):
if self._self_monitor_thread:
return
self._self_monitor_thread = Thread(target=self._monitor_thread_loop, args=(polling_interval_sec, ))
self._self_monitor_thread.daemon = True
self._self_monitor_thread.start()
print("INFO: bash on_abort monitor thread started")
def stop_monitor_thread(self):
self._self_monitor_thread = None
def _monitor_thread_loop(self, polling_interval_sec=10):
while self._self_monitor_thread is not None:
sleep(polling_interval_sec)
stop_reason = self.test()
if stop_reason != TaskStopSignal.default:
# mark quit loop
break
def _bash_callback_launch_thread(self):
command = Argv(*self._bash_callback)
print("INFO: running bash on_abort callback: {}".format(command.pretty()))
try:
command.check_call(cwd=self._bash_callback_cwd or None)
except Exception as ex:
print("WARNING: failed running bash callback: {}".format(ex))
# update the task runtime properties
try:
task_info = self.session.get(
service="tasks", action="get_all", version="2.13",
id=[self.task_id], only_fields=["runtime"])['tasks'][0]
runtime_properties = task_info.get("runtime") or {}
runtime_properties[self.property_abort_callback_completed] = str(1)
self.session.post(service="tasks", action="edit", version="2.13",
task=self.task_id, runtime=runtime_properties, force=True)
except Exception as ex:
print("WARNING: failed updating bash callback completed: {}".format(ex))
return
def test(self):
# type: () -> TaskStopReason
@ -583,8 +672,12 @@ class TaskStopSignal(object):
try:
task_info = self.session.get(
service="tasks", action="get_all", version="2.13", id=[self.task_id],
only_fields=["status", "status_message", "runtime._abort_callback_completed"])
cb_completed = task_info['tasks'][0]['runtime'].get('_abort_callback_completed', None)
only_fields=[
"status", "status_message",
"runtime.{}".format(self.property_abort_callback_completed)
]
)
cb_completed = task_info['tasks'][0]['runtime'].get(self.property_abort_callback_completed, None)
except: # noqa
pass
@ -620,6 +713,13 @@ class TaskStopSignal(object):
session=self.session
)
# launch bash callback if exists
if self._bash_callback and not self._bash_callback_thread:
print("INFO: bash on_abort callback thread started")
self._bash_callback_thread = Thread(target=self._bash_callback_launch_thread)
self._bash_callback_thread.daemon = True
self._bash_callback_thread.start()
self._active_callback_timestamp = time()
self._active_callback_timeout = timeout
return bool(cb_completed)
@ -629,11 +729,16 @@ class TaskStopSignal(object):
try:
task_info = self.session.get(
service="tasks", action="get_all", version="2.13", id=[self.task_id],
only_fields=["status", "status_message", "runtime._abort_callback_timeout",
"runtime._abort_poll_freq", "runtime._abort_callback_completed"])
abort_timeout = task_info['tasks'][0]['runtime'].get('_abort_callback_timeout', 0)
poll_timeout = task_info['tasks'][0]['runtime'].get('_abort_poll_freq', 0)
cb_completed = task_info['tasks'][0]['runtime'].get('_abort_callback_completed', None)
only_fields=[
"status", "status_message",
"runtime.{}".format(self.property_abort_callback_timeout),
"runtime.{}".format(self.property_abort_poll_freq),
"runtime.{}".format(self.property_abort_callback_completed)
]
)
abort_timeout = task_info['tasks'][0]['runtime'].get(self.property_abort_callback_timeout, 0)
poll_timeout = task_info['tasks'][0]['runtime'].get(self.property_abort_poll_freq, 0)
cb_completed = task_info['tasks'][0]['runtime'].get(self.property_abort_callback_completed, None)
except: # noqa
abort_timeout = None
poll_timeout = None
@ -3087,6 +3192,18 @@ class Worker(ServiceCommandSection):
force=True,
)
stop_signal = None
if TaskStopSignal.check_registered_bash_callback():
use_execv = False
stop_signal = TaskStopSignal(
command=self,
session=self._session,
events_service=self.get_service(Events),
task_id=task_id,
bash_cwd=script_dir
)
stop_signal.start_monitor_thread(polling_interval_sec=self._polling_interval)
# check if we need to add encoding to the subprocess
if sys.getfilesystemencoding() == 'ascii' and not os.environ.get("PYTHONIOENCODING"):
os.environ["PYTHONIOENCODING"] = "utf-8"
@ -3149,6 +3266,8 @@ class Worker(ServiceCommandSection):
exit_code = exit_code if exit_code != ExitStatus.interrupted else -1
if not disable_monitoring:
if stop_signal:
stop_signal.stop_monitor_thread()
# we need to change task status according to exit code
self.handle_task_termination(current_task.id, exit_code, TaskStopReason.no_stop)
self.stop_monitor()

View File

@ -187,6 +187,9 @@ ENV_CHILD_AGENTS_COUNT_CMD = EnvironmentConfig("CLEARML_AGENT_CHILD_AGENTS_COUNT
ENV_DOCKER_ARGS_FILTERS = EnvironmentConfig("CLEARML_AGENT_DOCKER_ARGS_FILTERS")
ENV_DOCKER_ARGS_HIDE_ENV = EnvironmentConfig("CLEARML_AGENT_DOCKER_ARGS_HIDE_ENV")
ENV_CONFIG_BC_IN_STANDALONE = EnvironmentConfig("CLEARML_AGENT_STANDALONE_CONFIG_BC", type=bool)
ENV_ABORT_CALLBACK_CMD = EnvironmentConfig("CLEARML_AGENT_ABORT_CALLBACK_CMD")
ENV_ABORT_CALLBACK_CMD_TIMEOUT = EnvironmentConfig("CLEARML_AGENT_ABORT_CALLBACK_TIMEOUT")
""" Maintain backwards compatible configuration when launching in standalone mode """
ENV_FORCE_DOCKER_AGENT_REPO = EnvironmentConfig("FORCE_CLEARML_AGENT_REPO", "CLEARML_AGENT_DOCKER_AGENT_REPO")