mirror of
				https://github.com/clearml/clearml-agent
				synced 2025-06-26 18:16:15 +00:00 
			
		
		
		
	Added CLEARML_AGENT_ABORT_CALLBACK_CMD and CLEARML_AGENT_ABORT_CALLBACK_TIMEOUT
(default 180 sec) to define callback command to be called on abort status change
This commit is contained in:
		
							parent
							
								
									ee286e2fb7
								
							
						
					
					
						commit
						0e2657421f
					
				| @ -22,6 +22,7 @@ from datetime import datetime | |||||||
| from functools import partial | from functools import partial | ||||||
| from os.path import basename | from os.path import basename | ||||||
| from tempfile import mkdtemp, NamedTemporaryFile | from tempfile import mkdtemp, NamedTemporaryFile | ||||||
|  | from threading import Thread | ||||||
| from time import sleep, time | from time import sleep, time | ||||||
| from typing import Text, Optional, Any, Tuple, List, Dict, Mapping, Union | from typing import Text, Optional, Any, Tuple, List, Dict, Mapping, Union | ||||||
| 
 | 
 | ||||||
| @ -78,7 +79,7 @@ from clearml_agent.definitions import ( | |||||||
|     ENV_AGENT_FORCE_EXEC_SCRIPT, |     ENV_AGENT_FORCE_EXEC_SCRIPT, | ||||||
|     ENV_TEMP_STDOUT_FILE_DIR, |     ENV_TEMP_STDOUT_FILE_DIR, | ||||||
|     ENV_AGENT_FORCE_TASK_INIT, |     ENV_AGENT_FORCE_TASK_INIT, | ||||||
|     ENV_AGENT_DEBUG_GET_NEXT_TASK, |     ENV_AGENT_DEBUG_GET_NEXT_TASK, ENV_ABORT_CALLBACK_CMD, ENV_ABORT_CALLBACK_CMD_TIMEOUT, | ||||||
| ) | ) | ||||||
| from clearml_agent.definitions import WORKING_REPOSITORY_DIR, PIP_EXTRA_INDICES | from clearml_agent.definitions import WORKING_REPOSITORY_DIR, PIP_EXTRA_INDICES | ||||||
| from clearml_agent.errors import ( | from clearml_agent.errors import ( | ||||||
| @ -534,14 +535,18 @@ class TaskStopSignal(object): | |||||||
|     ] |     ] | ||||||
|     default = TaskStopReason.no_stop |     default = TaskStopReason.no_stop | ||||||
|     stopping_message = "stopping" |     stopping_message = "stopping" | ||||||
|  |     property_abort_callback_completed = "_abort_callback_completed" | ||||||
|  |     property_abort_callback_timeout = "_abort_callback_timeout" | ||||||
|  |     property_abort_poll_freq = "_abort_poll_freq" | ||||||
| 
 | 
 | ||||||
|     def __init__(self, command, session, events_service, task_id): |     def __init__(self, command, session, events_service, task_id, bash_cwd=None): | ||||||
|         # type: (Worker, Session, Events, Text) -> () |         # type: (Worker, Session, Events, Text, Text) -> None | ||||||
|         """ |         """ | ||||||
|         :param command: workers command |         :param command: workers command | ||||||
|         :param session: command session |         :param session: command session | ||||||
|         :param events_service: events service object |         :param events_service: events service object | ||||||
|         :param task_id: followed task ID |         :param task_id: followed task ID | ||||||
|  |         :param bash_cwd: cwd for bash on_abort callback | ||||||
|         """ |         """ | ||||||
|         self.command = command |         self.command = command | ||||||
|         self.session = session |         self.session = session | ||||||
| @ -553,6 +558,90 @@ class TaskStopSignal(object): | |||||||
|         self._active_callback_timestamp = None |         self._active_callback_timestamp = None | ||||||
|         self._active_callback_timeout = None |         self._active_callback_timeout = None | ||||||
|         self._abort_callback_max_timeout = float(self.session.config.get('agent.abort_callback_max_timeout', 1800)) |         self._abort_callback_max_timeout = float(self.session.config.get('agent.abort_callback_max_timeout', 1800)) | ||||||
|  |         self._bash_callback_cwd = bash_cwd | ||||||
|  |         self._bash_callback = None | ||||||
|  |         self._bash_callback_timeout = None | ||||||
|  |         self._bash_callback_thread = None | ||||||
|  |         self._self_monitor_thread = None | ||||||
|  |         self.register_bash_callback() | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def check_registered_bash_callback(cls): | ||||||
|  |         return bool((ENV_ABORT_CALLBACK_CMD.get() or | ||||||
|  |                      os.environ.get(ENV_ABORT_CALLBACK_CMD.vars[0] + "_REGISTERED") or | ||||||
|  |                      "").strip()) | ||||||
|  | 
 | ||||||
