mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Add support for registering an abort callback
This commit is contained in:
parent
f298271876
commit
1cc87c9a21
@ -6,10 +6,15 @@ from time import time
|
|||||||
from ....config import deferred_config
|
from ....config import deferred_config
|
||||||
from ....backend_interface.task.development.stop_signal import TaskStopSignal
|
from ....backend_interface.task.development.stop_signal import TaskStopSignal
|
||||||
from ....backend_api.services import tasks
|
from ....backend_api.services import tasks
|
||||||
|
from ....utilities.lowlevel.threads import kill_thread
|
||||||
from ....utilities.process.mp import SafeEvent
|
from ....utilities.process.mp import SafeEvent
|
||||||
|
|
||||||
|
|
||||||
class DevWorker(object):
|
class DevWorker(object):
|
||||||
|
property_abort_callback_completed = "_abort_callback_completed"
|
||||||
|
property_abort_callback_timeout = "_abort_callback_timeout"
|
||||||
|
property_abort_poll_freq = "_abort_poll_freq"
|
||||||
|
|
||||||
prefix = attr.ib(type=str, default="MANUAL:")
|
prefix = attr.ib(type=str, default="MANUAL:")
|
||||||
|
|
||||||
report_stdout = deferred_config('development.worker.log_stdout', True)
|
report_stdout = deferred_config('development.worker.log_stdout', True)
|
||||||
@ -26,12 +31,16 @@ class DevWorker(object):
|
|||||||
self._exit_event = SafeEvent()
|
self._exit_event = SafeEvent()
|
||||||
self._task = None
|
self._task = None
|
||||||
self._support_ping = False
|
self._support_ping = False
|
||||||
|
self._poll_freq = None
|
||||||
|
self._abort_cb = None
|
||||||
|
self._abort_cb_timeout = None
|
||||||
|
self._cb_completed = None
|
||||||
|
|
||||||
def ping(self, timestamp=None):
|
def ping(self, timestamp=None):
|
||||||
try:
|
try:
|
||||||
if self._task:
|
if self._task:
|
||||||
self._task.send(tasks.PingRequest(self._task.id))
|
self._task.send(tasks.PingRequest(self._task.id))
|
||||||
except Exception:
|
except Exception: # noqa
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -51,21 +60,98 @@ class DevWorker(object):
|
|||||||
self._thread.start()
|
self._thread.start()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def register_abort_callback(self, callback_function, execution_timeout, poll_freq):
|
||||||
|
if not self._task:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._poll_freq = float(poll_freq) if poll_freq else None
|
||||||
|
self._abort_cb = callback_function
|
||||||
|
self._abort_cb_timeout = float(execution_timeout)
|
||||||
|
|
||||||
|
if not callback_function:
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
self._task._set_runtime_properties({DevWorker.property_abort_callback_timeout: float(-1)})
|
||||||
|
return
|
||||||
|
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
self._task._set_runtime_properties({
|
||||||
|
self.property_abort_callback_timeout: float(execution_timeout),
|
||||||
|
self.property_abort_poll_freq: float(poll_freq),
|
||||||
|
self.property_abort_callback_completed: "",
|
||||||
|
})
|
||||||
|
|
||||||
|
def _inner_abort_cb_wrapper(self):
|
||||||
|
# store the task object because we might nullify it
|
||||||
|
task = self._task
|
||||||
|
# call the user abort callback
|
||||||
|
try:
|
||||||
|
if self._abort_cb:
|
||||||
|
self._abort_cb()
|
||||||
|
self._cb_completed = True
|
||||||
|
except SystemError:
|
||||||
|
# we will get here if we killed the thread externally,
|
||||||
|
# we should not try to mark as completed, just leave the thread
|
||||||
|
return
|
||||||
|
except BaseException as ex: # noqa
|
||||||
|
if task and task.log:
|
||||||
|
task.log.warning(
|
||||||
|
"### TASK STOPPING - USER ABORTED - CALLBACK EXCEPTION: {} ###".format(ex))
|
||||||
|
|
||||||
|
# set runtime property, abort completed for the agent to know we are done
|
||||||
|
if task:
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
task._set_runtime_properties({self.property_abort_callback_completed: 1})
|
||||||
|
|
||||||
|
def _launch_abort_cb(self):
|
||||||
|
timeout = self._abort_cb_timeout or 300.
|
||||||
|
if self._task and self._task.log:
|
||||||
|
self._task.log.warning(
|
||||||
|
"### TASK STOPPING - USER ABORTED - "
|
||||||
|
"LAUNCHING CALLBACK (timeout {} sec) ###".format(timeout))
|
||||||
|
|
||||||
|
tic = time()
|
||||||
|
timed_out = False
|
||||||
|
try:
|
||||||
|
callback_thread = Thread(target=self._inner_abort_cb_wrapper)
|
||||||
|
callback_thread.daemon = True
|
||||||
|
callback_thread.start()
|
||||||
|
callback_thread.join(timeout=timeout)
|
||||||
|
if callback_thread.is_alive():
|
||||||
|
kill_thread(callback_thread, wait=False)
|
||||||
|
timed_out = True
|
||||||
|
except: # noqa
|
||||||
|
# something went wrong no just leave the process
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self._task and self._task.log:
|
||||||
|
self._task.log.warning(
|
||||||
|
"### TASK STOPPING - USER ABORTED - CALLBACK {} ({:.2f} sec) ###".format(
|
||||||
|
"TIMED OUT" if timed_out else ("COMPLETED" if self._cb_completed else "FAILED"), time()-tic))
|
||||||
|
|
||||||
def _daemon(self):
|
def _daemon(self):
|
||||||
last_ping = time()
|
last_ping = time()
|
||||||
while self._task is not None:
|
while self._task is not None:
|
||||||
try:
|
try:
|
||||||
if self._exit_event.wait(min(float(self.ping_period), float(self.report_period))):
|
wait_timeout = min(float(self.ping_period), float(self.report_period))
|
||||||
|
if self._poll_freq:
|
||||||
|
wait_timeout = min(self._poll_freq, wait_timeout)
|
||||||
|
if self._exit_event.wait(wait_timeout):
|
||||||
return
|
return
|
||||||
# send ping request
|
# send ping request
|
||||||
if self._support_ping and (time() - last_ping) >= float(self.ping_period):
|
if self._support_ping and (time() - last_ping) >= float(self.ping_period):
|
||||||
self.ping()
|
self.ping()
|
||||||
last_ping = time()
|
last_ping = time()
|
||||||
|
|
||||||
if self._dev_stop_signal:
|
if self._dev_stop_signal:
|
||||||
stop_reason = self._dev_stop_signal.test()
|
stop_reason = self._dev_stop_signal.test()
|
||||||
if stop_reason and self._task:
|
if stop_reason and self._task:
|
||||||
|
# call abort callback
|
||||||
|
if self._abort_cb:
|
||||||
|
self._launch_abort_cb()
|
||||||
|
|
||||||
|
# noinspection PyProtectedMember
|
||||||
self._task._dev_mode_stop_task(stop_reason)
|
self._task._dev_mode_stop_task(stop_reason)
|
||||||
except Exception:
|
except Exception: # noqa
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def unregister(self):
|
def unregister(self):
|
||||||
|
@ -2071,7 +2071,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _set_runtime_properties(self, runtime_properties):
|
def _set_runtime_properties(self, runtime_properties):
|
||||||
# type: (Mapping[str, str]) -> bool
|
# type: (Mapping[str, Union[str, int, float]]) -> bool
|
||||||
if not Session.check_min_api_version('2.13') or not runtime_properties:
|
if not Session.check_min_api_version('2.13') or not runtime_properties:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -2594,6 +2594,50 @@ class Task(_Task):
|
|||||||
self.reload()
|
self.reload()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def register_abort_callback(
|
||||||
|
self,
|
||||||
|
callback_function, # type: Optional[Callable]
|
||||||
|
callback_execution_timeout=30. # type: float
|
||||||
|
): # type (...) -> None
|
||||||
|
"""
|
||||||
|
Register a Task abort callback (single callback function support only).
|
||||||
|
Pass a function to be called from a background thread when the Task is **externally** being aborted.
|
||||||
|
Users must specify a timeout for the callback function execution (default 30 seconds)
|
||||||
|
if the callback execution function exceeds the timeout, the Task's process will be terminated
|
||||||
|
|
||||||
|
Call this register function from the main process only.
|
||||||
|
|
||||||
|
Note: Ctrl-C is Not considered external, only backend induced abort is covered here
|
||||||
|
|
||||||
|
:param callback_function: Callback function to be called via external thread (from the main process).
|
||||||
|
pass None to remove existing callback
|
||||||
|
:param callback_execution_timeout: Maximum callback execution time in seconds, after which the process
|
||||||
|
will be terminated even if the callback did not return
|
||||||
|
"""
|
||||||
|
if self.__is_subprocess():
|
||||||
|
raise ValueError("Register abort callback must be called from the main process, this is a subprocess.")
|
||||||
|
|
||||||
|
if callback_function is None:
|
||||||
|
if self._dev_worker:
|
||||||
|
self._dev_worker.register_abort_callback(callback_function=None, execution_timeout=0, poll_freq=0)
|
||||||
|
return
|
||||||
|
|
||||||
|
if float(callback_execution_timeout) <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
"function_timeout_sec must be positive timeout in seconds, got {}".format(callback_execution_timeout))
|
||||||
|
|
||||||
|
# if we are running remotely we might not have a DevWorker monitoring us, so let's create one
|
||||||
|
if not self._dev_worker:
|
||||||
|
self._dev_worker = DevWorker()
|
||||||
|
self._dev_worker.register(self, stop_signal_support=True)
|
||||||
|
|
||||||
|
poll_freq = 15.0
|
||||||
|
self._dev_worker.register_abort_callback(
|
||||||
|
callback_function=callback_function,
|
||||||
|
execution_timeout=callback_execution_timeout,
|
||||||
|
poll_freq=poll_freq
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def import_task(cls, task_data, target_task=None, update=False):
|
def import_task(cls, task_data, target_task=None, update=False):
|
||||||
# type: (dict, Optional[Union[str, Task]], bool) -> Optional[Task]
|
# type: (dict, Optional[Union[str, Task]], bool) -> Optional[Task]
|
||||||
|
Loading…
Reference in New Issue
Block a user