mirror of
				https://github.com/clearml/clearml-agent
				synced 2025-06-26 18:16:15 +00:00 
			
		
		
		
	Added CLEARML_AGENT_ABORT_CALLBACK_CMD and CLEARML_AGENT_ABORT_CALLBACK_TIMEOUT
(default 180 sec) to define callback command to be called on abort status change
This commit is contained in:
		
							parent
							
								
									ee286e2fb7
								
							
						
					
					
						commit
						0e2657421f
					
				| @ -22,6 +22,7 @@ from datetime import datetime | ||||
| from functools import partial | ||||
| from os.path import basename | ||||
| from tempfile import mkdtemp, NamedTemporaryFile | ||||
| from threading import Thread | ||||
| from time import sleep, time | ||||
| from typing import Text, Optional, Any, Tuple, List, Dict, Mapping, Union | ||||
| 
 | ||||
| @ -78,7 +79,7 @@ from clearml_agent.definitions import ( | ||||
|     ENV_AGENT_FORCE_EXEC_SCRIPT, | ||||
|     ENV_TEMP_STDOUT_FILE_DIR, | ||||
|     ENV_AGENT_FORCE_TASK_INIT, | ||||
|     ENV_AGENT_DEBUG_GET_NEXT_TASK, | ||||
|     ENV_AGENT_DEBUG_GET_NEXT_TASK, ENV_ABORT_CALLBACK_CMD, ENV_ABORT_CALLBACK_CMD_TIMEOUT, | ||||
| ) | ||||
| from clearml_agent.definitions import WORKING_REPOSITORY_DIR, PIP_EXTRA_INDICES | ||||
| from clearml_agent.errors import ( | ||||
| @ -534,14 +535,18 @@ class TaskStopSignal(object): | ||||
|     ] | ||||
|     default = TaskStopReason.no_stop | ||||
|     stopping_message = "stopping" | ||||
|     property_abort_callback_completed = "_abort_callback_completed" | ||||
|     property_abort_callback_timeout = "_abort_callback_timeout" | ||||
|     property_abort_poll_freq = "_abort_poll_freq" | ||||
| 
 | ||||
|     def __init__(self, command, session, events_service, task_id): | ||||
|         # type: (Worker, Session, Events, Text) -> () | ||||
|     def __init__(self, command, session, events_service, task_id, bash_cwd=None): | ||||
|         # type: (Worker, Session, Events, Text, Text) -> None | ||||
|         """ | ||||
|         :param command: workers command | ||||
|         :param session: command session | ||||
|         :param events_service: events service object | ||||
|         :param task_id: followed task ID | ||||
|         :param bash_cwd: cwd for bash on_abort callback | ||||
|         """ | ||||
|         self.command = command | ||||
|         self.session = session | ||||
| @ -553,6 +558,90 @@ class TaskStopSignal(object): | ||||
|         self._active_callback_timestamp = None | ||||
|         self._active_callback_timeout = None | ||||
|         self._abort_callback_max_timeout = float(self.session.config.get('agent.abort_callback_max_timeout', 1800)) | ||||
|         self._bash_callback_cwd = bash_cwd | ||||
|         self._bash_callback = None | ||||
|         self._bash_callback_timeout = None | ||||
|         self._bash_callback_thread = None | ||||
|         self._self_monitor_thread = None | ||||
|         self.register_bash_callback() | ||||
| 
 | ||||
|     @classmethod | ||||
|     def check_registered_bash_callback(cls): | ||||
|         return bool((ENV_ABORT_CALLBACK_CMD.get() or | ||||
|                      os.environ.get(ENV_ABORT_CALLBACK_CMD.vars[0] + "_REGISTERED") or | ||||
|                      "").strip()) | ||||
| 
 | ||||
|     def register_bash_callback(self): | ||||
|         # check if the env variable defined a callback | ||||
|         if not (ENV_ABORT_CALLBACK_CMD.get() or "").strip(): | ||||
|             return | ||||
|         self._bash_callback = shlex.split(ENV_ABORT_CALLBACK_CMD.get()) | ||||
|         # make sure we are re-testing in subprocesses | ||||
|         os.environ[ENV_ABORT_CALLBACK_CMD.vars[0]+"_REGISTERED"] = ENV_ABORT_CALLBACK_CMD.get() | ||||
|         os.environ.pop(ENV_ABORT_CALLBACK_CMD.vars[0], None) | ||||
| 
 | ||||
|         # noinspection PyBroadException | ||||
|         try: | ||||
|             self._bash_callback_timeout = int((ENV_ABORT_CALLBACK_CMD_TIMEOUT.get() or "").strip() or 180) | ||||
|         except Exception: | ||||
|             self._bash_callback_timeout = 180 | ||||
| 
 | ||||
