diff --git a/clearml/backend_interface/task/development/worker.py b/clearml/backend_interface/task/development/worker.py index fd6331fa..fa24ad40 100644 --- a/clearml/backend_interface/task/development/worker.py +++ b/clearml/backend_interface/task/development/worker.py @@ -6,10 +6,15 @@ from time import time from ....config import deferred_config from ....backend_interface.task.development.stop_signal import TaskStopSignal from ....backend_api.services import tasks +from ....utilities.lowlevel.threads import kill_thread from ....utilities.process.mp import SafeEvent class DevWorker(object): + property_abort_callback_completed = "_abort_callback_completed" + property_abort_callback_timeout = "_abort_callback_timeout" + property_abort_poll_freq = "_abort_poll_freq" + prefix = attr.ib(type=str, default="MANUAL:") report_stdout = deferred_config('development.worker.log_stdout', True) @@ -26,12 +31,16 @@ class DevWorker(object): self._exit_event = SafeEvent() self._task = None self._support_ping = False + self._poll_freq = None + self._abort_cb = None + self._abort_cb_timeout = None + self._cb_completed = None def ping(self, timestamp=None): try: if self._task: self._task.send(tasks.PingRequest(self._task.id)) - except Exception: + except Exception: # noqa return False return True @@ -51,21 +60,98 @@ class DevWorker(object): self._thread.start() return True + def register_abort_callback(self, callback_function, execution_timeout, poll_freq): + if not self._task: + return + + self._poll_freq = float(poll_freq) if poll_freq else None + self._abort_cb = callback_function + self._abort_cb_timeout = float(execution_timeout) + + if not callback_function: + # noinspection PyProtectedMember + self._task._set_runtime_properties({DevWorker.property_abort_callback_timeout: float(-1)}) + return + + # noinspection PyProtectedMember + self._task._set_runtime_properties({ + self.property_abort_callback_timeout: float(execution_timeout), + self.property_abort_poll_freq: float(poll_freq), + self.property_abort_callback_completed: "", + }) + + def _inner_abort_cb_wrapper(self): + # store the task object because we might nullify it + task = self._task + # call the user abort callback + try: + if self._abort_cb: + self._abort_cb() + self._cb_completed = True + except SystemError: + # we will get here if we killed the thread externally, + # we should not try to mark as completed, just leave the thread + return + except BaseException as ex: # noqa + if task and task.log: + task.log.warning( + "### TASK STOPPING - USER ABORTED - CALLBACK EXCEPTION: {} ###".format(ex)) + + # set runtime property, abort completed for the agent to know we are done + if task: + # noinspection PyProtectedMember + task._set_runtime_properties({self.property_abort_callback_completed: 1}) + + def _launch_abort_cb(self): + timeout = self._abort_cb_timeout or 300. + if self._task and self._task.log: + self._task.log.warning( + "### TASK STOPPING - USER ABORTED - " + "LAUNCHING CALLBACK (timeout {} sec) ###".format(timeout)) + + tic = time() + timed_out = False + try: + callback_thread = Thread(target=self._inner_abort_cb_wrapper) + callback_thread.daemon = True + callback_thread.start() + callback_thread.join(timeout=timeout) + if callback_thread.is_alive(): + kill_thread(callback_thread, wait=False) + timed_out = True + except: # noqa + # something went wrong no just leave the process + pass + + if self._task and self._task.log: + self._task.log.warning( + "### TASK STOPPING - USER ABORTED - CALLBACK {} ({:.2f} sec) ###".format( + "TIMED OUT" if timed_out else ("COMPLETED" if self._cb_completed else "FAILED"), time()-tic)) + def _daemon(self): last_ping = time() while self._task is not None: try: - if self._exit_event.wait(min(float(self.ping_period), float(self.report_period))): + wait_timeout = min(float(self.ping_period), float(self.report_period)) + if self._poll_freq: + wait_timeout = min(self._poll_freq, wait_timeout) + if self._exit_event.wait(wait_timeout): return # send ping request if self._support_ping and (time() - last_ping) >= float(self.ping_period): self.ping() last_ping = time() + if self._dev_stop_signal: stop_reason = self._dev_stop_signal.test() if stop_reason and self._task: + # call abort callback + if self._abort_cb: + self._launch_abort_cb() + + # noinspection PyProtectedMember self._task._dev_mode_stop_task(stop_reason) - except Exception: + except Exception: # noqa pass def unregister(self): diff --git a/clearml/backend_interface/task/task.py b/clearml/backend_interface/task/task.py index ea8dfbc9..50b661b4 100644 --- a/clearml/backend_interface/task/task.py +++ b/clearml/backend_interface/task/task.py @@ -2071,7 +2071,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): return None def _set_runtime_properties(self, runtime_properties): - # type: (Mapping[str, str]) -> bool + # type: (Mapping[str, Union[str, int, float]]) -> bool if not Session.check_min_api_version('2.13') or not runtime_properties: return False diff --git a/clearml/task.py b/clearml/task.py index 600c5480..4f0c213e 100644 --- a/clearml/task.py +++ b/clearml/task.py @@ -2594,6 +2594,50 @@ class Task(_Task): self.reload() return result + def register_abort_callback( + self, + callback_function, # type: Optional[Callable] + callback_execution_timeout=30. # type: float + ): # type (...) -> None + """ + Register a Task abort callback (single callback function support only). + Pass a function to be called from a background thread when the Task is **externally** being aborted. + Users must specify a timeout for the callback function execution (default 30 seconds) + if the callback execution function exceeds the timeout, the Task's process will be terminated + + Call this register function from the main process only. + + Note: Ctrl-C is Not considered external, only backend induced abort is covered here + + :param callback_function: Callback function to be called via external thread (from the main process). + pass None to remove existing callback + :param callback_execution_timeout: Maximum callback execution time in seconds, after which the process + will be terminated even if the callback did not return + """ + if self.__is_subprocess(): + raise ValueError("Register abort callback must be called from the main process, this is a subprocess.") + + if callback_function is None: + if self._dev_worker: + self._dev_worker.register_abort_callback(callback_function=None, execution_timeout=0, poll_freq=0) + return + + if float(callback_execution_timeout) <= 0: + raise ValueError( + "function_timeout_sec must be positive timeout in seconds, got {}".format(callback_execution_timeout)) + + # if we are running remotely we might not have a DevWorker monitoring us, so let's create one + if not self._dev_worker: + self._dev_worker = DevWorker() + self._dev_worker.register(self, stop_signal_support=True) + + poll_freq = 15.0 + self._dev_worker.register_abort_callback( + callback_function=callback_function, + execution_timeout=callback_execution_timeout, + poll_freq=poll_freq + ) + @classmethod def import_task(cls, task_data, target_task=None, update=False): # type: (dict, Optional[Union[str, Task]], bool) -> Optional[Task]