Add support for skipping container apt installs using CLEARML_AGENT_SKIP_CONTAINER_APT env var in k8s

Add runtime callback support for setting runtime properties per task in k8s
Fix remove task from pending queue and set to failed when kubectl apply fails
This commit is contained in:
allegroai 2024-08-27 23:01:27 +03:00
parent 760bbca74e
commit 6302d43990

View File

@ -69,16 +69,23 @@ class K8sIntegration(Worker):
'echo "ldconfig" >> /etc/profile',
"/usr/sbin/sshd -p {port}"]
CONTAINER_BASH_SCRIPT = [
_CONTAINER_APT_SCRIPT_SECTION = [
"export DEBIAN_FRONTEND='noninteractive'",
"echo 'Binary::apt::APT::Keep-Downloaded-Packages \"true\";' > /etc/apt/apt.conf.d/docker-clean",
"chown -R root /root/.cache/pip",
"apt-get update",
"apt-get install -y git libsm6 libxext6 libxrender-dev libglib2.0-0",
]
CONTAINER_BASH_SCRIPT = [
*(
'[ ! -z "$CLEARML_AGENT_SKIP_CONTAINER_APT" ] || {}'.format(line)
for line in _CONTAINER_APT_SCRIPT_SECTION
),
"declare LOCAL_PYTHON",
"[ ! -z $LOCAL_PYTHON ] || for i in {{15..5}}; do which python3.$i && python3.$i -m pip --version && "
"export LOCAL_PYTHON=$(which python3.$i) && break ; done",
"[ ! -z $LOCAL_PYTHON ] || apt-get install -y python3-pip",
'[ ! -z "$CLEARML_AGENT_SKIP_CONTAINER_APT" ] || [ ! -z "$LOCAL_PYTHON" ] || apt-get install -y python3-pip',
"[ ! -z $LOCAL_PYTHON ] || export LOCAL_PYTHON=python3",
"{extra_bash_init_cmd}",
"[ ! -z $CLEARML_AGENT_NO_UPDATE ] || $LOCAL_PYTHON -m pip install clearml-agent{agent_install_args}",
@ -100,6 +107,7 @@ class K8sIntegration(Worker):
num_of_services=20,
base_pod_num=1,
user_props_cb=None,
runtime_cb=None,
overrides_yaml=None,
template_yaml=None,
clearml_conf_file=None,
@ -127,6 +135,7 @@ class K8sIntegration(Worker):
:param callable user_props_cb: An Optional callable allowing additional user properties to be specified
when scheduling a task to run in a pod. Callable can receive an optional pod number and should return
a dictionary of user properties (name and value). Signature is [[Optional[int]], Dict[str,str]]
:param callable runtime_cb: An Optional callable allowing additional task runtime to be specified (see user_props_cb)
:param str overrides_yaml: YAML file containing the overrides for the pod (optional)
:param str template_yaml: YAML file containing the template for the pod (optional).
If provided the pod is scheduled with kubectl apply and overrides are ignored, otherwise with kubectl run.
@ -161,6 +170,7 @@ class K8sIntegration(Worker):
self.base_pod_num = base_pod_num
self._edit_hyperparams_support = None
self._user_props_cb = user_props_cb
self._runtime_cb = runtime_cb
self.conf_file_content = None
self.overrides_json_string = None
self.template_dict = None
@ -198,6 +208,10 @@ class K8sIntegration(Worker):
self._session.feature_set != "basic" and self._session.check_min_server_version("3.22.3")
)
@property
def agent_label(self):
return self._get_agent_label()
def _create_daemon_instance(self, cls_, **kwargs):
return cls_(agent=self, **kwargs)
@ -430,6 +444,9 @@ class K8sIntegration(Worker):
""" Called when a resource (pod/job) was applied """
pass
def ports_mode_supported_for_task(self, task_id: str, task_data):
return self.ports_mode
def run_one_task(self, queue: Text, task_id: Text, worker_args=None, task_session=None, **_):
print('Pulling task {} launching on kubernetes cluster'.format(task_id))
session = task_session or self._session
@ -501,8 +518,10 @@ class K8sIntegration(Worker):
)
)
if self.ports_mode:
ports_mode = False
if self.ports_mode_supported_for_task(task_id, task_data):
print("Kubernetes looking for available pod to use")
ports_mode = True
# noinspection PyBroadException
try:
@ -513,12 +532,12 @@ class K8sIntegration(Worker):
# Search for a free pod number
pod_count = 0
pod_number = self.base_pod_num
while self.ports_mode or self.max_pods_limit:
while ports_mode or self.max_pods_limit:
pod_number = self.base_pod_num + pod_count
try:
items_count = self._get_pod_count(
extra_labels=[self.limit_pod_label.format(pod_number=pod_number)] if self.ports_mode else None,
extra_labels=[self.limit_pod_label.format(pod_number=pod_number)] if ports_mode else None,
msg="Looking for a free pod/port"
)
except GetPodCountError:
@ -568,11 +587,11 @@ class K8sIntegration(Worker):
break
pod_count += 1
labels = self._get_pod_labels(queue, queue_name)
if self.ports_mode:
labels = self._get_pod_labels(queue, queue_name, task_data)
if ports_mode:
labels.append(self.limit_pod_label.format(pod_number=pod_number))
if self.ports_mode:
if ports_mode:
print("Kubernetes scheduling task id={} on pod={} (pod_count={})".format(task_id, pod_number, pod_count))
else:
print("Kubernetes scheduling task id={}".format(task_id))
@ -611,6 +630,14 @@ class K8sIntegration(Worker):
send_log = "Running kubectl encountered an error: {}".format(error)
self.log.error(send_log)
self.send_logs(task_id, send_log.splitlines())
# Make sure to remove the task from our k8s pending queue
self._session.api_client.queues.remove_task(
task=task_id,
queue=self.k8s_pending_queue_id,
)
# Set task as failed
session.api_client.tasks.failed(task_id, force=True)
return
if pod_name:
@ -618,25 +645,41 @@ class K8sIntegration(Worker):
resource_name=pod_name, namespace=namespace, task_id=task_id, session=session
)
self.set_task_info(
task_id=task_id, task_session=task_session, queue_name=queue_name, ports_mode=ports_mode,
pod_number=pod_number, pod_count=pod_count, task_data=task_data
)
def set_task_info(
self, task_id: str, task_session, task_data, queue_name: str, ports_mode: bool, pod_number, pod_count
):
user_props = {"k8s-queue": str(queue_name)}
if self.ports_mode:
user_props.update(
{
"k8s-pod-number": pod_number,
"k8s-pod-label": labels[0],
"k8s-internal-pod-count": pod_count,
"k8s-agent": self._get_agent_label(),
}
)
runtime = {}
if ports_mode:
agent_label = self._get_agent_label()
user_props.update({
"k8s-pod-number": pod_number,
"k8s-pod-label": agent_label, # backwards-compatibility / legacy
"k8s-internal-pod-count": pod_count,
"k8s-agent": agent_label,
})
if self._user_props_cb:
# noinspection PyBroadException
try:
custom_props = self._user_props_cb(pod_number) if self.ports_mode else self._user_props_cb()
custom_props = self._user_props_cb(pod_number) if ports_mode else self._user_props_cb()
user_props.update(custom_props)
except Exception:
pass
if self._runtime_cb:
# noinspection PyBroadException
try:
custom_runtime = self._runtime_cb(pod_number) if ports_mode else self._runtime_cb()
runtime.update(custom_runtime)
except Exception:
pass
if user_props:
self._set_task_user_properties(
task_id=task_id,
@ -644,7 +687,38 @@ class K8sIntegration(Worker):
**user_props
)
def _get_pod_labels(self, queue, queue_name):
if runtime:
task_runtime = self._get_task_runtime(task_id) or {}
task_runtime.update(runtime)
try:
res = task_session.send_request(
service='tasks', action='edit', method=Request.def_method,
json={
"task": task_id, "force": True, "runtime": task_runtime
},
)
if not res.ok:
raise Exception("failed setting runtime property")
except Exception as ex:
print("WARNING: failed setting custom runtime properties for task '{}': {}".format(task_id, ex))
def _get_task_runtime(self, task_id) -> Optional[dict]:
try:
res = self._session.send_request(
service='tasks', action='get_by_id', method=Request.def_method,
json={"task": task_id, "only_fields": ["runtime"]},
)
if not res.ok:
raise ValueError(f"request returned {res.status_code}")
data = res.json().get("data")
if not data or "task" not in data:
raise ValueError("empty data in result")
return data["task"].get("runtime", {})
except Exception as ex:
print(f"ERROR: Failed getting runtime properties for task {task_id}: {ex}")
def _get_pod_labels(self, queue, queue_name, task_data):
return [
self._get_agent_label(),
"{}={}".format(self.QUEUE_LABEL, self._safe_k8s_label_value(queue)),
@ -1012,6 +1086,9 @@ class K8sIntegration(Worker):
return deleted_pods
def check_if_suspended(self) -> bool:
pass
def run_tasks_loop(self, queues: List[Text], worker_params, **kwargs):
"""
:summary: Pull and run tasks from queues.
@ -1061,6 +1138,11 @@ class K8sIntegration(Worker):
# delete old completed / failed pods
self._cleanup_old_pods(namespaces, extra_msg="Cleanup cycle {cmd}")
if self.check_if_suspended():
print("Agent is suspended, sleeping for {:.1f} seconds".format(self._polling_interval))
sleep(self._polling_interval)
break
# get next task in queue
try:
# print(f"debug> getting tasks for queue {queue}")