mirror of
https://github.com/clearml/clearml
synced 2025-04-16 21:42:10 +00:00
Add support for registering an abort callback
This commit is contained in:
parent
f298271876
commit
1cc87c9a21
@ -6,10 +6,15 @@ from time import time
|
||||
from ....config import deferred_config
|
||||
from ....backend_interface.task.development.stop_signal import TaskStopSignal
|
||||
from ....backend_api.services import tasks
|
||||
from ....utilities.lowlevel.threads import kill_thread
|
||||
from ....utilities.process.mp import SafeEvent
|
||||
|
||||
|
||||
class DevWorker(object):
|
||||
property_abort_callback_completed = "_abort_callback_completed"
|
||||
property_abort_callback_timeout = "_abort_callback_timeout"
|
||||
property_abort_poll_freq = "_abort_poll_freq"
|
||||
|
||||
prefix = attr.ib(type=str, default="MANUAL:")
|
||||
|
||||
report_stdout = deferred_config('development.worker.log_stdout', True)
|
||||
@ -26,12 +31,16 @@ class DevWorker(object):
|
||||
self._exit_event = SafeEvent()
|
||||
self._task = None
|
||||
self._support_ping = False
|
||||
self._poll_freq = None
|
||||
self._abort_cb = None
|
||||
self._abort_cb_timeout = None
|
||||
self._cb_completed = None
|
||||
|
||||
def ping(self, timestamp=None):
|
||||
try:
|
||||
if self._task:
|
||||
self._task.send(tasks.PingRequest(self._task.id))
|
||||
except Exception:
|
||||
except Exception: # noqa
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -51,21 +60,98 @@ class DevWorker(object):
|
||||
self._thread.start()
|
||||
return True
|
||||
|
||||
def register_abort_callback(self, callback_function, execution_timeout, poll_freq):
|
||||
if not self._task:
|
||||
return
|
||||
|
||||
self._poll_freq = float(poll_freq) if poll_freq else None
|
||||
self._abort_cb = callback_function
|
||||
self._abort_cb_timeout = float(execution_timeout)
|
||||
|
||||
if not callback_function:
|
||||
# noinspection PyProtectedMember
|
||||
self._task._set_runtime_properties({DevWorker.property_abort_callback_timeout: float(-1)})
|
||||
return
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
self._task._set_runtime_properties({
|
||||
self.property_abort_callback_timeout: float(execution_timeout),
|
||||
self.property_abort_poll_freq: float(poll_freq),
|
||||
self.property_abort_callback_completed: "",
|
||||
})
|
||||
|
||||
def _inner_abort_cb_wrapper(self):
|
||||
# store the task object because we might nullify it
|
||||
task = self._task
|
||||
# call the user abort callback
|
||||
try:
|
||||
if self._abort_cb:
|
||||
self._abort_cb()
|
||||
self._cb_completed = True
|
||||
except SystemError:
|
||||
# we will get here if we killed the thread externally,
|
||||
# we should not try to mark as completed, just leave the thread
|
||||
return
|
||||
except BaseException as ex: # noqa
|
||||
if task and task.log:
|
||||
task.log.warning(
|
||||
"### TASK STOPPING - USER ABORTED - CALLBACK EXCEPTION: {} ###".format(ex))
|
||||
|
||||
# set runtime property, abort completed for the agent to know we are done
|
||||
if task:
|
||||
# noinspection PyProtectedMember
|
||||
task._set_runtime_properties({self.property_abort_callback_completed: 1})
|
||||
|
||||
def _launch_abort_cb(self):
|
||||
timeout = self._abort_cb_timeout or 300.
|
||||
if self._task and self._task.log:
|
||||
self._task.log.warning(
|
||||
"### TASK STOPPING - USER ABORTED - "
|
||||
"LAUNCHING CALLBACK (timeout {} sec) ###".format(timeout))
|
||||
|
||||
tic = time()
|
||||
timed_out = False
|
||||
try:
|
||||
callback_thread = Thread(target=self._inner_abort_cb_wrapper)
|
||||
callback_thread.daemon = True
|
||||
callback_thread.start()
|
||||
callback_thread.join(timeout=timeout)
|
||||
if callback_thread.is_alive():
|
||||
kill_thread(callback_thread, wait=False)
|
||||
timed_out = True
|
||||
except: # noqa
|
||||
# something went wrong no just leave the process
|
||||
pass
|
||||
|
||||
if self._task and self._task.log:
|
||||
self._task.log.warning(
|
||||
"### TASK STOPPING - USER ABORTED - CALLBACK {} ({:.2f} sec) ###".format(
|
||||
"TIMED OUT" if timed_out else ("COMPLETED" if self._cb_completed else "FAILED"), time()-tic))
|
||||
|
||||
def _daemon(self):
|
||||
last_ping = time()
|
||||
while self._task is not None:
|
||||
try:
|
||||
if self._exit_event.wait(min(float(self.ping_period), float(self.report_period))):
|
||||
wait_timeout = min(float(self.ping_period), float(self.report_period))
|
||||
if self._poll_freq:
|
||||
wait_timeout = min(self._poll_freq, wait_timeout)
|
||||
if self._exit_event.wait(wait_timeout):
|
||||
return
|
||||
# send ping request
|
||||
if self._support_ping and (time() - last_ping) >= float(self.ping_period):
|
||||
self.ping()
|
||||
last_ping = time()
|
||||
|
||||
if self._dev_stop_signal:
|
||||
stop_reason = self._dev_stop_signal.test()
|
||||
if stop_reason and self._task:
|
||||
# call abort callback
|
||||
if self._abort_cb:
|
||||
self._launch_abort_cb()
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
self._task._dev_mode_stop_task(stop_reason)
|
||||
except Exception:
|
||||
except Exception: # noqa
|
||||
pass
|
||||
|
||||
def unregister(self):
|
||||
|
@ -2071,7 +2071,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
return None
|
||||
|
||||
def _set_runtime_properties(self, runtime_properties):
|
||||
# type: (Mapping[str, str]) -> bool
|
||||
# type: (Mapping[str, Union[str, int, float]]) -> bool
|
||||
if not Session.check_min_api_version('2.13') or not runtime_properties:
|
||||
return False
|
||||
|
||||
|
@ -2594,6 +2594,50 @@ class Task(_Task):
|
||||
self.reload()
|
||||
return result
|
||||
|
||||
def register_abort_callback(
|
||||
self,
|
||||
callback_function, # type: Optional[Callable]
|
||||
callback_execution_timeout=30. # type: float
|
||||
): # type (...) -> None
|
||||
"""
|
||||
Register a Task abort callback (single callback function support only).
|
||||
Pass a function to be called from a background thread when the Task is **externally** being aborted.
|
||||
Users must specify a timeout for the callback function execution (default 30 seconds)
|
||||
if the callback execution function exceeds the timeout, the Task's process will be terminated
|
||||
|
||||
Call this register function from the main process only.
|
||||
|
||||
Note: Ctrl-C is Not considered external, only backend induced abort is covered here
|
||||
|
||||
:param callback_function: Callback function to be called via external thread (from the main process).
|
||||
pass None to remove existing callback
|
||||
:param callback_execution_timeout: Maximum callback execution time in seconds, after which the process
|
||||
will be terminated even if the callback did not return
|
||||
"""
|
||||
if self.__is_subprocess():
|
||||
raise ValueError("Register abort callback must be called from the main process, this is a subprocess.")
|
||||
|
||||
if callback_function is None:
|
||||
if self._dev_worker:
|
||||
self._dev_worker.register_abort_callback(callback_function=None, execution_timeout=0, poll_freq=0)
|
||||
return
|
||||
|
||||
if float(callback_execution_timeout) <= 0:
|
||||
raise ValueError(
|
||||
"function_timeout_sec must be positive timeout in seconds, got {}".format(callback_execution_timeout))
|
||||
|
||||
# if we are running remotely we might not have a DevWorker monitoring us, so let's create one
|
||||
if not self._dev_worker:
|
||||
self._dev_worker = DevWorker()
|
||||
self._dev_worker.register(self, stop_signal_support=True)
|
||||
|
||||
poll_freq = 15.0
|
||||
self._dev_worker.register_abort_callback(
|
||||
callback_function=callback_function,
|
||||
execution_timeout=callback_execution_timeout,
|
||||
poll_freq=poll_freq
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def import_task(cls, task_data, target_task=None, update=False):
|
||||
# type: (dict, Optional[Union[str, Task]], bool) -> Optional[Task]
|
||||
|
Loading…
Reference in New Issue
Block a user