|  |     def register_bash_callback(self): | ||||||
|  |         # check if the env variable defined a callback | ||||||
|  |         if not (ENV_ABORT_CALLBACK_CMD.get() or "").strip(): | ||||||
|  |             return | ||||||
|  |         self._bash_callback = shlex.split(ENV_ABORT_CALLBACK_CMD.get()) | ||||||
|  |         # make sure we are re-testing in subprocesses | ||||||
|  |         os.environ[ENV_ABORT_CALLBACK_CMD.vars[0]+"_REGISTERED"] = ENV_ABORT_CALLBACK_CMD.get() | ||||||
|  |         os.environ.pop(ENV_ABORT_CALLBACK_CMD.vars[0], None) | ||||||
|  | 
 | ||||||
|  |         # noinspection PyBroadException | ||||||
|  |         try: | ||||||
|  |             self._bash_callback_timeout = int((ENV_ABORT_CALLBACK_CMD_TIMEOUT.get() or "").strip() or 180) | ||||||
|  |         except Exception: | ||||||
|  |             self._bash_callback_timeout = 180 | ||||||
|  | 
 | ||||||
|  |         # update the task runtime properties | ||||||
|  |         try: | ||||||
|  |             task_info = self.session.get( | ||||||
|  |                 service="tasks", action="get_all", version="2.13", | ||||||
|  |                 id=[self.task_id], only_fields=["runtime"])['tasks'][0] | ||||||
|  |             runtime_properties = task_info.get("runtime") or {} | ||||||
|  |             runtime_properties[self.property_abort_callback_timeout] = str(self._bash_callback_timeout) | ||||||
|  |             runtime_properties[self.property_abort_poll_freq] = str(10) | ||||||
|  |             self.session.post(service="tasks", action="edit", version="2.13", | ||||||
|  |                               task=self.task_id, runtime=runtime_properties, force=True) | ||||||
|  |         except Exception as ex: | ||||||
|  |             print("WARNING: failed registering bash callback: {}".format(ex)) | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         print("INFO: registering bash callback, timout {}s: {}".format( | ||||||
|  |             self._bash_callback_timeout, self._bash_callback)) | ||||||
|  | 
 | ||||||
|  |     def start_monitor_thread(self, polling_interval_sec=10): | ||||||
|  |         if self._self_monitor_thread: | ||||||
|  |            return | ||||||
|  |         self._self_monitor_thread = Thread(target=self._monitor_thread_loop, args=(polling_interval_sec, )) | ||||||
|  |         self._self_monitor_thread.daemon = True | ||||||
|  |         self._self_monitor_thread.start() | ||||||
|  |         print("INFO: bash on_abort monitor thread started") | ||||||
|  | 
 | ||||||
|  |     def stop_monitor_thread(self): | ||||||
|  |         self._self_monitor_thread = None | ||||||
|  | 
 | ||||||
|  |     def _monitor_thread_loop(self, polling_interval_sec=10): | ||||||
|  |         while self._self_monitor_thread is not None: | ||||||
|  |             sleep(polling_interval_sec) | ||||||
|  |             stop_reason = self.test() | ||||||
|  |             if stop_reason != TaskStopSignal.default: | ||||||
|  |                 # mark quit loop | ||||||
|  |                 break | ||||||
|  | 
 | ||||||
|  |     def _bash_callback_launch_thread(self): | ||||||
|  |         command = Argv(*self._bash_callback) | ||||||
|  |         print("INFO: running bash on_abort callback: {}".format(command.pretty())) | ||||||
|  |         try: | ||||||
|  |             command.check_call(cwd=self._bash_callback_cwd or None) | ||||||
|  |         except Exception as ex: | ||||||
|  |             print("WARNING: failed running bash callback: {}".format(ex)) | ||||||
|  | 
 | ||||||
