mirror of
https://github.com/clearml/clearml-agent
synced 2025-05-04 20:21:10 +00:00
Add support for abort callback registration
This commit is contained in:
parent
ec216198a0
commit
9006c2d28f
@ -67,8 +67,10 @@ from clearml_agent.definitions import (
|
||||
ENV_SSH_AUTH_SOCK,
|
||||
ENV_AGENT_SKIP_PIP_VENV_INSTALL,
|
||||
ENV_EXTRA_DOCKER_ARGS,
|
||||
ENV_CUSTOM_BUILD_SCRIPT, ENV_AGENT_SKIP_PYTHON_ENV_INSTALL, WORKING_STANDALONE_DIR,
|
||||
|
||||
ENV_CUSTOM_BUILD_SCRIPT,
|
||||
ENV_AGENT_SKIP_PYTHON_ENV_INSTALL,
|
||||
WORKING_STANDALONE_DIR,
|
||||
ENV_DEBUG_INFO,
|
||||
)
|
||||
from clearml_agent.definitions import WORKING_REPOSITORY_DIR, PIP_EXTRA_INDICES
|
||||
from clearml_agent.errors import (
|
||||
@ -406,6 +408,10 @@ class TaskStopSignal(object):
|
||||
self.worker_id = command.worker_id
|
||||
self._task_reset_state_counter = 0
|
||||
self.task_id = task_id
|
||||
self._support_callback = None
|
||||
self._active_callback_timestamp = None
|
||||
self._active_callback_timeout = None
|
||||
self._abort_callback_max_timeout = float(self.session.config.get('agent.abort_callback_max_timeout', 1800))
|
||||
|
||||
def test(self):
|
||||
# type: () -> TaskStopReason
|
||||
@ -423,11 +429,84 @@ class TaskStopSignal(object):
|
||||
# make sure we break nothing
|
||||
return TaskStopSignal.default
|
||||
|
||||
def _wait_for_abort_callback(self):
|
||||
if not self._support_callback:
|
||||
return None
|
||||
|
||||
if self._active_callback_timestamp:
|
||||
if time() - self._active_callback_timestamp < self._active_callback_timeout:
|
||||
# print("waiting for callback to complete")
|
||||
self.command.log("waiting for callback to complete")
|
||||
# check state
|
||||
cb_completed = None
|
||||
try:
|
||||
task_info = self.session.get(
|
||||
service="tasks", action="get_all", version="2.13", id=[self.task_id],
|
||||
only_fields=["status", "status_message", "runtime._abort_callback_completed"])
|
||||
cb_completed = task_info['tasks'][0]['runtime'].get('_abort_callback_completed', None)
|
||||
except: # noqa
|
||||
pass
|
||||
|
||||
if not bool(cb_completed):
|
||||
return False
|
||||
|
||||
msg = "Task abort callback completed in {:.2f} seconds".format(
|
||||
time() - self._active_callback_timestamp)
|
||||
else:
|
||||
msg = "Task abort callback timed out [timeout: {}, elapsed: {:.2f}]".format(
|
||||
self._active_callback_timeout, time() - self._active_callback_timestamp)
|
||||
|
||||
self.command.send_logs(self.task_id, ["### " + msg + " ###"], session=self.session)
|
||||
return True
|
||||
|
||||
# check if abort callback is turned on
|
||||
cb_completed = None
|
||||
# TODO: add retries on network error with timeout
|
||||
try:
|
||||
task_info = self.session.get(
|
||||
service="tasks", action="get_all", version="2.13", id=[self.task_id],
|
||||
only_fields=["status", "status_message", "runtime._abort_callback_timeout",
|
||||
"runtime._abort_poll_freq", "runtime._abort_callback_completed"])
|
||||
abort_timeout = task_info['tasks'][0]['runtime'].get('_abort_callback_timeout', 0)
|
||||
poll_timeout = task_info['tasks'][0]['runtime'].get('_abort_poll_freq', 0)
|
||||
cb_completed = task_info['tasks'][0]['runtime'].get('_abort_callback_completed', None)
|
||||
except: # noqa
|
||||
abort_timeout = None
|
||||
poll_timeout = None
|
||||
|
||||
if not abort_timeout:
|
||||
# no callback set we can leave
|
||||
return None
|
||||
|
||||
try:
|
||||
timeout = min(float(abort_timeout) + float(poll_timeout), self._abort_callback_max_timeout)
|
||||
except: # noqa
|
||||
self.command.log("Failed parsing runtime timeout shutdown callback [{}, {}]".format(
|
||||
abort_timeout, poll_timeout))
|
||||
return None
|
||||
|
||||
self.command.send_logs(
|
||||
self.task_id,
|
||||
["### Task abort callback timeout set, waiting for max {} sec ###".format(timeout)],
|
||||
session=self.session
|
||||
)
|
||||
|
||||
self._active_callback_timestamp = time()
|
||||
self._active_callback_timeout = timeout
|
||||
return bool(cb_completed)
|
||||
|
||||
def was_abort_function_called(self):
|
||||
return bool(self._active_callback_timestamp)
|
||||
|
||||
def _test(self):
|
||||
# type: () -> TaskStopReason
|
||||
"""
|
||||
"Unsafe" version of test()
|
||||
"""
|
||||
if self._support_callback is None:
|
||||
# test if backend support callback
|
||||
self._support_callback = self.session.check_min_api_version("2.13")
|
||||
|
||||
task_info = get_task(
|
||||
self.session, self.task_id, only_fields=["status", "status_message"]
|
||||
)
|
||||
@ -439,10 +518,16 @@ class TaskStopSignal(object):
|
||||
"task status_message has '%s', task will terminate",
|
||||
self.stopping_message,
|
||||
)
|
||||
# actively waiting for task to complete
|
||||
if self._wait_for_abort_callback() is False:
|
||||
return TaskStopReason.no_stop
|
||||
return TaskStopReason.stopped
|
||||
|
||||
if status in self.unexpected_statuses: # ## and "worker" not in message:
|
||||
self.command.log("unexpected status change, task will terminate")
|
||||
# actively waiting for task to complete
|
||||
if self._wait_for_abort_callback() is False:
|
||||
return TaskStopReason.no_stop
|
||||
return TaskStopReason.status_changed
|
||||
|
||||
if status == self.statuses.created:
|
||||
@ -451,13 +536,18 @@ class TaskStopSignal(object):
|
||||
>= self._number_of_consecutive_reset_tests
|
||||
):
|
||||
self.command.log("task was reset, task will terminate")
|
||||
# actively waiting for task to complete
|
||||
if self._wait_for_abort_callback() is False:
|
||||
return TaskStopReason.no_stop
|
||||
return TaskStopReason.reset
|
||||
|
||||
self._task_reset_state_counter += 1
|
||||
warning_msg = "Warning: Task {} was reset! if state is consistent we shall terminate ({}/{}).".format(
|
||||
self.task_id,
|
||||
self._task_reset_state_counter,
|
||||
self._number_of_consecutive_reset_tests,
|
||||
)
|
||||
|
||||
if self.events_service:
|
||||
self.events_service.send_log_events(
|
||||
self.worker_id,
|
||||
@ -526,6 +616,7 @@ class Worker(ServiceCommandSection):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Worker, self).__init__(*args, **kwargs)
|
||||
self._debug_context = ENV_DEBUG_INFO.get()
|
||||
self.monitor = None
|
||||
self.log = self._session.get_logger(__name__)
|
||||
self.register_signal_handler()
|
||||
@ -1726,6 +1817,10 @@ class Worker(ServiceCommandSection):
|
||||
printed_lines, stderr_pos_count = _print_file(stderr_path, stderr_pos_count)
|
||||
stderr_line_count += report_lines(printed_lines, "stderr")
|
||||
|
||||
# make sure that if the abort function was called, the task is marked as aborted
|
||||
if stop_signal and stop_signal.was_abort_function_called():
|
||||
stop_reason = TaskStopReason.stopped
|
||||
|
||||
return status, stop_reason
|
||||
|
||||
def _check_if_internal_agent_started(self, printed_lines, service_mode_internal_agent_started, task_id):
|
||||
@ -2883,8 +2978,8 @@ class Worker(ServiceCommandSection):
|
||||
if self._session.debug_mode:
|
||||
self.log(traceback.format_exc())
|
||||
|
||||
def debug(self, message):
|
||||
if self._session.debug_mode:
|
||||
def debug(self, message, context=None):
|
||||
if self._session.debug_mode and (not context or context == self._debug_context):
|
||||
print("clearml_agent: {}".format(message))
|
||||
|
||||
@staticmethod
|
||||
@ -3313,7 +3408,7 @@ class Worker(ServiceCommandSection):
|
||||
mounted_vcs_cache = temp_config.get(
|
||||
"agent.docker_internal_mounts.vcs_cache", '/root/.clearml/vcs-cache')
|
||||
mounted_venv_dir = temp_config.get(
|
||||
"agent.docker_internal_mounts.venv_build", '/root/.clearml/venvs-builds')
|
||||
"agent.docker_internal_mounts.venv_build", '~/.clearml/venvs-builds')
|
||||
temp_config.put("sdk.storage.cache.default_base_dir", mounted_cache_dir)
|
||||
temp_config.put("agent.pip_download_cache.path", mounted_pip_dl_dir)
|
||||
temp_config.put("agent.vcs_cache.path", mounted_vcs_cache)
|
||||
@ -3351,27 +3446,36 @@ class Worker(ServiceCommandSection):
|
||||
)
|
||||
|
||||
def _get_docker_config_cmd(self, temp_config, clean_api_credentials=False, **kwargs):
|
||||
self.debug("Setting up docker config command")
|
||||
host_cache = Path(os.path.expandvars(
|
||||
self._session.config["sdk.storage.cache.default_base_dir"])).expanduser().as_posix()
|
||||
self.debug("host_cache: {}".format(host_cache))
|
||||
host_pip_dl = Path(os.path.expandvars(
|
||||
self._session.config["agent.pip_download_cache.path"])).expanduser().as_posix()
|
||||
self.debug("host_pip_dl: {}".format(host_pip_dl))
|
||||
host_vcs_cache = Path(os.path.expandvars(
|
||||
self._session.config["agent.vcs_cache.path"])).expanduser().as_posix()
|
||||
self.debug("host_vcs_cache: {}".format(host_vcs_cache))
|
||||
host_venvs_cache = Path(os.path.expandvars(
|
||||
self._session.config["agent.venvs_cache.path"])).expanduser().as_posix() \
|
||||
if self._session.config.get("agent.venvs_cache.path", None) else None
|
||||
self.debug("host_venvs_cache: {}".format(host_venvs_cache))
|
||||
host_ssh_cache = self._host_ssh_cache
|
||||
self.debug("host_ssh_cache: {}".format(host_ssh_cache))
|
||||
|
||||
host_apt_cache = Path(os.path.expandvars(self._session.config.get(
|
||||
"agent.docker_apt_cache", '~/.clearml/apt-cache'))).expanduser().as_posix()
|
||||
self.debug("host_apt_cache: {}".format(host_apt_cache))
|
||||
host_pip_cache = Path(os.path.expandvars(self._session.config.get(
|
||||
"agent.docker_pip_cache", '~/.clearml/pip-cache'))).expanduser().as_posix()
|
||||
self.debug("host_pip_cache: {}".format(host_pip_cache))
|
||||
|
||||
if self.poetry.enabled:
|
||||
host_poetry_cache = Path(os.path.expandvars(self._session.config.get(
|
||||
"agent.docker_poetry_cache", '~/.clearml/poetry-cache'))).expanduser().as_posix()
|
||||
else:
|
||||
host_poetry_cache = None
|
||||
self.debug("host_poetry_cache: {}".format(host_poetry_cache))
|
||||
|
||||
# make sure all folders are valid
|
||||
if host_apt_cache:
|
||||
@ -3491,9 +3595,8 @@ class Worker(ServiceCommandSection):
|
||||
|
||||
return len(output.splitlines()) if output else 0
|
||||
|
||||
@classmethod
|
||||
def _get_docker_cmd(
|
||||
cls,
|
||||
self,
|
||||
worker_id, parent_worker_id,
|
||||
docker_image, docker_arguments,
|
||||
python_version,
|
||||
@ -3518,6 +3621,7 @@ class Worker(ServiceCommandSection):
|
||||
mount_ssh=None, mount_apt_cache=None, mount_pip_cache=None, mount_poetry_cache=None,
|
||||
env_task_id=None,
|
||||
):
|
||||
self.debug("Constructing docker command", context="docker")
|
||||
docker = 'docker'
|
||||
|
||||
base_cmd = [docker, 'run', '-t']
|
||||
@ -3552,8 +3656,10 @@ class Worker(ServiceCommandSection):
|
||||
base_cmd += [str(a) for a in extra_docker_arguments if a]
|
||||
|
||||
# set docker labels
|
||||
base_cmd += ['-l', cls._worker_label.format(worker_id)]
|
||||
base_cmd += ['-l', cls._parent_worker_label.format(parent_worker_id)]
|
||||
base_cmd += ['-l', self._worker_label.format(worker_id)]
|
||||
base_cmd += ['-l', self._parent_worker_label.format(parent_worker_id)]
|
||||
|
||||
self.debug("Command: {}".format(base_cmd), context="docker")
|
||||
|
||||
# check if running inside a kubernetes
|
||||
if ENV_DOCKER_HOST_MOUNT.get() or (os.environ.get('KUBERNETES_SERVICE_HOST') and
|
||||
@ -3570,6 +3676,8 @@ class Worker(ServiceCommandSection):
|
||||
pass
|
||||
base_cmd += ['-e', 'NVIDIA_VISIBLE_DEVICES={}'.format(dockers_nvidia_visible_devices)]
|
||||
|
||||
self.debug("Running in k8s: {}".format(base_cmd), context="docker")
|
||||
|
||||
# check if we need to map host folders
|
||||
if ENV_DOCKER_HOST_MOUNT.get():
|
||||
# expect CLEARML_AGENT_K8S_HOST_MOUNT = '/mnt/host/data:/root/.clearml'
|
||||
@ -3577,6 +3685,7 @@ class Worker(ServiceCommandSection):
|
||||
# search and replace all the host folders with the k8s
|
||||
host_mounts = [host_apt_cache, host_pip_cache, host_poetry_cache, host_pip_dl,
|
||||
host_cache, host_vcs_cache, host_venvs_cache]
|
||||
self.debug("Mapping host mounts: {}".format(host_mounts), context="docker")
|
||||
for i, m in enumerate(host_mounts):
|
||||
if not m:
|
||||
continue
|
||||
@ -3585,6 +3694,7 @@ class Worker(ServiceCommandSection):
|
||||
host_mounts[i] = None
|
||||
else:
|
||||
host_mounts[i] = m.replace(k8s_pod_mnt, k8s_node_mnt, 1)
|
||||
self.debug("Mapped host mounts: {}".format(host_mounts), context="docker")
|
||||
host_apt_cache, host_pip_cache, host_poetry_cache, host_pip_dl, \
|
||||
host_cache, host_vcs_cache, host_venvs_cache = host_mounts
|
||||
|
||||
@ -3598,6 +3708,8 @@ class Worker(ServiceCommandSection):
|
||||
except Exception:
|
||||
raise ValueError('Error: could not copy configuration file into: {}'.format(new_conf_file))
|
||||
|
||||
self.debug("Config file target: {}, host: {}".format(new_conf_file, conf_file), context="docker")
|
||||
|
||||
if host_ssh_cache:
|
||||
new_ssh_cache = os.path.join(k8s_pod_mnt, '.clearml_agent.{}.ssh'.format(quote(worker_id, safe="")))
|
||||
try:
|
||||
@ -3606,6 +3718,7 @@ class Worker(ServiceCommandSection):
|
||||
host_ssh_cache = new_ssh_cache.replace(k8s_pod_mnt, k8s_node_mnt)
|
||||
except Exception:
|
||||
raise ValueError('Error: could not copy .ssh directory into: {}'.format(new_ssh_cache))
|
||||
self.debug("Copied host SSH cache to: {}, host {}".format(new_ssh_cache, host_ssh_cache), context="docker")
|
||||
|
||||
base_cmd += ['-e', 'CLEARML_WORKER_ID='+worker_id, ]
|
||||
# update the docker image, so the system knows where it runs
|
||||
@ -3702,6 +3815,15 @@ class Worker(ServiceCommandSection):
|
||||
mount_pip_cache = mount_pip_cache or '/root/.cache/pip'
|
||||
mount_poetry_cache = mount_poetry_cache or '/root/.cache/pypoetry'
|
||||
|
||||
self.debug(
|
||||
"Adding mounts: host_ssh_cache={}, host_apt_cache={}, host_pip_cache={}, host_poetry_cache={}, "
|
||||
"host_pip_dl={}, host_cache={}, host_vcs_cache={}, host_venvs_cache={}".format(
|
||||
host_ssh_cache, host_apt_cache, host_pip_cache, host_poetry_cache, host_pip_dl, host_cache,
|
||||
host_vcs_cache, host_venvs_cache,
|
||||
),
|
||||
context="docker"
|
||||
)
|
||||
|
||||
base_cmd += (
|
||||
(['--name', name] if name else []) +
|
||||
['-v', conf_file+':'+DOCKER_ROOT_CONF_FILE] +
|
||||
@ -3983,6 +4105,13 @@ class Worker(ServiceCommandSection):
|
||||
if self._session.feature_set == "basic":
|
||||
raise ValueError("Server does not support --use-owner-token option")
|
||||
|
||||
role = self._session.get_decoded_token(self._session.token).get("identity", {}).get("role", None)
|
||||
if role and role not in ["admin", "root", "system"]:
|
||||
raise ValueError(
|
||||
"User role not suitable for --use-owner-token option (requires at least admin,"
|
||||
" found {})".format(role)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user