From 8f41002845eee6a7ad149cfd4a753671873e7a0d Mon Sep 17 00:00:00 2001
From: allegroai <>
Date: Wed, 24 Jul 2024 17:37:26 +0300
Subject: [PATCH] Add task.script.binary /bin/bash support Fix -m module $env
 to support parsing the $env before launching

---
 clearml_agent/commands/worker.py | 120 ++++++++++++++++++++-----------
 1 file changed, 80 insertions(+), 40 deletions(-)

diff --git a/clearml_agent/commands/worker.py b/clearml_agent/commands/worker.py
index 02d0d8a..752d4cb 100644
--- a/clearml_agent/commands/worker.py
+++ b/clearml_agent/commands/worker.py
@@ -960,11 +960,14 @@ class Worker(ServiceCommandSection):
 
             self.send_logs(
                 task_id=task_id,
-                lines=
-                ['Running Task {} inside {}docker: {} arguments: {}\n'.format(
-                    task_id, "default " if default_docker else '',
-                    docker_image, DockerArgsSanitizer.sanitize_docker_command(self._session, docker_arguments or []))]
-                + (['custom_setup_bash_script:\n{}'.format(docker_setup_script)] if docker_setup_script else []),
+                lines=[
+                    'Running Task {} inside {}docker: {} arguments: {}\n'.format(
+                        task_id,
+                        "default " if default_docker else '',
+                        docker_image,
+                        DockerArgsSanitizer.sanitize_docker_command(self._session, docker_arguments or [])
+                    )
+                ] + (['custom_setup_bash_script:\n{}'.format(docker_setup_script)] if docker_setup_script else []),
                 level="INFO",
                 session=task_session,
             )
@@ -2405,7 +2408,11 @@ class Worker(ServiceCommandSection):
         # noinspection PyBroadException
         try:
             python_ver = task.script.binary
-            python_ver = python_ver.split('/')[-1].replace('python', '')
+            python_ver = python_ver.split('/')[-1]
+            if not python_ver.startswith("python"):
+                return None
+
+            python_ver = python_ver.replace('python', '')
             # if we can cast it, we are good
             return '{}.{}'.format(
                 int(python_ver.partition(".")[0]),
@@ -2709,29 +2716,58 @@ class Worker(ServiceCommandSection):
         # run code
         # print("Running task id [%s]:" % current_task.id)
         print(self._task_logging_pass_control_message.format(current_task.id))
-        extra = ['-u', ]
-        if optimization:
-            extra.append(
-                WorkerParams(optimization=optimization).get_optimization_flag()
-            )
 
         # check if we need to patch entry point script
         if ENV_AGENT_FORCE_TASK_INIT.get():
             patch_add_task_init_call((Path(script_dir) / execution.entry_point).as_posix())
 
+        is_python_binary = (current_task.script.binary or "").split("/")[-1].startswith('python')
+        is_bash_binary = (not is_python_binary and
+                          (current_task.script.binary or "").split("/")[-1] in ('bash', 'zsh', 'sh'))
+
+        if not is_bash_binary and not is_python_binary:
+            print("WARNING binary '{}' not supported, defaulting to python".format(current_task.script.binary))
+            is_python_binary = True
+
+        extra = []
+        if is_python_binary:
+            extra = ['-u', ]
+            if optimization:
+                extra.append(
+                    WorkerParams(optimization=optimization).get_optimization_flag()
+                )
+        elif is_bash_binary:
+            # if we needed some arguments for bash, that's where we will add them
+            extra = []
+
         # check if this is a module load, then load it.
         # noinspection PyBroadException
         try:
-            if current_task.script.binary and current_task.script.binary.startswith('python') and \
-                    execution.entry_point and execution.entry_point.split()[0].strip() == '-m':
-                # we need to split it
-                extra.extend(shlex.split(execution.entry_point))
+            if is_python_binary and execution.entry_point and execution.entry_point.split()[0].strip() == '-m':
+                # do not parse $env when running as user
+                if "$" in execution.entry_point and not ENV_TASK_EXECUTE_AS_USER.get() and is_linux_platform():
+                    print("INFO: parsing environment variables: {}".format(execution.entry_point))
+                    _org_env = copy(os.environ)
+                    os.environ.update(self._get_job_os_envs(current_task, log_level))
+                    os.environ.update(self._get_task_os_env(self._session.config, current_task) or dict())
+                    extra.extend(shlex.split(os.path.expandvars(execution.entry_point)))
+                    # restore (just in case, so we do not interfere with our local execution)
+                    os.environ = _org_env
+                else:
+                    extra.extend(shlex.split(execution.entry_point))
             else:
                 extra.append(execution.entry_point)
         except Exception:
             extra.append(execution.entry_point)
 
-        command = self.package_api.get_python_command(extra)
+        if is_python_binary:
+            command = self.package_api.get_python_command(extra)
+        elif is_bash_binary:
+            command = Argv(Path(os.environ.get("SHELL", "/bin/bash")), *extra)
+        else:
+            # actually we should not be here because we default to python is we do not recognize the binary
+            raise ValueError("Task execution binary requested {} is not supported!".format(current_task.script.binary))
+
         print("[{}]$ {}".format(execution.working_dir, command.pretty()))
 
         if freeze:
@@ -2742,29 +2778,14 @@ class Worker(ServiceCommandSection):
 
         print("Environment setup completed successfully\n")
 
-        sdk_env = {
-            # config_file updated in session.py
-            "task_id": current_task.id,
-            "log_level": log_level,
-            "log_to_backend": "0",
-            "config_file": self._session.config_file,  # The config file is the tmp file that clearml_agent created
-        }
-        os.environ.update(
-            {
-                sdk_key: str(value)
-                for key, value in sdk_env.items()
-                for sdk_key in ENVIRONMENT_SDK_PARAMS[key]
-            }
-        )
+        # update the jobs global environment variable
+        os.environ.update(self._get_job_os_envs(current_task, log_level))
 
         if repo_info:
             self._update_commit_id(current_task.id, execution, repo_info)
 
-        # get Task Environments and update the process
-        if self._session.config.get('agent.enable_task_env', None):
-            hyper_params = self._get_task_os_env(current_task)
-            if hyper_params:
-                os.environ.update(hyper_params)
+        # get Task Environments variables and update the process (if enabled)
+        os.environ.update(self._get_task_os_env(self._session.config, current_task) or dict())
 
         # Add the script CWD to the python path
         if repo_info and repo_info.root and self._session.config.get('agent.force_git_root_python_path', None):
@@ -2864,7 +2885,23 @@ class Worker(ServiceCommandSection):
 
         return 1 if exit_code is None else exit_code
 
-    def _get_task_os_env(self, current_task):
+    def _get_job_os_envs(self, current_task, log_level):
+        sdk_env = {
+            # config_file updated in session.py
+            "task_id": current_task.id,
+            "log_level": log_level,
+            "log_to_backend": "0",
+            "config_file": self._session.config_file,  # The config file is the tmp file that clearml_agent created
+        }
+        return {
+                sdk_key: str(value)
+                for key, value in sdk_env.items()
+                for sdk_key in ENVIRONMENT_SDK_PARAMS[key]
+            }
+
+    def _get_task_os_env(self, config, current_task):
+        if not config.get('agent.enable_task_env', None):
+            return None
         if not self._session.check_min_api_version('2.9'):
             return None
         # noinspection PyBroadException
@@ -2893,6 +2930,7 @@ class Worker(ServiceCommandSection):
                 status_reason=e.args[0], status_message=self._task_status_change_message
             )
             self.exit(e.args[0])
+
         if "\\" in execution.working_dir:
             warning(
                 'Working dir "{}" contains backslashes. '
@@ -3567,9 +3605,8 @@ class Worker(ServiceCommandSection):
                     override_interpreter_path = skip_pip_venv_install
                 else:
                     print(
-                        "Warning: interpreter {} could not be found. Reverting to the default interpreter resolution".format(
-                            skip_pip_venv_install
-                        )
+                        "Warning: interpreter {} could not be found. "
+                        "Reverting to the default interpreter resolution".format(skip_pip_venv_install)
                     )
             if override_interpreter_path:
                 print("Python interpreter {} is set from environment var".format(override_interpreter_path))
@@ -4199,7 +4236,10 @@ class Worker(ServiceCommandSection):
                     host_ssh_cache = new_ssh_cache.replace(k8s_pod_mnt, k8s_node_mnt)
                 except Exception:
                     raise ValueError('Error: could not copy .ssh directory into: {}'.format(new_ssh_cache))
-                self.debug("Copied host SSH cache to: {}, host {}".format(new_ssh_cache, host_ssh_cache), context="docker")
+                self.debug(
+                    "Copied host SSH cache to: {}, host {}".format(new_ssh_cache, host_ssh_cache),
+                    context="docker"
+                )
 
         base_cmd += ['-e', 'CLEARML_WORKER_ID='+worker_id, ]
         # update the docker image, so the system knows where it runs