Refactor k8s glue template handling

This commit is contained in:
allegroai 2022-07-22 22:43:07 +03:00
parent a5a797ec5e
commit e687418194

View File

@ -119,7 +119,7 @@ class K8sIntegration(Worker):
when scheduling a task to run in a pod. Callable can receive an optional pod number and should return when scheduling a task to run in a pod. Callable can receive an optional pod number and should return
a dictionary of user properties (name and value). Signature is [[Optional[int]], Dict[str,str]] a dictionary of user properties (name and value). Signature is [[Optional[int]], Dict[str,str]]
:param str overrides_yaml: YAML file containing the overrides for the pod (optional) :param str overrides_yaml: YAML file containing the overrides for the pod (optional)
:param str template_yaml: YAML file containing the template for the pod (optional). :param str template_yaml: YAML file containing the template for the pod (optional).
If provided the pod is scheduled with kubectl apply and overrides are ignored, otherwise with kubectl run. If provided the pod is scheduled with kubectl apply and overrides are ignored, otherwise with kubectl run.
:param str clearml_conf_file: clearml.conf file to be use by the pod itself (optional) :param str clearml_conf_file: clearml.conf file to be use by the pod itself (optional)
:param str extra_bash_init_script: Additional bash script to run before starting the Task inside the container :param str extra_bash_init_script: Additional bash script to run before starting the Task inside the container
@ -503,8 +503,10 @@ class K8sIntegration(Worker):
queue=queue queue=queue
) )
if self.template_dict: template = self._resolve_template(task_session, task_data, queue)
output, error = self._kubectl_apply(**kubectl_kwargs)
if template:
output, error = self._kubectl_apply(template=template, **kubectl_kwargs)
else: else:
output, error = self._kubectl_run(task_data=task_data, **kubectl_kwargs) output, error = self._kubectl_run(task_data=task_data, **kubectl_kwargs)
@ -566,8 +568,10 @@ class K8sIntegration(Worker):
return {target: results} if results else {} return {target: results} if results else {}
return results return results
def _kubectl_apply(self, create_clearml_conf, docker_image, docker_args, docker_bash, labels, queue, task_id): def _kubectl_apply(
template = deepcopy(self.template_dict) self, create_clearml_conf, docker_image, docker_args, docker_bash, labels, queue, task_id, template=None
):
template = template or deepcopy(self.template_dict)
template.setdefault('apiVersion', 'v1') template.setdefault('apiVersion', 'v1')
template['kind'] = 'Pod' template['kind'] = 'Pod'
template.setdefault('metadata', {}) template.setdefault('metadata', {})
@ -753,13 +757,13 @@ class K8sIntegration(Worker):
# get next task in queue # get next task in queue
try: try:
response = get_next_task( response = self._get_next_task(queue=queue, get_task_info=self._impersonate_as_task_owner)
self._session, queue=queue, get_task_info=self._impersonate_as_task_owner
)
except Exception as e: except Exception as e:
print("Warning: Could not access task queue [{}], error: {}".format(queue, e)) print("Warning: Could not access task queue [{}], error: {}".format(queue, e))
continue continue
else: else:
if not response:
continue
try: try:
task_id = response["entry"]["task"] task_id = response["entry"]["task"]
except (KeyError, TypeError, AttributeError): except (KeyError, TypeError, AttributeError):
@ -820,6 +824,15 @@ class K8sIntegration(Worker):
log_level=logging.INFO, foreground=True, docker=False, **kwargs, log_level=logging.INFO, foreground=True, docker=False, **kwargs,
) )
def _get_next_task(self, queue, get_task_info):
return get_next_task(
self._session, queue=queue, get_task_info=get_task_info
)
def _resolve_template(self, task_session, task_data, queue):
if self.template_dict:
return deepcopy(self.template_dict)
@classmethod @classmethod
def get_ssh_server_bash(cls, ssh_port_number): def get_ssh_server_bash(cls, ssh_port_number):
return ' ; '.join(line.format(port=ssh_port_number) for line in cls.BASH_INSTALL_SSH_CMD) return ' ; '.join(line.format(port=ssh_port_number) for line in cls.BASH_INSTALL_SSH_CMD)