From 6302d43990433f1f085b9bb569da2906c00a1765 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 27 Aug 2024 23:01:27 +0300 Subject: [PATCH] Add support for skipping container apt installs using CLEARML_AGENT_SKIP_CONTAINER_APT env var in k8s Add runtime callback support for setting runtime properties per task in k8s Fix remove task from pending queue and set to failed when kubectl apply fails --- clearml_agent/glue/k8s.py | 120 ++++++++++++++++++++++++++++++++------ 1 file changed, 101 insertions(+), 19 deletions(-) diff --git a/clearml_agent/glue/k8s.py b/clearml_agent/glue/k8s.py index 798d779..daa1b99 100644 --- a/clearml_agent/glue/k8s.py +++ b/clearml_agent/glue/k8s.py @@ -69,16 +69,23 @@ class K8sIntegration(Worker): 'echo "ldconfig" >> /etc/profile', "/usr/sbin/sshd -p {port}"] - CONTAINER_BASH_SCRIPT = [ + _CONTAINER_APT_SCRIPT_SECTION = [ "export DEBIAN_FRONTEND='noninteractive'", "echo 'Binary::apt::APT::Keep-Downloaded-Packages \"true\";' > /etc/apt/apt.conf.d/docker-clean", "chown -R root /root/.cache/pip", "apt-get update", "apt-get install -y git libsm6 libxext6 libxrender-dev libglib2.0-0", + ] + + CONTAINER_BASH_SCRIPT = [ + *( + '[ ! -z "$CLEARML_AGENT_SKIP_CONTAINER_APT" ] || {}'.format(line) + for line in _CONTAINER_APT_SCRIPT_SECTION + ), "declare LOCAL_PYTHON", "[ ! -z $LOCAL_PYTHON ] || for i in {{15..5}}; do which python3.$i && python3.$i -m pip --version && " "export LOCAL_PYTHON=$(which python3.$i) && break ; done", - "[ ! -z $LOCAL_PYTHON ] || apt-get install -y python3-pip", + '[ ! -z "$CLEARML_AGENT_SKIP_CONTAINER_APT" ] || [ ! -z "$LOCAL_PYTHON" ] || apt-get install -y python3-pip', "[ ! -z $LOCAL_PYTHON ] || export LOCAL_PYTHON=python3", "{extra_bash_init_cmd}", "[ ! -z $CLEARML_AGENT_NO_UPDATE ] || $LOCAL_PYTHON -m pip install clearml-agent{agent_install_args}", @@ -100,6 +107,7 @@ class K8sIntegration(Worker): num_of_services=20, base_pod_num=1, user_props_cb=None, + runtime_cb=None, overrides_yaml=None, template_yaml=None, clearml_conf_file=None, @@ -127,6 +135,7 @@ class K8sIntegration(Worker): :param callable user_props_cb: An Optional callable allowing additional user properties to be specified when scheduling a task to run in a pod. Callable can receive an optional pod number and should return a dictionary of user properties (name and value). Signature is [[Optional[int]], Dict[str,str]] + :param callable runtime_cb: An Optional callable allowing additional task runtime to be specified (see user_props_cb) :param str overrides_yaml: YAML file containing the overrides for the pod (optional) :param str template_yaml: YAML file containing the template for the pod (optional). If provided the pod is scheduled with kubectl apply and overrides are ignored, otherwise with kubectl run. @@ -161,6 +170,7 @@ class K8sIntegration(Worker): self.base_pod_num = base_pod_num self._edit_hyperparams_support = None self._user_props_cb = user_props_cb + self._runtime_cb = runtime_cb self.conf_file_content = None self.overrides_json_string = None self.template_dict = None @@ -198,6 +208,10 @@ class K8sIntegration(Worker): self._session.feature_set != "basic" and self._session.check_min_server_version("3.22.3") ) + @property + def agent_label(self): + return self._get_agent_label() + def _create_daemon_instance(self, cls_, **kwargs): return cls_(agent=self, **kwargs) @@ -430,6 +444,9 @@ class K8sIntegration(Worker): """ Called when a resource (pod/job) was applied """ pass + def ports_mode_supported_for_task(self, task_id: str, task_data): + return self.ports_mode + def run_one_task(self, queue: Text, task_id: Text, worker_args=None, task_session=None, **_): print('Pulling task {} launching on kubernetes cluster'.format(task_id)) session = task_session or self._session @@ -501,8 +518,10 @@ class K8sIntegration(Worker): ) ) - if self.ports_mode: + ports_mode = False + if self.ports_mode_supported_for_task(task_id, task_data): print("Kubernetes looking for available pod to use") + ports_mode = True # noinspection PyBroadException try: @@ -513,12 +532,12 @@ class K8sIntegration(Worker): # Search for a free pod number pod_count = 0 pod_number = self.base_pod_num - while self.ports_mode or self.max_pods_limit: + while ports_mode or self.max_pods_limit: pod_number = self.base_pod_num + pod_count try: items_count = self._get_pod_count( - extra_labels=[self.limit_pod_label.format(pod_number=pod_number)] if self.ports_mode else None, + extra_labels=[self.limit_pod_label.format(pod_number=pod_number)] if ports_mode else None, msg="Looking for a free pod/port" ) except GetPodCountError: @@ -568,11 +587,11 @@ class K8sIntegration(Worker): break pod_count += 1 - labels = self._get_pod_labels(queue, queue_name) - if self.ports_mode: + labels = self._get_pod_labels(queue, queue_name, task_data) + if ports_mode: labels.append(self.limit_pod_label.format(pod_number=pod_number)) - if self.ports_mode: + if ports_mode: print("Kubernetes scheduling task id={} on pod={} (pod_count={})".format(task_id, pod_number, pod_count)) else: print("Kubernetes scheduling task id={}".format(task_id)) @@ -611,6 +630,14 @@ class K8sIntegration(Worker): send_log = "Running kubectl encountered an error: {}".format(error) self.log.error(send_log) self.send_logs(task_id, send_log.splitlines()) + + # Make sure to remove the task from our k8s pending queue + self._session.api_client.queues.remove_task( + task=task_id, + queue=self.k8s_pending_queue_id, + ) + # Set task as failed + session.api_client.tasks.failed(task_id, force=True) return if pod_name: @@ -618,25 +645,41 @@ class K8sIntegration(Worker): resource_name=pod_name, namespace=namespace, task_id=task_id, session=session ) + self.set_task_info( + task_id=task_id, task_session=task_session, queue_name=queue_name, ports_mode=ports_mode, + pod_number=pod_number, pod_count=pod_count, task_data=task_data + ) + + def set_task_info( + self, task_id: str, task_session, task_data, queue_name: str, ports_mode: bool, pod_number, pod_count + ): user_props = {"k8s-queue": str(queue_name)} - if self.ports_mode: - user_props.update( - { - "k8s-pod-number": pod_number, - "k8s-pod-label": labels[0], - "k8s-internal-pod-count": pod_count, - "k8s-agent": self._get_agent_label(), - } - ) + runtime = {} + if ports_mode: + agent_label = self._get_agent_label() + user_props.update({ + "k8s-pod-number": pod_number, + "k8s-pod-label": agent_label, # backwards-compatibility / legacy + "k8s-internal-pod-count": pod_count, + "k8s-agent": agent_label, + }) if self._user_props_cb: # noinspection PyBroadException try: - custom_props = self._user_props_cb(pod_number) if self.ports_mode else self._user_props_cb() + custom_props = self._user_props_cb(pod_number) if ports_mode else self._user_props_cb() user_props.update(custom_props) except Exception: pass + if self._runtime_cb: + # noinspection PyBroadException + try: + custom_runtime = self._runtime_cb(pod_number) if ports_mode else self._runtime_cb() + runtime.update(custom_runtime) + except Exception: + pass + if user_props: self._set_task_user_properties( task_id=task_id, @@ -644,7 +687,38 @@ class K8sIntegration(Worker): **user_props ) - def _get_pod_labels(self, queue, queue_name): + if runtime: + task_runtime = self._get_task_runtime(task_id) or {} + task_runtime.update(runtime) + + try: + res = task_session.send_request( + service='tasks', action='edit', method=Request.def_method, + json={ + "task": task_id, "force": True, "runtime": task_runtime + }, + ) + if not res.ok: + raise Exception("failed setting runtime property") + except Exception as ex: + print("WARNING: failed setting custom runtime properties for task '{}': {}".format(task_id, ex)) + + def _get_task_runtime(self, task_id) -> Optional[dict]: + try: + res = self._session.send_request( + service='tasks', action='get_by_id', method=Request.def_method, + json={"task": task_id, "only_fields": ["runtime"]}, + ) + if not res.ok: + raise ValueError(f"request returned {res.status_code}") + data = res.json().get("data") + if not data or "task" not in data: + raise ValueError("empty data in result") + return data["task"].get("runtime", {}) + except Exception as ex: + print(f"ERROR: Failed getting runtime properties for task {task_id}: {ex}") + + def _get_pod_labels(self, queue, queue_name, task_data): return [ self._get_agent_label(), "{}={}".format(self.QUEUE_LABEL, self._safe_k8s_label_value(queue)), @@ -1012,6 +1086,9 @@ class K8sIntegration(Worker): return deleted_pods + def check_if_suspended(self) -> bool: + pass + def run_tasks_loop(self, queues: List[Text], worker_params, **kwargs): """ :summary: Pull and run tasks from queues. @@ -1061,6 +1138,11 @@ class K8sIntegration(Worker): # delete old completed / failed pods self._cleanup_old_pods(namespaces, extra_msg="Cleanup cycle {cmd}") + if self.check_if_suspended(): + print("Agent is suspended, sleeping for {:.1f} seconds".format(self._polling_interval)) + sleep(self._polling_interval) + break + # get next task in queue try: # print(f"debug> getting tasks for queue {queue}")