|         # update the task runtime properties | ||||
|         try: | ||||
|             task_info = self.session.get( | ||||
|                 service="tasks", action="get_all", version="2.13", | ||||
|                 id=[self.task_id], only_fields=["runtime"])['tasks'][0] | ||||
|             runtime_properties = task_info.get("runtime") or {} | ||||
|             runtime_properties[self.property_abort_callback_timeout] = str(self._bash_callback_timeout) | ||||
|             runtime_properties[self.property_abort_poll_freq] = str(10) | ||||
|             self.session.post(service="tasks", action="edit", version="2.13", | ||||
|                               task=self.task_id, runtime=runtime_properties, force=True) | ||||
|         except Exception as ex: | ||||
|             print("WARNING: failed registering bash callback: {}".format(ex)) | ||||
|             return | ||||
| 
 | ||||
|         print("INFO: registering bash callback, timout {}s: {}".format( | ||||
|             self._bash_callback_timeout, self._bash_callback)) | ||||
| 
 | ||||
|     def start_monitor_thread(self, polling_interval_sec=10): | ||||
|         if self._self_monitor_thread: | ||||
|            return | ||||
|         self._self_monitor_thread = Thread(target=self._monitor_thread_loop, args=(polling_interval_sec, )) | ||||
|         self._self_monitor_thread.daemon = True | ||||
|         self._self_monitor_thread.start() | ||||
|         print("INFO: bash on_abort monitor thread started") | ||||
| 
 | ||||
|     def stop_monitor_thread(self): | ||||
|         self._self_monitor_thread = None | ||||
| 
 | ||||
|     def _monitor_thread_loop(self, polling_interval_sec=10): | ||||
|         while self._self_monitor_thread is not None: | ||||
|             sleep(polling_interval_sec) | ||||
|             stop_reason = self.test() | ||||
|             if stop_reason != TaskStopSignal.default: | ||||
|                 # mark quit loop | ||||
|                 break | ||||
| 
 | ||||
|     def _bash_callback_launch_thread(self): | ||||
|         command = Argv(*self._bash_callback) | ||||
|         print("INFO: running bash on_abort callback: {}".format(command.pretty())) | ||||
|         try: | ||||
|             command.check_call(cwd=self._bash_callback_cwd or None) | ||||
|         except Exception as ex: | ||||
|             print("WARNING: failed running bash callback: {}".format(ex)) | ||||
| 
 | ||||
|         # update the task runtime properties | ||||
|         try: | ||||
|             task_info = self.session.get( | ||||
|                 service="tasks", action="get_all", version="2.13", | ||||
|                 id=[self.task_id], only_fields=["runtime"])['tasks'][0] | ||||
|             runtime_properties = task_info.get("runtime") or {} | ||||
|             runtime_properties[self.property_abort_callback_completed] = str(1) | ||||
|             self.session.post(service="tasks", action="edit", version="2.13", | ||||
|                               task=self.task_id, runtime=runtime_properties, force=True) | ||||
|         except Exception as ex: | ||||
|             print("WARNING: failed updating bash callback completed: {}".format(ex)) | ||||
|             return | ||||
| 
 | ||||
|     def test(self): | ||||
|         # type: () -> TaskStopReason | ||||
| @ -583,8 +672,12 @@ class TaskStopSignal(object): | ||||
|                 try: | ||||
|                     task_info = self.session.get( | ||||
|                         service="tasks", action="get_all", version="2.13", id=[self.task_id], | ||||
|                         only_fields=["status", "status_message", "runtime._abort_callback_completed"]) | ||||
|                     cb_completed = task_info['tasks'][0]['runtime'].get('_abort_callback_completed', None) | ||||
|                         only_fields=[ | ||||
|                             "status", "status_message", | ||||
|                             "runtime.{}".format(self.property_abort_callback_completed) | ||||
|                         ] | ||||
|                     ) | ||||
|                     cb_completed = task_info['tasks'][0]['runtime'].get(self.property_abort_callback_completed, None) | ||||
|                 except:  # noqa | ||||
|                     pass | ||||
| 
 | ||||
| @ -620,6 +713,13 @@ class TaskStopSignal(object): | ||||
|             session=self.session | ||||
|         ) | ||||
| 
 | ||||
|         # launch bash callback if exists | ||||
|         if self._bash_callback and not self._bash_callback_thread: | ||||
|             print("INFO: bash on_abort callback thread started") | ||||
|             self._bash_callback_thread = Thread(target=self._bash_callback_launch_thread) | ||||
|             self._bash_callback_thread.daemon = True | ||||
|             self._bash_callback_thread.start() | ||||
| 
 | ||||