|  |         # update the task runtime properties | ||||||
|  |         try: | ||||||
|  |             task_info = self.session.get( | ||||||
|  |                 service="tasks", action="get_all", version="2.13", | ||||||
|  |                 id=[self.task_id], only_fields=["runtime"])['tasks'][0] | ||||||
|  |             runtime_properties = task_info.get("runtime") or {} | ||||||
|  |             runtime_properties[self.property_abort_callback_completed] = str(1) | ||||||
|  |             self.session.post(service="tasks", action="edit", version="2.13", | ||||||
|  |                               task=self.task_id, runtime=runtime_properties, force=True) | ||||||
|  |         except Exception as ex: | ||||||
|  |             print("WARNING: failed updating bash callback completed: {}".format(ex)) | ||||||
|  |             return | ||||||
| 
 | 
 | ||||||
|     def test(self): |     def test(self): | ||||||
|         # type: () -> TaskStopReason |         # type: () -> TaskStopReason | ||||||
| @ -583,8 +672,12 @@ class TaskStopSignal(object): | |||||||
|                 try: |                 try: | ||||||
|                     task_info = self.session.get( |                     task_info = self.session.get( | ||||||
|                         service="tasks", action="get_all", version="2.13", id=[self.task_id], |                         service="tasks", action="get_all", version="2.13", id=[self.task_id], | ||||||
|                         only_fields=["status", "status_message", "runtime._abort_callback_completed"]) |                         only_fields=[ | ||||||
|                     cb_completed = task_info['tasks'][0]['runtime'].get('_abort_callback_completed', None) |                             "status", "status_message", | ||||||
|  |                             "runtime.{}".format(self.property_abort_callback_completed) | ||||||
|  |                         ] | ||||||
|  |                     ) | ||||||
|  |                     cb_completed = task_info['tasks'][0]['runtime'].get(self.property_abort_callback_completed, None) | ||||||
|                 except:  # noqa |                 except:  # noqa | ||||||
|                     pass |                     pass | ||||||
| 
 | 
 | ||||||
| @ -620,6 +713,13 @@ class TaskStopSignal(object): | |||||||
|             session=self.session |             session=self.session | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |         # launch bash callback if exists | ||||||
|  |         if self._bash_callback and not self._bash_callback_thread: | ||||||
|  |             print("INFO: bash on_abort callback thread started") | ||||||
|  |             self._bash_callback_thread = Thread(target=self._bash_callback_launch_thread) | ||||||
|  |             self._bash_callback_thread.daemon = True | ||||||
|  |             self._bash_callback_thread.start() | ||||||
|  | 
 | ||||||
|         self._active_callback_timestamp = time() |         self._active_callback_timestamp = time() | ||||||
|         self._active_callback_timeout = timeout |         self._active_callback_timeout = timeout | ||||||
|         return bool(cb_completed) |         return bool(cb_completed) | ||||||
| @ -629,11 +729,16 @@ class TaskStopSignal(object): | |||||||
|         try: |         try: | ||||||
|             task_info = self.session.get( |             task_info = self.session.get( | ||||||
|                 service="tasks", action="get_all", version="2.13", id=[self.task_id], |                 service="tasks", action="get_all", version="2.13", id=[self.task_id], | ||||||
|                 only_fields=["status", "status_message", "runtime._abort_callback_timeout", |                 only_fields=[ | ||||||
|                              "runtime._abort_poll_freq", "runtime._abort_callback_completed"]) |                     "status", "status_message", | ||||||
|             abort_timeout = task_info['tasks'][0]['runtime'].get('_abort_callback_timeout', 0) |                     "runtime.{}".format(self.property_abort_callback_timeout), | ||||||
|             poll_timeout = task_info['tasks'][0]['runtime'].get('_abort_poll_freq', 0) |                     "runtime.{}".format(self.property_abort_poll_freq), | ||||||
|             cb_completed = task_info['tasks'][0]['runtime'].get('_abort_callback_completed', None) |                     "runtime.{}".format(self.property_abort_callback_completed) | ||||||
|  |                 ] | ||||||
|  |             ) | ||||||
|  |             abort_timeout = task_info['tasks'][0]['runtime'].get(self.property_abort_callback_timeout, 0) | ||||||
|  |             poll_timeout = task_info['tasks'][0]['runtime'].get(self.property_abort_poll_freq, 0) | ||||||
|  |             cb_completed = task_info['tasks'][0]['runtime'].get(self.property_abort_callback_completed, None) | ||||||
|         except:  # noqa |         except:  # noqa | ||||||
|             abort_timeout = None |             abort_timeout = None | ||||||
|             poll_timeout = None |             poll_timeout = None | ||||||
| @ -3087,6 +3192,18 @@ class Worker(ServiceCommandSection): | |||||||
|             force=True, |             force=True, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |         stop_signal = None | ||||||
|  |         if TaskStopSignal.check_registered_bash_callback(): | ||||||
|  |             use_execv = False | ||||||
|  |             stop_signal = TaskStopSignal( | ||||||
|  |                 command=self, | ||||||
|  |                 session=self._session, | ||||||
|  |                 events_service=self.get_service(Events), | ||||||
|  |                 task_id=task_id, | ||||||
|  |                 bash_cwd=script_dir | ||||||
|  |             ) | ||||||
|  |             stop_signal.start_monitor_thread(polling_interval_sec=self._polling_interval) | ||||||
|  | 
 | ||||||
