Add support for registering an abort callback

This commit is contained in:
allegroai 2022-08-20 22:55:10 +03:00
parent f298271876
commit 1cc87c9a21
3 changed files with 134 additions and 4 deletions

View File

@ -6,10 +6,15 @@ from time import time
from ....config import deferred_config
from ....backend_interface.task.development.stop_signal import TaskStopSignal
from ....backend_api.services import tasks
from ....utilities.lowlevel.threads import kill_thread
from ....utilities.process.mp import SafeEvent
class DevWorker(object):
property_abort_callback_completed = "_abort_callback_completed"
property_abort_callback_timeout = "_abort_callback_timeout"
property_abort_poll_freq = "_abort_poll_freq"
prefix = attr.ib(type=str, default="MANUAL:")
report_stdout = deferred_config('development.worker.log_stdout', True)
@ -26,12 +31,16 @@ class DevWorker(object):
self._exit_event = SafeEvent()
self._task = None
self._support_ping = False
self._poll_freq = None
self._abort_cb = None
self._abort_cb_timeout = None
self._cb_completed = None
def ping(self, timestamp=None):
try:
if self._task:
self._task.send(tasks.PingRequest(self._task.id))
except Exception:
except Exception: # noqa
return False
return True
@ -51,21 +60,98 @@ class DevWorker(object):
self._thread.start()
return True
def register_abort_callback(self, callback_function, execution_timeout, poll_freq):
if not self._task:
return
self._poll_freq = float(poll_freq) if poll_freq else None
self._abort_cb = callback_function
self._abort_cb_timeout = float(execution_timeout)
if not callback_function:
# noinspection PyProtectedMember
self._task._set_runtime_properties({DevWorker.property_abort_callback_timeout: float(-1)})
return
# noinspection PyProtectedMember
self._task._set_runtime_properties({
self.property_abort_callback_timeout: float(execution_timeout),
self.property_abort_poll_freq: float(poll_freq),
self.property_abort_callback_completed: "",
})
def _inner_abort_cb_wrapper(self):
# store the task object because we might nullify it
task = self._task
# call the user abort callback
try:
if self._abort_cb:
self._abort_cb()
self._cb_completed = True
except SystemError:
# we will get here if we killed the thread externally,
# we should not try to mark as completed, just leave the thread
return
except BaseException as ex: # noqa
if task and task.log:
task.log.warning(
"### TASK STOPPING - USER ABORTED - CALLBACK EXCEPTION: {} ###".format(ex))
# set runtime property, abort completed for the agent to know we are done
if task:
# noinspection PyProtectedMember
task._set_runtime_properties({self.property_abort_callback_completed: 1})
def _launch_abort_cb(self):
timeout = self._abort_cb_timeout or 300.
if self._task and self._task.log:
self._task.log.warning(
"### TASK STOPPING - USER ABORTED - "
"LAUNCHING CALLBACK (timeout {} sec) ###".format(timeout))
tic = time()
timed_out = False
try:
callback_thread = Thread(target=self._inner_abort_cb_wrapper)
callback_thread.daemon = True
callback_thread.start()
callback_thread.join(timeout=timeout)
if callback_thread.is_alive():
kill_thread(callback_thread, wait=False)
timed_out = True
except: # noqa
# something went wrong no just leave the process
pass
if self._task and self._task.log:
self._task.log.warning(
"### TASK STOPPING - USER ABORTED - CALLBACK {} ({:.2f} sec) ###".format(
"TIMED OUT" if timed_out else ("COMPLETED" if self._cb_completed else "FAILED"), time()-tic))
def _daemon(self):
last_ping = time()
while self._task is not None:
try:
if self._exit_event.wait(min(float(self.ping_period), float(self.report_period))):
wait_timeout = min(float(self.ping_period), float(self.report_period))
if self._poll_freq:
wait_timeout = min(self._poll_freq, wait_timeout)
if self._exit_event.wait(wait_timeout):
return
# send ping request
if self._support_ping and (time() - last_ping) >= float(self.ping_period):
self.ping()
last_ping = time()
if self._dev_stop_signal:
stop_reason = self._dev_stop_signal.test()
if stop_reason and self._task:
# call abort callback
if self._abort_cb:
self._launch_abort_cb()
# noinspection PyProtectedMember
self._task._dev_mode_stop_task(stop_reason)
except Exception:
except Exception: # noqa
pass
def unregister(self):

View File

@ -2071,7 +2071,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return None
def _set_runtime_properties(self, runtime_properties):
# type: (Mapping[str, str]) -> bool
# type: (Mapping[str, Union[str, int, float]]) -> bool
if not Session.check_min_api_version('2.13') or not runtime_properties:
return False

View File

@ -2594,6 +2594,50 @@ class Task(_Task):
self.reload()
return result
def register_abort_callback(
self,
callback_function, # type: Optional[Callable]
callback_execution_timeout=30. # type: float
): # type (...) -> None
"""
Register a Task abort callback (single callback function support only).
Pass a function to be called from a background thread when the Task is **externally** being aborted.
Users must specify a timeout for the callback function execution (default 30 seconds)
if the callback execution function exceeds the timeout, the Task's process will be terminated
Call this register function from the main process only.
Note: Ctrl-C is Not considered external, only backend induced abort is covered here
:param callback_function: Callback function to be called via external thread (from the main process).
pass None to remove existing callback
:param callback_execution_timeout: Maximum callback execution time in seconds, after which the process
will be terminated even if the callback did not return
"""
if self.__is_subprocess():
raise ValueError("Register abort callback must be called from the main process, this is a subprocess.")
if callback_function is None:
if self._dev_worker:
self._dev_worker.register_abort_callback(callback_function=None, execution_timeout=0, poll_freq=0)
return
if float(callback_execution_timeout) <= 0:
raise ValueError(
"function_timeout_sec must be positive timeout in seconds, got {}".format(callback_execution_timeout))
# if we are running remotely we might not have a DevWorker monitoring us, so let's create one
if not self._dev_worker:
self._dev_worker = DevWorker()
self._dev_worker.register(self, stop_signal_support=True)
poll_freq = 15.0
self._dev_worker.register_abort_callback(
callback_function=callback_function,
execution_timeout=callback_execution_timeout,
poll_freq=poll_freq
)
@classmethod
def import_task(cls, task_data, target_task=None, update=False):
# type: (dict, Optional[Union[str, Task]], bool) -> Optional[Task]