|         self._active_callback_timestamp = time() | ||||
|         self._active_callback_timeout = timeout | ||||
|         return bool(cb_completed) | ||||
| @ -629,11 +729,16 @@ class TaskStopSignal(object): | ||||
|         try: | ||||
|             task_info = self.session.get( | ||||
|                 service="tasks", action="get_all", version="2.13", id=[self.task_id], | ||||
|                 only_fields=["status", "status_message", "runtime._abort_callback_timeout", | ||||
|                              "runtime._abort_poll_freq", "runtime._abort_callback_completed"]) | ||||
|             abort_timeout = task_info['tasks'][0]['runtime'].get('_abort_callback_timeout', 0) | ||||
|             poll_timeout = task_info['tasks'][0]['runtime'].get('_abort_poll_freq', 0) | ||||
|             cb_completed = task_info['tasks'][0]['runtime'].get('_abort_callback_completed', None) | ||||
|                 only_fields=[ | ||||
|                     "status", "status_message", | ||||
|                     "runtime.{}".format(self.property_abort_callback_timeout), | ||||
|                     "runtime.{}".format(self.property_abort_poll_freq), | ||||
|                     "runtime.{}".format(self.property_abort_callback_completed) | ||||
|                 ] | ||||
|             ) | ||||
|             abort_timeout = task_info['tasks'][0]['runtime'].get(self.property_abort_callback_timeout, 0) | ||||
|             poll_timeout = task_info['tasks'][0]['runtime'].get(self.property_abort_poll_freq, 0) | ||||
|             cb_completed = task_info['tasks'][0]['runtime'].get(self.property_abort_callback_completed, None) | ||||
|         except:  # noqa | ||||
|             abort_timeout = None | ||||
|             poll_timeout = None | ||||
| @ -3087,6 +3192,18 @@ class Worker(ServiceCommandSection): | ||||
|             force=True, | ||||
|         ) | ||||
| 
 | ||||
|         stop_signal = None | ||||
|         if TaskStopSignal.check_registered_bash_callback(): | ||||
|             use_execv = False | ||||
|             stop_signal = TaskStopSignal( | ||||
|                 command=self, | ||||
|                 session=self._session, | ||||
|                 events_service=self.get_service(Events), | ||||
|                 task_id=task_id, | ||||
|                 bash_cwd=script_dir | ||||
|             ) | ||||
|             stop_signal.start_monitor_thread(polling_interval_sec=self._polling_interval) | ||||
| 
 | ||||
|         # check if we need to add encoding to the subprocess | ||||
|         if sys.getfilesystemencoding() == 'ascii' and not os.environ.get("PYTHONIOENCODING"): | ||||
|             os.environ["PYTHONIOENCODING"] = "utf-8" | ||||
| @ -3149,6 +3266,8 @@ class Worker(ServiceCommandSection): | ||||
|         exit_code = exit_code if exit_code != ExitStatus.interrupted else -1 | ||||
| 
 | ||||
|         if not disable_monitoring: | ||||
|             if stop_signal: | ||||
|                 stop_signal.stop_monitor_thread() | ||||
|             # we need to change task status according to exit code | ||||
|             self.handle_task_termination(current_task.id, exit_code, TaskStopReason.no_stop) | ||||
|             self.stop_monitor() | ||||
|  | ||||
| @ -187,6 +187,9 @@ ENV_CHILD_AGENTS_COUNT_CMD = EnvironmentConfig("CLEARML_AGENT_CHILD_AGENTS_COUNT | ||||
| ENV_DOCKER_ARGS_FILTERS = EnvironmentConfig("CLEARML_AGENT_DOCKER_ARGS_FILTERS") | ||||
| ENV_DOCKER_ARGS_HIDE_ENV = EnvironmentConfig("CLEARML_AGENT_DOCKER_ARGS_HIDE_ENV") | ||||
| ENV_CONFIG_BC_IN_STANDALONE = EnvironmentConfig("CLEARML_AGENT_STANDALONE_CONFIG_BC", type=bool) | ||||
| ENV_ABORT_CALLBACK_CMD = EnvironmentConfig("CLEARML_AGENT_ABORT_CALLBACK_CMD") | ||||
| ENV_ABORT_CALLBACK_CMD_TIMEOUT = EnvironmentConfig("CLEARML_AGENT_ABORT_CALLBACK_TIMEOUT") | ||||
| 
 | ||||
| """ Maintain backwards compatible configuration when launching in standalone mode """ | ||||
| 
 | ||||
| ENV_FORCE_DOCKER_AGENT_REPO = EnvironmentConfig("FORCE_CLEARML_AGENT_REPO", "CLEARML_AGENT_DOCKER_AGENT_REPO") | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 clearml
						clearml