|         # check if we need to add encoding to the subprocess |         # check if we need to add encoding to the subprocess | ||||||
|         if sys.getfilesystemencoding() == 'ascii' and not os.environ.get("PYTHONIOENCODING"): |         if sys.getfilesystemencoding() == 'ascii' and not os.environ.get("PYTHONIOENCODING"): | ||||||
|             os.environ["PYTHONIOENCODING"] = "utf-8" |             os.environ["PYTHONIOENCODING"] = "utf-8" | ||||||
| @ -3149,6 +3266,8 @@ class Worker(ServiceCommandSection): | |||||||
|         exit_code = exit_code if exit_code != ExitStatus.interrupted else -1 |         exit_code = exit_code if exit_code != ExitStatus.interrupted else -1 | ||||||
| 
 | 
 | ||||||
|         if not disable_monitoring: |         if not disable_monitoring: | ||||||
|  |             if stop_signal: | ||||||
|  |                 stop_signal.stop_monitor_thread() | ||||||
|             # we need to change task status according to exit code |             # we need to change task status according to exit code | ||||||
|             self.handle_task_termination(current_task.id, exit_code, TaskStopReason.no_stop) |             self.handle_task_termination(current_task.id, exit_code, TaskStopReason.no_stop) | ||||||
|             self.stop_monitor() |             self.stop_monitor() | ||||||
|  | |||||||
| @ -187,6 +187,9 @@ ENV_CHILD_AGENTS_COUNT_CMD = EnvironmentConfig("CLEARML_AGENT_CHILD_AGENTS_COUNT | |||||||
| ENV_DOCKER_ARGS_FILTERS = EnvironmentConfig("CLEARML_AGENT_DOCKER_ARGS_FILTERS") | ENV_DOCKER_ARGS_FILTERS = EnvironmentConfig("CLEARML_AGENT_DOCKER_ARGS_FILTERS") | ||||||
| ENV_DOCKER_ARGS_HIDE_ENV = EnvironmentConfig("CLEARML_AGENT_DOCKER_ARGS_HIDE_ENV") | ENV_DOCKER_ARGS_HIDE_ENV = EnvironmentConfig("CLEARML_AGENT_DOCKER_ARGS_HIDE_ENV") | ||||||
| ENV_CONFIG_BC_IN_STANDALONE = EnvironmentConfig("CLEARML_AGENT_STANDALONE_CONFIG_BC", type=bool) | ENV_CONFIG_BC_IN_STANDALONE = EnvironmentConfig("CLEARML_AGENT_STANDALONE_CONFIG_BC", type=bool) | ||||||
|  | ENV_ABORT_CALLBACK_CMD = EnvironmentConfig("CLEARML_AGENT_ABORT_CALLBACK_CMD") | ||||||
|  | ENV_ABORT_CALLBACK_CMD_TIMEOUT = EnvironmentConfig("CLEARML_AGENT_ABORT_CALLBACK_TIMEOUT") | ||||||
|  | 
 | ||||||
| """ Maintain backwards compatible configuration when launching in standalone mode """ | """ Maintain backwards compatible configuration when launching in standalone mode """ | ||||||
| 
 | 
 | ||||||
| ENV_FORCE_DOCKER_AGENT_REPO = EnvironmentConfig("FORCE_CLEARML_AGENT_REPO", "CLEARML_AGENT_DOCKER_AGENT_REPO") | ENV_FORCE_DOCKER_AGENT_REPO = EnvironmentConfig("FORCE_CLEARML_AGENT_REPO", "CLEARML_AGENT_DOCKER_AGENT_REPO") | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 clearml
						clearml