From 10c6629982b1dc40a5d2ed2e95602a70cee33889 Mon Sep 17 00:00:00 2001
From: allegroai <>
Date: Fri, 19 Apr 2024 23:46:57 +0300
Subject: [PATCH] Support skipping re-enqueue on suspected preempted k8s pods

---
 clearml_agent/glue/definitions.py         |  3 +
 clearml_agent/glue/pending_pods_daemon.py | 69 ++++++++++++++---------
 2 files changed, 44 insertions(+), 28 deletions(-)

diff --git a/clearml_agent/glue/definitions.py b/clearml_agent/glue/definitions.py
index e04e13e..ee808e2 100644
--- a/clearml_agent/glue/definitions.py
+++ b/clearml_agent/glue/definitions.py
@@ -9,3 +9,6 @@ Script will be appended to the specified file.
 ENV_DEFAULT_EXECUTION_AGENT_ARGS = EnvEntry("K8S_GLUE_DEF_EXEC_AGENT_ARGS", default="--full-monitoring --require-queue")
 ENV_POD_AGENT_INSTALL_ARGS = EnvEntry("K8S_GLUE_POD_AGENT_INSTALL_ARGS", default="", lstrip=False)
 ENV_POD_MONITOR_LOG_BATCH_SIZE = EnvEntry("K8S_GLUE_POD_MONITOR_LOG_BATCH_SIZE", default=5, converter=int)
+ENV_POD_MONITOR_DISABLE_ENQUEUE_ON_PREEMPTION = EnvEntry(
+    "K8S_GLUE_POD_MONITOR_DISABLE_ENQUEUE_ON_PREEMPTION", default=False, converter=bool
+)
diff --git a/clearml_agent/glue/pending_pods_daemon.py b/clearml_agent/glue/pending_pods_daemon.py
index 59a19bc..01d996e 100644
--- a/clearml_agent/glue/pending_pods_daemon.py
+++ b/clearml_agent/glue/pending_pods_daemon.py
@@ -9,6 +9,7 @@ from clearml_agent.helper.process import stringify_bash_output
 from .daemon import K8sDaemon
 from .utilities import get_path
 from .errors import GetPodsError
+from .definitions import ENV_POD_MONITOR_DISABLE_ENQUEUE_ON_PREEMPTION
 
 
 class PendingPodsDaemon(K8sDaemon):
@@ -17,16 +18,16 @@ class PendingPodsDaemon(K8sDaemon):
         self._polling_interval = polling_interval
         self._last_tasks_msgs = {}  # last msg updated for every task
 
-    def get_pods(self, pod_name=None):
+    def get_pods(self, pod_name=None, debug_msg="Detecting pending pods: {cmd}"):
         filters = ["status.phase=Pending"]
         if pod_name:
             filters.append(f"metadata.name={pod_name}")
 
         if self._agent.using_jobs:
             return self._agent.get_pods_for_jobs(
-                job_condition="status.active=1", pod_filters=filters, debug_msg="Detecting pending pods: {cmd}"
+                job_condition="status.active=1", pod_filters=filters, debug_msg=debug_msg
             )
-        return self._agent.get_pods(filters=filters, debug_msg="Detecting pending pods: {cmd}")
+        return self._agent.get_pods(filters=filters, debug_msg=debug_msg)
 
     def _get_pod_name(self, pod: dict):
         return get_path(pod, "metadata", "name")
@@ -72,6 +73,11 @@ class PendingPodsDaemon(K8sDaemon):
                     if not namespace:
                         continue
 
+                    updated_pod = self.get_pods(pod_name=pod_name, debug_msg="Refreshing pod information: {cmd}")
+                    if not updated_pod:
+                        continue
+                    pod = updated_pod[0]
+
                     task_id_to_pod[task_id] = pod
 
                     msg = None
@@ -190,32 +196,39 @@ class PendingPodsDaemon(K8sDaemon):
         if not msg or self._last_tasks_msgs.get(task_id, None) == (msg, tags):
             return
         try:
-            # Make sure the task is queued
-            result = self._session.send_request(
-                service='tasks',
-                action='get_all',
-                json={"id": task_id, "only_fields": ["status"]},
-                method=Request.def_method,
-                async_enable=False,
-            )
-            if result.ok:
-                status = get_path(result.json(), 'data', 'tasks', 0, 'status')
-                # if task is in progress, change its status to enqueued
-                if status == "in_progress":
-                    result = self._session.send_request(
-                        service='tasks', action='enqueue',
-                        json={
-                            "task": task_id, "force": True, "queue": self._agent.k8s_pending_queue_id
-                        },
-                        method=Request.def_method,
-                        async_enable=False,
-                    )
-                    if not result.ok:
-                        result_msg = get_path(result.json(), 'meta', 'result_msg')
-                        self.log.debug(
-                            "K8S Glue pods monitor: failed forcing task status change"
-                            " for pending task {}: {}".format(task_id, result_msg)
+            if ENV_POD_MONITOR_DISABLE_ENQUEUE_ON_PREEMPTION.get():
+                # This disables the option to enqueue the task which is supposed to sync the ClearML task status
+                # in case the pod was preempted. In some cases this does not happen due to preemption but due to
+                # cluster communication lag issues that cause us not to discover the pod is no longer pending and
+                # enqueue the task when it's actually already running, thus essentially killing the task
+                pass
+            else:
+                # Make sure the task is queued
+                result = self._session.send_request(
+                    service='tasks',
+                    action='get_all',
+                    json={"id": task_id, "only_fields": ["status"]},
+                    method=Request.def_method,
+                    async_enable=False,
+                )
+                if result.ok:
+                    status = get_path(result.json(), 'data', 'tasks', 0, 'status')
+                    # if task is in progress, change its status to enqueued
+                    if status == "in_progress":
+                        result = self._session.send_request(
+                            service='tasks', action='enqueue',
+                            json={
+                                "task": task_id, "force": True, "queue": self._agent.k8s_pending_queue_id
+                            },
+                            method=Request.def_method,
+                            async_enable=False,
                         )
+                        if not result.ok:
+                            result_msg = get_path(result.json(), 'meta', 'result_msg')
+                            self.log.debug(
+                                "K8S Glue pods monitor: failed forcing task status change"
+                                " for pending task {}: {}".format(task_id, result_msg)
+                            )
 
             # Update task status message
             payload = {"task": task_id, "status_message": "K8S glue status: {}".format(msg)}