mirror of
https://github.com/clearml/clearml-agent
synced 2025-02-26 05:59:24 +00:00
Added CLEARML_AGENT_ABORT_CALLBACK_CMD and CLEARML_AGENT_ABORT_CALLBACK_TIMEOUT
(default 180 sec) to define callback command to be called on abort status change
This commit is contained in:
parent
ee286e2fb7
commit
0e2657421f
@ -22,6 +22,7 @@ from datetime import datetime
|
||||
from functools import partial
|
||||
from os.path import basename
|
||||
from tempfile import mkdtemp, NamedTemporaryFile
|
||||
from threading import Thread
|
||||
from time import sleep, time
|
||||
from typing import Text, Optional, Any, Tuple, List, Dict, Mapping, Union
|
||||
|
||||
@ -78,7 +79,7 @@ from clearml_agent.definitions import (
|
||||
ENV_AGENT_FORCE_EXEC_SCRIPT,
|
||||
ENV_TEMP_STDOUT_FILE_DIR,
|
||||
ENV_AGENT_FORCE_TASK_INIT,
|
||||
ENV_AGENT_DEBUG_GET_NEXT_TASK,
|
||||
ENV_AGENT_DEBUG_GET_NEXT_TASK, ENV_ABORT_CALLBACK_CMD, ENV_ABORT_CALLBACK_CMD_TIMEOUT,
|
||||
)
|
||||
from clearml_agent.definitions import WORKING_REPOSITORY_DIR, PIP_EXTRA_INDICES
|
||||
from clearml_agent.errors import (
|
||||
@ -534,14 +535,18 @@ class TaskStopSignal(object):
|
||||
]
|
||||
default = TaskStopReason.no_stop
|
||||
stopping_message = "stopping"
|
||||
property_abort_callback_completed = "_abort_callback_completed"
|
||||
property_abort_callback_timeout = "_abort_callback_timeout"
|
||||
property_abort_poll_freq = "_abort_poll_freq"
|
||||
|
||||
def __init__(self, command, session, events_service, task_id):
|
||||
# type: (Worker, Session, Events, Text) -> ()
|
||||
def __init__(self, command, session, events_service, task_id, bash_cwd=None):
|
||||
# type: (Worker, Session, Events, Text, Text) -> None
|
||||
"""
|
||||
:param command: workers command
|
||||
:param session: command session
|
||||
:param events_service: events service object
|
||||
:param task_id: followed task ID
|
||||
:param bash_cwd: cwd for bash on_abort callback
|
||||
"""
|
||||
self.command = command
|
||||
self.session = session
|
||||
@ -553,6 +558,90 @@ class TaskStopSignal(object):
|
||||
self._active_callback_timestamp = None
|
||||
self._active_callback_timeout = None
|
||||
self._abort_callback_max_timeout = float(self.session.config.get('agent.abort_callback_max_timeout', 1800))
|
||||
self._bash_callback_cwd = bash_cwd
|
||||
self._bash_callback = None
|
||||
self._bash_callback_timeout = None
|
||||
self._bash_callback_thread = None
|
||||
self._self_monitor_thread = None
|
||||
self.register_bash_callback()
|
||||
|
||||
@classmethod
|
||||
def check_registered_bash_callback(cls):
|
||||
return bool((ENV_ABORT_CALLBACK_CMD.get() or
|
||||
os.environ.get(ENV_ABORT_CALLBACK_CMD.vars[0] + "_REGISTERED") or
|
||||
"").strip())
|
||||
|
||||
def register_bash_callback(self):
|
||||
# check if the env variable defined a callback
|
||||
if not (ENV_ABORT_CALLBACK_CMD.get() or "").strip():
|
||||
return
|
||||
self._bash_callback = shlex.split(ENV_ABORT_CALLBACK_CMD.get())
|
||||
# make sure we are re-testing in subprocesses
|
||||
os.environ[ENV_ABORT_CALLBACK_CMD.vars[0]+"_REGISTERED"] = ENV_ABORT_CALLBACK_CMD.get()
|
||||
os.environ.pop(ENV_ABORT_CALLBACK_CMD.vars[0], None)
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self._bash_callback_timeout = int((ENV_ABORT_CALLBACK_CMD_TIMEOUT.get() or "").strip() or 180)
|
||||
except Exception:
|
||||
self._bash_callback_timeout = 180
|
||||
|
||||
# update the task runtime properties
|
||||
try:
|
||||
task_info = self.session.get(
|
||||
service="tasks", action="get_all", version="2.13",
|
||||
id=[self.task_id], only_fields=["runtime"])['tasks'][0]
|
||||
runtime_properties = task_info.get("runtime") or {}
|
||||
runtime_properties[self.property_abort_callback_timeout] = str(self._bash_callback_timeout)
|
||||
runtime_properties[self.property_abort_poll_freq] = str(10)
|
||||
self.session.post(service="tasks", action="edit", version="2.13",
|
||||
task=self.task_id, runtime=runtime_properties, force=True)
|
||||
except Exception as ex:
|
||||
print("WARNING: failed registering bash callback: {}".format(ex))
|
||||
return
|
||||
|
||||
print("INFO: registering bash callback, timout {}s: {}".format(
|
||||
self._bash_callback_timeout, self._bash_callback))
|
||||
|
||||
def start_monitor_thread(self, polling_interval_sec=10):
|
||||
if self._self_monitor_thread:
|
||||
return
|
||||
self._self_monitor_thread = Thread(target=self._monitor_thread_loop, args=(polling_interval_sec, ))
|
||||
self._self_monitor_thread.daemon = True
|
||||
self._self_monitor_thread.start()
|
||||
print("INFO: bash on_abort monitor thread started")
|
||||
|
||||
def stop_monitor_thread(self):
|
||||
self._self_monitor_thread = None
|
||||
|
||||
def _monitor_thread_loop(self, polling_interval_sec=10):
|
||||
while self._self_monitor_thread is not None:
|
||||
sleep(polling_interval_sec)
|
||||
stop_reason = self.test()
|
||||
if stop_reason != TaskStopSignal.default:
|
||||
# mark quit loop
|
||||
break
|
||||
|
||||
def _bash_callback_launch_thread(self):
|
||||
command = Argv(*self._bash_callback)
|
||||
print("INFO: running bash on_abort callback: {}".format(command.pretty()))
|
||||
try:
|
||||
command.check_call(cwd=self._bash_callback_cwd or None)
|
||||
except Exception as ex:
|
||||
print("WARNING: failed running bash callback: {}".format(ex))
|
||||
|
||||
# update the task runtime properties
|
||||
try:
|
||||
task_info = self.session.get(
|
||||
service="tasks", action="get_all", version="2.13",
|
||||
id=[self.task_id], only_fields=["runtime"])['tasks'][0]
|
||||
runtime_properties = task_info.get("runtime") or {}
|
||||
runtime_properties[self.property_abort_callback_completed] = str(1)
|
||||
self.session.post(service="tasks", action="edit", version="2.13",
|
||||
task=self.task_id, runtime=runtime_properties, force=True)
|
||||
except Exception as ex:
|
||||
print("WARNING: failed updating bash callback completed: {}".format(ex))
|
||||
return
|
||||
|
||||
def test(self):
|
||||
# type: () -> TaskStopReason
|
||||
@ -583,8 +672,12 @@ class TaskStopSignal(object):
|
||||
try:
|
||||
task_info = self.session.get(
|
||||
service="tasks", action="get_all", version="2.13", id=[self.task_id],
|
||||
only_fields=["status", "status_message", "runtime._abort_callback_completed"])
|
||||
cb_completed = task_info['tasks'][0]['runtime'].get('_abort_callback_completed', None)
|
||||
only_fields=[
|
||||
"status", "status_message",
|
||||
"runtime.{}".format(self.property_abort_callback_completed)
|
||||
]
|
||||
)
|
||||
cb_completed = task_info['tasks'][0]['runtime'].get(self.property_abort_callback_completed, None)
|
||||
except: # noqa
|
||||
pass
|
||||
|
||||
@ -620,6 +713,13 @@ class TaskStopSignal(object):
|
||||
session=self.session
|
||||
)
|
||||
|
||||
# launch bash callback if exists
|
||||
if self._bash_callback and not self._bash_callback_thread:
|
||||
print("INFO: bash on_abort callback thread started")
|
||||
self._bash_callback_thread = Thread(target=self._bash_callback_launch_thread)
|
||||
self._bash_callback_thread.daemon = True
|
||||
self._bash_callback_thread.start()
|
||||
|
||||
self._active_callback_timestamp = time()
|
||||
self._active_callback_timeout = timeout
|
||||
return bool(cb_completed)
|
||||
@ -629,11 +729,16 @@ class TaskStopSignal(object):
|
||||
try:
|
||||
task_info = self.session.get(
|
||||
service="tasks", action="get_all", version="2.13", id=[self.task_id],
|
||||
only_fields=["status", "status_message", "runtime._abort_callback_timeout",
|
||||
"runtime._abort_poll_freq", "runtime._abort_callback_completed"])
|
||||
abort_timeout = task_info['tasks'][0]['runtime'].get('_abort_callback_timeout', 0)
|
||||
poll_timeout = task_info['tasks'][0]['runtime'].get('_abort_poll_freq', 0)
|
||||
cb_completed = task_info['tasks'][0]['runtime'].get('_abort_callback_completed', None)
|
||||
only_fields=[
|
||||
"status", "status_message",
|
||||
"runtime.{}".format(self.property_abort_callback_timeout),
|
||||
"runtime.{}".format(self.property_abort_poll_freq),
|
||||
"runtime.{}".format(self.property_abort_callback_completed)
|
||||
]
|
||||
)
|
||||
abort_timeout = task_info['tasks'][0]['runtime'].get(self.property_abort_callback_timeout, 0)
|
||||
poll_timeout = task_info['tasks'][0]['runtime'].get(self.property_abort_poll_freq, 0)
|
||||
cb_completed = task_info['tasks'][0]['runtime'].get(self.property_abort_callback_completed, None)
|
||||
except: # noqa
|
||||
abort_timeout = None
|
||||
poll_timeout = None
|
||||
@ -3087,6 +3192,18 @@ class Worker(ServiceCommandSection):
|
||||
force=True,
|
||||
)
|
||||
|
||||
stop_signal = None
|
||||
if TaskStopSignal.check_registered_bash_callback():
|
||||
use_execv = False
|
||||
stop_signal = TaskStopSignal(
|
||||
command=self,
|
||||
session=self._session,
|
||||
events_service=self.get_service(Events),
|
||||
task_id=task_id,
|
||||
bash_cwd=script_dir
|
||||
)
|
||||
stop_signal.start_monitor_thread(polling_interval_sec=self._polling_interval)
|
||||
|
||||
# check if we need to add encoding to the subprocess
|
||||
if sys.getfilesystemencoding() == 'ascii' and not os.environ.get("PYTHONIOENCODING"):
|
||||
os.environ["PYTHONIOENCODING"] = "utf-8"
|
||||
@ -3149,6 +3266,8 @@ class Worker(ServiceCommandSection):
|
||||
exit_code = exit_code if exit_code != ExitStatus.interrupted else -1
|
||||
|
||||
if not disable_monitoring:
|
||||
if stop_signal:
|
||||
stop_signal.stop_monitor_thread()
|
||||
# we need to change task status according to exit code
|
||||
self.handle_task_termination(current_task.id, exit_code, TaskStopReason.no_stop)
|
||||
self.stop_monitor()
|
||||
|
@ -187,6 +187,9 @@ ENV_CHILD_AGENTS_COUNT_CMD = EnvironmentConfig("CLEARML_AGENT_CHILD_AGENTS_COUNT
|
||||
ENV_DOCKER_ARGS_FILTERS = EnvironmentConfig("CLEARML_AGENT_DOCKER_ARGS_FILTERS")
|
||||
ENV_DOCKER_ARGS_HIDE_ENV = EnvironmentConfig("CLEARML_AGENT_DOCKER_ARGS_HIDE_ENV")
|
||||
ENV_CONFIG_BC_IN_STANDALONE = EnvironmentConfig("CLEARML_AGENT_STANDALONE_CONFIG_BC", type=bool)
|
||||
ENV_ABORT_CALLBACK_CMD = EnvironmentConfig("CLEARML_AGENT_ABORT_CALLBACK_CMD")
|
||||
ENV_ABORT_CALLBACK_CMD_TIMEOUT = EnvironmentConfig("CLEARML_AGENT_ABORT_CALLBACK_TIMEOUT")
|
||||
|
||||
""" Maintain backwards compatible configuration when launching in standalone mode """
|
||||
|
||||
ENV_FORCE_DOCKER_AGENT_REPO = EnvironmentConfig("FORCE_CLEARML_AGENT_REPO", "CLEARML_AGENT_DOCKER_AGENT_REPO")
|
||||
|
Loading…
Reference in New Issue
Block a user