mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Add support for registering an abort callback
This commit is contained in:
		
							parent
							
								
									f298271876
								
							
						
					
					
						commit
						1cc87c9a21
					
				| @ -6,10 +6,15 @@ from time import time | ||||
| from ....config import deferred_config | ||||
| from ....backend_interface.task.development.stop_signal import TaskStopSignal | ||||
| from ....backend_api.services import tasks | ||||
| from ....utilities.lowlevel.threads import kill_thread | ||||
| from ....utilities.process.mp import SafeEvent | ||||
| 
 | ||||
| 
 | ||||
| class DevWorker(object): | ||||
|     property_abort_callback_completed = "_abort_callback_completed" | ||||
|     property_abort_callback_timeout = "_abort_callback_timeout" | ||||
|     property_abort_poll_freq = "_abort_poll_freq" | ||||
| 
 | ||||
|     prefix = attr.ib(type=str, default="MANUAL:") | ||||
| 
 | ||||
|     report_stdout = deferred_config('development.worker.log_stdout', True) | ||||
| @ -26,12 +31,16 @@ class DevWorker(object): | ||||
|         self._exit_event = SafeEvent() | ||||
|         self._task = None | ||||
|         self._support_ping = False | ||||
|         self._poll_freq = None | ||||
|         self._abort_cb = None | ||||
|         self._abort_cb_timeout = None | ||||
|         self._cb_completed = None | ||||
| 
 | ||||
|     def ping(self, timestamp=None): | ||||
|         try: | ||||
|             if self._task: | ||||
|                 self._task.send(tasks.PingRequest(self._task.id)) | ||||
|         except Exception: | ||||
|         except Exception:  # noqa | ||||
|             return False | ||||
|         return True | ||||
| 
 | ||||
| @ -51,21 +60,98 @@ class DevWorker(object): | ||||
|         self._thread.start() | ||||
|         return True | ||||
| 
 | ||||
|     def register_abort_callback(self, callback_function, execution_timeout, poll_freq): | ||||
|         if not self._task: | ||||
|             return | ||||
| 
 | ||||
|         self._poll_freq = float(poll_freq) if poll_freq else None | ||||
|         self._abort_cb = callback_function | ||||
|         self._abort_cb_timeout = float(execution_timeout) | ||||
| 
 | ||||
|         if not callback_function: | ||||
|             # noinspection PyProtectedMember | ||||
|             self._task._set_runtime_properties({DevWorker.property_abort_callback_timeout: float(-1)}) | ||||
|             return | ||||
| 
 | ||||
|         # noinspection PyProtectedMember | ||||
|         self._task._set_runtime_properties({ | ||||
|             self.property_abort_callback_timeout: float(execution_timeout), | ||||
|             self.property_abort_poll_freq: float(poll_freq), | ||||
|             self.property_abort_callback_completed: "", | ||||
|         }) | ||||
| 
 | ||||
|     def _inner_abort_cb_wrapper(self): | ||||
|         # store the task object because we might nullify it | ||||
|         task = self._task | ||||
|         # call the user abort callback | ||||
|         try: | ||||
|             if self._abort_cb: | ||||
|                 self._abort_cb() | ||||
|             self._cb_completed = True | ||||
|         except SystemError: | ||||
|             # we will get here if we killed the thread externally, | ||||
|             # we should not try to mark as completed, just leave the thread | ||||
|             return | ||||
|         except BaseException as ex:  # noqa | ||||
|             if task and task.log: | ||||
|                 task.log.warning( | ||||
|                     "### TASK STOPPING - USER ABORTED - CALLBACK EXCEPTION: {} ###".format(ex)) | ||||
| 
 | ||||
|         # set runtime property, abort completed for the agent to know we are done | ||||
|         if task: | ||||
|             # noinspection PyProtectedMember | ||||
|             task._set_runtime_properties({self.property_abort_callback_completed: 1}) | ||||
| 
 | ||||
|     def _launch_abort_cb(self): | ||||
|         timeout = self._abort_cb_timeout or 300. | ||||
|         if self._task and self._task.log: | ||||
|             self._task.log.warning( | ||||
|                 "### TASK STOPPING - USER ABORTED - " | ||||
|                 "LAUNCHING CALLBACK (timeout {} sec) ###".format(timeout)) | ||||
| 
 | ||||
|         tic = time() | ||||
|         timed_out = False | ||||
|         try: | ||||
|             callback_thread = Thread(target=self._inner_abort_cb_wrapper) | ||||
|             callback_thread.daemon = True | ||||
|             callback_thread.start() | ||||
|             callback_thread.join(timeout=timeout) | ||||
|             if callback_thread.is_alive(): | ||||
|                 kill_thread(callback_thread, wait=False) | ||||
|                 timed_out = True | ||||
|         except:  # noqa | ||||
|             # something went wrong no just leave the process | ||||
|             pass | ||||
| 
 | ||||
