diff --git a/clearml_agent/commands/worker.py b/clearml_agent/commands/worker.py index e3b196b..fda7c50 100644 --- a/clearml_agent/commands/worker.py +++ b/clearml_agent/commands/worker.py @@ -67,8 +67,10 @@ from clearml_agent.definitions import ( ENV_SSH_AUTH_SOCK, ENV_AGENT_SKIP_PIP_VENV_INSTALL, ENV_EXTRA_DOCKER_ARGS, - ENV_CUSTOM_BUILD_SCRIPT, ENV_AGENT_SKIP_PYTHON_ENV_INSTALL, WORKING_STANDALONE_DIR, - + ENV_CUSTOM_BUILD_SCRIPT, + ENV_AGENT_SKIP_PYTHON_ENV_INSTALL, + WORKING_STANDALONE_DIR, + ENV_DEBUG_INFO, ) from clearml_agent.definitions import WORKING_REPOSITORY_DIR, PIP_EXTRA_INDICES from clearml_agent.errors import ( @@ -406,6 +408,10 @@ class TaskStopSignal(object): self.worker_id = command.worker_id self._task_reset_state_counter = 0 self.task_id = task_id + self._support_callback = None + self._active_callback_timestamp = None + self._active_callback_timeout = None + self._abort_callback_max_timeout = float(self.session.config.get('agent.abort_callback_max_timeout', 1800)) def test(self): # type: () -> TaskStopReason @@ -423,11 +429,84 @@ class TaskStopSignal(object): # make sure we break nothing return TaskStopSignal.default + def _wait_for_abort_callback(self): + if not self._support_callback: + return None + + if self._active_callback_timestamp: + if time() - self._active_callback_timestamp < self._active_callback_timeout: + # print("waiting for callback to complete") + self.command.log("waiting for callback to complete") + # check state + cb_completed = None + try: + task_info = self.session.get( + service="tasks", action="get_all", version="2.13", id=[self.task_id], + only_fields=["status", "status_message", "runtime._abort_callback_completed"]) + cb_completed = task_info['tasks'][0]['runtime'].get('_abort_callback_completed', None) + except: # noqa + pass + + if not bool(cb_completed): + return False + + msg = "Task abort callback completed in {:.2f} seconds".format( + time() - self._active_callback_timestamp) + else: + msg = "Task abort callback timed out [timeout: {}, elapsed: {:.2f}]".format( + self._active_callback_timeout, time() - self._active_callback_timestamp) + + self.command.send_logs(self.task_id, ["### " + msg + " ###"], session=self.session) + return True + + # check if abort callback is turned on + cb_completed = None + # TODO: add retries on network error with timeout + try: + task_info = self.session.get( + service="tasks", action="get_all", version="2.13", id=[self.task_id], + only_fields=["status", "status_message", "runtime._abort_callback_timeout", + "runtime._abort_poll_freq", "runtime._abort_callback_completed"]) + abort_timeout = task_info['tasks'][0]['runtime'].get('_abort_callback_timeout', 0) + poll_timeout = task_info['tasks'][0]['runtime'].get('_abort_poll_freq', 0) + cb_completed = task_info['tasks'][0]['runtime'].get('_abort_callback_completed', None) + except: # noqa + abort_timeout = None + poll_timeout = None + + if not abort_timeout: + # no callback set we can leave + return None + + try: + timeout = min(float(abort_timeout) + float(poll_timeout), self._abort_callback_max_timeout) + except: # noqa + self.command.log("Failed parsing runtime timeout shutdown callback [{}, {}]".format( + abort_timeout, poll_timeout)) + return None + + self.command.send_logs( + self.task_id, + ["### Task abort callback timeout set, waiting for max {} sec ###".format(timeout)], + session=self.session + ) + + self._active_callback_timestamp = time() + self._active_callback_timeout = timeout + return bool(cb_completed) + + def was_abort_function_called(self): + return bool(self._active_callback_timestamp) + def _test(self): # type: () -> TaskStopReason """ "Unsafe" version of test() """ + if self._support_callback is None: + # test if backend support callback + self._support_callback = self.session.check_min_api_version("2.13") + task_info = get_task( self.session, self.task_id, only_fields=["status", "status_message"] ) @@ -439,10 +518,16 @@ class TaskStopSignal(object): "task status_message has '%s', task will terminate", self.stopping_message, ) + # actively waiting for task to complete + if self._wait_for_abort_callback() is False: + return TaskStopReason.no_stop return TaskStopReason.stopped if status in self.unexpected_statuses: # ## and "worker" not in message: self.command.log("unexpected status change, task will terminate") + # actively waiting for task to complete + if self._wait_for_abort_callback() is False: + return TaskStopReason.no_stop return TaskStopReason.status_changed if status == self.statuses.created: @@ -451,13 +536,18 @@ class TaskStopSignal(object): >= self._number_of_consecutive_reset_tests ): self.command.log("task was reset, task will terminate") + # actively waiting for task to complete + if self._wait_for_abort_callback() is False: + return TaskStopReason.no_stop return TaskStopReason.reset + self._task_reset_state_counter += 1 warning_msg = "Warning: Task {} was reset! if state is consistent we shall terminate ({}/{}).".format( self.task_id, self._task_reset_state_counter, self._number_of_consecutive_reset_tests, ) + if self.events_service: self.events_service.send_log_events( self.worker_id, @@ -526,6 +616,7 @@ class Worker(ServiceCommandSection): def __init__(self, *args, **kwargs): super(Worker, self).__init__(*args, **kwargs) + self._debug_context = ENV_DEBUG_INFO.get() self.monitor = None self.log = self._session.get_logger(__name__) self.register_signal_handler() @@ -1726,6 +1817,10 @@ class Worker(ServiceCommandSection): printed_lines, stderr_pos_count = _print_file(stderr_path, stderr_pos_count) stderr_line_count += report_lines(printed_lines, "stderr") + # make sure that if the abort function was called, the task is marked as aborted + if stop_signal and stop_signal.was_abort_function_called(): + stop_reason = TaskStopReason.stopped + return status, stop_reason def _check_if_internal_agent_started(self, printed_lines, service_mode_internal_agent_started, task_id): @@ -2883,8 +2978,8 @@ class Worker(ServiceCommandSection): if self._session.debug_mode: self.log(traceback.format_exc()) - def debug(self, message): - if self._session.debug_mode: + def debug(self, message, context=None): + if self._session.debug_mode and (not context or context == self._debug_context): print("clearml_agent: {}".format(message)) @staticmethod @@ -3313,7 +3408,7 @@ class Worker(ServiceCommandSection): mounted_vcs_cache = temp_config.get( "agent.docker_internal_mounts.vcs_cache", '/root/.clearml/vcs-cache') mounted_venv_dir = temp_config.get( - "agent.docker_internal_mounts.venv_build", '/root/.clearml/venvs-builds') + "agent.docker_internal_mounts.venv_build", '~/.clearml/venvs-builds') temp_config.put("sdk.storage.cache.default_base_dir", mounted_cache_dir) temp_config.put("agent.pip_download_cache.path", mounted_pip_dl_dir) temp_config.put("agent.vcs_cache.path", mounted_vcs_cache) @@ -3351,27 +3446,36 @@ class Worker(ServiceCommandSection): ) def _get_docker_config_cmd(self, temp_config, clean_api_credentials=False, **kwargs): + self.debug("Setting up docker config command") host_cache = Path(os.path.expandvars( self._session.config["sdk.storage.cache.default_base_dir"])).expanduser().as_posix() + self.debug("host_cache: {}".format(host_cache)) host_pip_dl = Path(os.path.expandvars( self._session.config["agent.pip_download_cache.path"])).expanduser().as_posix() + self.debug("host_pip_dl: {}".format(host_pip_dl)) host_vcs_cache = Path(os.path.expandvars( self._session.config["agent.vcs_cache.path"])).expanduser().as_posix() + self.debug("host_vcs_cache: {}".format(host_vcs_cache)) host_venvs_cache = Path(os.path.expandvars( self._session.config["agent.venvs_cache.path"])).expanduser().as_posix() \ if self._session.config.get("agent.venvs_cache.path", None) else None + self.debug("host_venvs_cache: {}".format(host_venvs_cache)) host_ssh_cache = self._host_ssh_cache + self.debug("host_ssh_cache: {}".format(host_ssh_cache)) host_apt_cache = Path(os.path.expandvars(self._session.config.get( "agent.docker_apt_cache", '~/.clearml/apt-cache'))).expanduser().as_posix() + self.debug("host_apt_cache: {}".format(host_apt_cache)) host_pip_cache = Path(os.path.expandvars(self._session.config.get( "agent.docker_pip_cache", '~/.clearml/pip-cache'))).expanduser().as_posix() + self.debug("host_pip_cache: {}".format(host_pip_cache)) if self.poetry.enabled: host_poetry_cache = Path(os.path.expandvars(self._session.config.get( "agent.docker_poetry_cache", '~/.clearml/poetry-cache'))).expanduser().as_posix() else: host_poetry_cache = None + self.debug("host_poetry_cache: {}".format(host_poetry_cache)) # make sure all folders are valid if host_apt_cache: @@ -3491,9 +3595,8 @@ class Worker(ServiceCommandSection): return len(output.splitlines()) if output else 0 - @classmethod def _get_docker_cmd( - cls, + self, worker_id, parent_worker_id, docker_image, docker_arguments, python_version, @@ -3518,6 +3621,7 @@ class Worker(ServiceCommandSection): mount_ssh=None, mount_apt_cache=None, mount_pip_cache=None, mount_poetry_cache=None, env_task_id=None, ): + self.debug("Constructing docker command", context="docker") docker = 'docker' base_cmd = [docker, 'run', '-t'] @@ -3552,8 +3656,10 @@ class Worker(ServiceCommandSection): base_cmd += [str(a) for a in extra_docker_arguments if a] # set docker labels - base_cmd += ['-l', cls._worker_label.format(worker_id)] - base_cmd += ['-l', cls._parent_worker_label.format(parent_worker_id)] + base_cmd += ['-l', self._worker_label.format(worker_id)] + base_cmd += ['-l', self._parent_worker_label.format(parent_worker_id)] + + self.debug("Command: {}".format(base_cmd), context="docker") # check if running inside a kubernetes if ENV_DOCKER_HOST_MOUNT.get() or (os.environ.get('KUBERNETES_SERVICE_HOST') and @@ -3570,6 +3676,8 @@ class Worker(ServiceCommandSection): pass base_cmd += ['-e', 'NVIDIA_VISIBLE_DEVICES={}'.format(dockers_nvidia_visible_devices)] + self.debug("Running in k8s: {}".format(base_cmd), context="docker") + # check if we need to map host folders if ENV_DOCKER_HOST_MOUNT.get(): # expect CLEARML_AGENT_K8S_HOST_MOUNT = '/mnt/host/data:/root/.clearml' @@ -3577,6 +3685,7 @@ class Worker(ServiceCommandSection): # search and replace all the host folders with the k8s host_mounts = [host_apt_cache, host_pip_cache, host_poetry_cache, host_pip_dl, host_cache, host_vcs_cache, host_venvs_cache] + self.debug("Mapping host mounts: {}".format(host_mounts), context="docker") for i, m in enumerate(host_mounts): if not m: continue @@ -3585,6 +3694,7 @@ class Worker(ServiceCommandSection): host_mounts[i] = None else: host_mounts[i] = m.replace(k8s_pod_mnt, k8s_node_mnt, 1) + self.debug("Mapped host mounts: {}".format(host_mounts), context="docker") host_apt_cache, host_pip_cache, host_poetry_cache, host_pip_dl, \ host_cache, host_vcs_cache, host_venvs_cache = host_mounts @@ -3598,6 +3708,8 @@ class Worker(ServiceCommandSection): except Exception: raise ValueError('Error: could not copy configuration file into: {}'.format(new_conf_file)) + self.debug("Config file target: {}, host: {}".format(new_conf_file, conf_file), context="docker") + if host_ssh_cache: new_ssh_cache = os.path.join(k8s_pod_mnt, '.clearml_agent.{}.ssh'.format(quote(worker_id, safe=""))) try: @@ -3606,6 +3718,7 @@ class Worker(ServiceCommandSection): host_ssh_cache = new_ssh_cache.replace(k8s_pod_mnt, k8s_node_mnt) except Exception: raise ValueError('Error: could not copy .ssh directory into: {}'.format(new_ssh_cache)) + self.debug("Copied host SSH cache to: {}, host {}".format(new_ssh_cache, host_ssh_cache), context="docker") base_cmd += ['-e', 'CLEARML_WORKER_ID='+worker_id, ] # update the docker image, so the system knows where it runs @@ -3702,6 +3815,15 @@ class Worker(ServiceCommandSection): mount_pip_cache = mount_pip_cache or '/root/.cache/pip' mount_poetry_cache = mount_poetry_cache or '/root/.cache/pypoetry' + self.debug( + "Adding mounts: host_ssh_cache={}, host_apt_cache={}, host_pip_cache={}, host_poetry_cache={}, " + "host_pip_dl={}, host_cache={}, host_vcs_cache={}, host_venvs_cache={}".format( + host_ssh_cache, host_apt_cache, host_pip_cache, host_poetry_cache, host_pip_dl, host_cache, + host_vcs_cache, host_venvs_cache, + ), + context="docker" + ) + base_cmd += ( (['--name', name] if name else []) + ['-v', conf_file+':'+DOCKER_ROOT_CONF_FILE] + @@ -3983,6 +4105,13 @@ class Worker(ServiceCommandSection): if self._session.feature_set == "basic": raise ValueError("Server does not support --use-owner-token option") + role = self._session.get_decoded_token(self._session.token).get("identity", {}).get("role", None) + if role and role not in ["admin", "root", "system"]: + raise ValueError( + "User role not suitable for --use-owner-token option (requires at least admin," + " found {})".format(role) + ) + if __name__ == "__main__": pass