|         if self._task and self._task.log: | ||||
|             self._task.log.warning( | ||||
|                 "### TASK STOPPING - USER ABORTED - CALLBACK {} ({:.2f} sec) ###".format( | ||||
|                     "TIMED OUT" if timed_out else ("COMPLETED" if self._cb_completed else "FAILED"), time()-tic)) | ||||
| 
 | ||||
|     def _daemon(self): | ||||
|         last_ping = time() | ||||
|         while self._task is not None: | ||||
|             try: | ||||
|                 if self._exit_event.wait(min(float(self.ping_period), float(self.report_period))): | ||||
|                 wait_timeout = min(float(self.ping_period), float(self.report_period)) | ||||
|                 if self._poll_freq: | ||||
|                     wait_timeout = min(self._poll_freq, wait_timeout) | ||||
|                 if self._exit_event.wait(wait_timeout): | ||||
|                     return | ||||
|                 # send ping request | ||||
|                 if self._support_ping and (time() - last_ping) >= float(self.ping_period): | ||||
|                     self.ping() | ||||
|                     last_ping = time() | ||||
| 
 | ||||
|                 if self._dev_stop_signal: | ||||
|                     stop_reason = self._dev_stop_signal.test() | ||||
|                     if stop_reason and self._task: | ||||
|                         # call abort callback | ||||
|                         if self._abort_cb: | ||||
|                             self._launch_abort_cb() | ||||
| 
 | ||||
|                         # noinspection PyProtectedMember | ||||
|                         self._task._dev_mode_stop_task(stop_reason) | ||||
|             except Exception: | ||||
|             except Exception:  # noqa | ||||
|                 pass | ||||
| 
 | ||||
|     def unregister(self): | ||||
|  | ||||
| @ -2071,7 +2071,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): | ||||
|             return None | ||||
| 
 | ||||
|     def _set_runtime_properties(self, runtime_properties): | ||||
|         # type: (Mapping[str, str]) -> bool | ||||
|         # type: (Mapping[str, Union[str, int, float]]) -> bool | ||||
|         if not Session.check_min_api_version('2.13') or not runtime_properties: | ||||
|             return False | ||||
| 
 | ||||
|  | ||||
| @ -2594,6 +2594,50 @@ class Task(_Task): | ||||
|         self.reload() | ||||
|         return result | ||||
| 
 | ||||
|     def register_abort_callback( | ||||
|             self, | ||||
|             callback_function,  # type: Optional[Callable] | ||||
|             callback_execution_timeout=30.  # type: float | ||||
|     ):  # type (...) -> None | ||||
|         """ | ||||
|         Register a Task abort callback (single callback function support only). | ||||
|         Pass a function to be called from a background thread when the Task is **externally** being aborted. | ||||
|         Users must specify a timeout for the callback function execution (default 30 seconds) | ||||
|         if the callback execution function exceeds the timeout, the Task's process will be terminated | ||||
| 
 | ||||
|         Call this register function from the main process only. | ||||
| 
 | ||||
|         Note: Ctrl-C is Not considered external, only backend induced abort is covered here | ||||
| 
 | ||||
|         :param callback_function: Callback function to be called via external thread (from the main process). | ||||
|             pass None to remove existing callback | ||||
|         :param callback_execution_timeout: Maximum callback execution time in seconds, after which the process | ||||
|             will be terminated even if the callback did not return | ||||
|         """ | ||||
|         if self.__is_subprocess(): | ||||
|             raise ValueError("Register abort callback must be called from the main process, this is a subprocess.") | ||||
| 
 | ||||
|         if callback_function is None: | ||||
|             if self._dev_worker: | ||||
|                 self._dev_worker.register_abort_callback(callback_function=None, execution_timeout=0, poll_freq=0) | ||||
|             return | ||||
| 
 | ||||
|         if float(callback_execution_timeout) <= 0: | ||||
|             raise ValueError( | ||||
|                 "function_timeout_sec must be positive timeout in seconds, got {}".format(callback_execution_timeout)) | ||||
| 
 | ||||
|         # if we are running remotely we might not have a DevWorker monitoring us, so let's create one | ||||
|         if not self._dev_worker: | ||||
|             self._dev_worker = DevWorker() | ||||
|             self._dev_worker.register(self, stop_signal_support=True) | ||||
| 
 | ||||
|         poll_freq = 15.0 | ||||
|         self._dev_worker.register_abort_callback( | ||||
|             callback_function=callback_function, | ||||
|             execution_timeout=callback_execution_timeout, | ||||
|             poll_freq=poll_freq | ||||
|         ) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def import_task(cls, task_data, target_task=None, update=False): | ||||
|         # type: (dict, Optional[Union[str, Task]], bool) -> Optional[Task] | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai