Add support for abort callback registration

This commit is contained in:
allegroai 2022-08-29 18:06:59 +03:00
parent ec216198a0
commit 9006c2d28f

View File

@ -67,8 +67,10 @@ from clearml_agent.definitions import (
ENV_SSH_AUTH_SOCK,
ENV_AGENT_SKIP_PIP_VENV_INSTALL,
ENV_EXTRA_DOCKER_ARGS,
ENV_CUSTOM_BUILD_SCRIPT, ENV_AGENT_SKIP_PYTHON_ENV_INSTALL, WORKING_STANDALONE_DIR,
ENV_CUSTOM_BUILD_SCRIPT,
ENV_AGENT_SKIP_PYTHON_ENV_INSTALL,
WORKING_STANDALONE_DIR,
ENV_DEBUG_INFO,
)
from clearml_agent.definitions import WORKING_REPOSITORY_DIR, PIP_EXTRA_INDICES
from clearml_agent.errors import (
@ -406,6 +408,10 @@ class TaskStopSignal(object):
self.worker_id = command.worker_id
self._task_reset_state_counter = 0
self.task_id = task_id
self._support_callback = None
self._active_callback_timestamp = None
self._active_callback_timeout = None
self._abort_callback_max_timeout = float(self.session.config.get('agent.abort_callback_max_timeout', 1800))
def test(self):
# type: () -> TaskStopReason
@ -423,11 +429,84 @@ class TaskStopSignal(object):
# make sure we break nothing
return TaskStopSignal.default
def _wait_for_abort_callback(self):
if not self._support_callback:
return None
if self._active_callback_timestamp:
if time() - self._active_callback_timestamp < self._active_callback_timeout:
# print("waiting for callback to complete")
self.command.log("waiting for callback to complete")
# check state
cb_completed = None
try:
task_info = self.session.get(
service="tasks", action="get_all", version="2.13", id=[self.task_id],
only_fields=["status", "status_message", "runtime._abort_callback_completed"])
cb_completed = task_info['tasks'][0]['runtime'].get('_abort_callback_completed', None)
except: # noqa
pass
if not bool(cb_completed):
return False
msg = "Task abort callback completed in {:.2f} seconds".format(
time() - self._active_callback_timestamp)
else:
msg = "Task abort callback timed out [timeout: {}, elapsed: {:.2f}]".format(
self._active_callback_timeout, time() - self._active_callback_timestamp)
self.command.send_logs(self.task_id, ["### " + msg + " ###"], session=self.session)
return True
# check if abort callback is turned on
cb_completed = None
# TODO: add retries on network error with timeout
try:
task_info = self.session.get(
service="tasks", action="get_all", version="2.13", id=[self.task_id],
only_fields=["status", "status_message", "runtime._abort_callback_timeout",
"runtime._abort_poll_freq", "runtime._abort_callback_completed"])
abort_timeout = task_info['tasks'][0]['runtime'].get('_abort_callback_timeout', 0)
poll_timeout = task_info['tasks'][0]['runtime'].get('_abort_poll_freq', 0)
cb_completed = task_info['tasks'][0]['runtime'].get('_abort_callback_completed', None)
except: # noqa
abort_timeout = None
poll_timeout = None
if not abort_timeout:
# no callback set we can leave
return None
try:
timeout = min(float(abort_timeout) + float(poll_timeout), self._abort_callback_max_timeout)
except: # noqa
self.command.log("Failed parsing runtime timeout shutdown callback [{}, {}]".format(
abort_timeout, poll_timeout))
return None
self.command.send_logs(
self.task_id,
["### Task abort callback timeout set, waiting for max {} sec ###".format(timeout)],
session=self.session
)
self._active_callback_timestamp = time()
self._active_callback_timeout = timeout
return bool(cb_completed)
def was_abort_function_called(self):
return bool(self._active_callback_timestamp)
def _test(self):
# type: () -> TaskStopReason
"""
"Unsafe" version of test()
"""
if self._support_callback is None:
# test if backend support callback
self._support_callback = self.session.check_min_api_version("2.13")
task_info = get_task(
self.session, self.task_id, only_fields=["status", "status_message"]
)
@ -439,10 +518,16 @@ class TaskStopSignal(object):
"task status_message has '%s', task will terminate",
self.stopping_message,
)
# actively waiting for task to complete
if self._wait_for_abort_callback() is False:
return TaskStopReason.no_stop
return TaskStopReason.stopped
if status in self.unexpected_statuses: # ## and "worker" not in message:
self.command.log("unexpected status change, task will terminate")
# actively waiting for task to complete
if self._wait_for_abort_callback() is False:
return TaskStopReason.no_stop
return TaskStopReason.status_changed
if status == self.statuses.created:
@ -451,13 +536,18 @@ class TaskStopSignal(object):
>= self._number_of_consecutive_reset_tests
):
self.command.log("task was reset, task will terminate")
# actively waiting for task to complete
if self._wait_for_abort_callback() is False:
return TaskStopReason.no_stop
return TaskStopReason.reset
self._task_reset_state_counter += 1
warning_msg = "Warning: Task {} was reset! if state is consistent we shall terminate ({}/{}).".format(
self.task_id,
self._task_reset_state_counter,
self._number_of_consecutive_reset_tests,
)
if self.events_service:
self.events_service.send_log_events(
self.worker_id,
@ -526,6 +616,7 @@ class Worker(ServiceCommandSection):
def __init__(self, *args, **kwargs):
super(Worker, self).__init__(*args, **kwargs)
self._debug_context = ENV_DEBUG_INFO.get()
self.monitor = None
self.log = self._session.get_logger(__name__)
self.register_signal_handler()
@ -1726,6 +1817,10 @@ class Worker(ServiceCommandSection):
printed_lines, stderr_pos_count = _print_file(stderr_path, stderr_pos_count)
stderr_line_count += report_lines(printed_lines, "stderr")
# make sure that if the abort function was called, the task is marked as aborted
if stop_signal and stop_signal.was_abort_function_called():
stop_reason = TaskStopReason.stopped
return status, stop_reason
def _check_if_internal_agent_started(self, printed_lines, service_mode_internal_agent_started, task_id):
@ -2883,8 +2978,8 @@ class Worker(ServiceCommandSection):
if self._session.debug_mode:
self.log(traceback.format_exc())
def debug(self, message):
if self._session.debug_mode:
def debug(self, message, context=None):
if self._session.debug_mode and (not context or context == self._debug_context):
print("clearml_agent: {}".format(message))
@staticmethod
@ -3313,7 +3408,7 @@ class Worker(ServiceCommandSection):
mounted_vcs_cache = temp_config.get(
"agent.docker_internal_mounts.vcs_cache", '/root/.clearml/vcs-cache')
mounted_venv_dir = temp_config.get(
"agent.docker_internal_mounts.venv_build", '/root/.clearml/venvs-builds')
"agent.docker_internal_mounts.venv_build", '~/.clearml/venvs-builds')
temp_config.put("sdk.storage.cache.default_base_dir", mounted_cache_dir)
temp_config.put("agent.pip_download_cache.path", mounted_pip_dl_dir)
temp_config.put("agent.vcs_cache.path", mounted_vcs_cache)
@ -3351,27 +3446,36 @@ class Worker(ServiceCommandSection):
)
def _get_docker_config_cmd(self, temp_config, clean_api_credentials=False, **kwargs):
self.debug("Setting up docker config command")
host_cache = Path(os.path.expandvars(
self._session.config["sdk.storage.cache.default_base_dir"])).expanduser().as_posix()
self.debug("host_cache: {}".format(host_cache))
host_pip_dl = Path(os.path.expandvars(
self._session.config["agent.pip_download_cache.path"])).expanduser().as_posix()
self.debug("host_pip_dl: {}".format(host_pip_dl))
host_vcs_cache = Path(os.path.expandvars(
self._session.config["agent.vcs_cache.path"])).expanduser().as_posix()
self.debug("host_vcs_cache: {}".format(host_vcs_cache))
host_venvs_cache = Path(os.path.expandvars(
self._session.config["agent.venvs_cache.path"])).expanduser().as_posix() \
if self._session.config.get("agent.venvs_cache.path", None) else None
self.debug("host_venvs_cache: {}".format(host_venvs_cache))
host_ssh_cache = self._host_ssh_cache
self.debug("host_ssh_cache: {}".format(host_ssh_cache))
host_apt_cache = Path(os.path.expandvars(self._session.config.get(
"agent.docker_apt_cache", '~/.clearml/apt-cache'))).expanduser().as_posix()
self.debug("host_apt_cache: {}".format(host_apt_cache))
host_pip_cache = Path(os.path.expandvars(self._session.config.get(
"agent.docker_pip_cache", '~/.clearml/pip-cache'))).expanduser().as_posix()
self.debug("host_pip_cache: {}".format(host_pip_cache))
if self.poetry.enabled:
host_poetry_cache = Path(os.path.expandvars(self._session.config.get(
"agent.docker_poetry_cache", '~/.clearml/poetry-cache'))).expanduser().as_posix()
else:
host_poetry_cache = None
self.debug("host_poetry_cache: {}".format(host_poetry_cache))
# make sure all folders are valid
if host_apt_cache:
@ -3491,9 +3595,8 @@ class Worker(ServiceCommandSection):
return len(output.splitlines()) if output else 0
@classmethod
def _get_docker_cmd(
cls,
self,
worker_id, parent_worker_id,
docker_image, docker_arguments,
python_version,
@ -3518,6 +3621,7 @@ class Worker(ServiceCommandSection):
mount_ssh=None, mount_apt_cache=None, mount_pip_cache=None, mount_poetry_cache=None,
env_task_id=None,
):
self.debug("Constructing docker command", context="docker")
docker = 'docker'
base_cmd = [docker, 'run', '-t']
@ -3552,8 +3656,10 @@ class Worker(ServiceCommandSection):
base_cmd += [str(a) for a in extra_docker_arguments if a]
# set docker labels
base_cmd += ['-l', cls._worker_label.format(worker_id)]
base_cmd += ['-l', cls._parent_worker_label.format(parent_worker_id)]
base_cmd += ['-l', self._worker_label.format(worker_id)]
base_cmd += ['-l', self._parent_worker_label.format(parent_worker_id)]
self.debug("Command: {}".format(base_cmd), context="docker")
# check if running inside a kubernetes
if ENV_DOCKER_HOST_MOUNT.get() or (os.environ.get('KUBERNETES_SERVICE_HOST') and
@ -3570,6 +3676,8 @@ class Worker(ServiceCommandSection):
pass
base_cmd += ['-e', 'NVIDIA_VISIBLE_DEVICES={}'.format(dockers_nvidia_visible_devices)]
self.debug("Running in k8s: {}".format(base_cmd), context="docker")
# check if we need to map host folders
if ENV_DOCKER_HOST_MOUNT.get():
# expect CLEARML_AGENT_K8S_HOST_MOUNT = '/mnt/host/data:/root/.clearml'
@ -3577,6 +3685,7 @@ class Worker(ServiceCommandSection):
# search and replace all the host folders with the k8s
host_mounts = [host_apt_cache, host_pip_cache, host_poetry_cache, host_pip_dl,
host_cache, host_vcs_cache, host_venvs_cache]
self.debug("Mapping host mounts: {}".format(host_mounts), context="docker")
for i, m in enumerate(host_mounts):
if not m:
continue
@ -3585,6 +3694,7 @@ class Worker(ServiceCommandSection):
host_mounts[i] = None
else:
host_mounts[i] = m.replace(k8s_pod_mnt, k8s_node_mnt, 1)
self.debug("Mapped host mounts: {}".format(host_mounts), context="docker")
host_apt_cache, host_pip_cache, host_poetry_cache, host_pip_dl, \
host_cache, host_vcs_cache, host_venvs_cache = host_mounts
@ -3598,6 +3708,8 @@ class Worker(ServiceCommandSection):
except Exception:
raise ValueError('Error: could not copy configuration file into: {}'.format(new_conf_file))
self.debug("Config file target: {}, host: {}".format(new_conf_file, conf_file), context="docker")
if host_ssh_cache:
new_ssh_cache = os.path.join(k8s_pod_mnt, '.clearml_agent.{}.ssh'.format(quote(worker_id, safe="")))
try:
@ -3606,6 +3718,7 @@ class Worker(ServiceCommandSection):
host_ssh_cache = new_ssh_cache.replace(k8s_pod_mnt, k8s_node_mnt)
except Exception:
raise ValueError('Error: could not copy .ssh directory into: {}'.format(new_ssh_cache))
self.debug("Copied host SSH cache to: {}, host {}".format(new_ssh_cache, host_ssh_cache), context="docker")
base_cmd += ['-e', 'CLEARML_WORKER_ID='+worker_id, ]
# update the docker image, so the system knows where it runs
@ -3702,6 +3815,15 @@ class Worker(ServiceCommandSection):
mount_pip_cache = mount_pip_cache or '/root/.cache/pip'
mount_poetry_cache = mount_poetry_cache or '/root/.cache/pypoetry'
self.debug(
"Adding mounts: host_ssh_cache={}, host_apt_cache={}, host_pip_cache={}, host_poetry_cache={}, "
"host_pip_dl={}, host_cache={}, host_vcs_cache={}, host_venvs_cache={}".format(
host_ssh_cache, host_apt_cache, host_pip_cache, host_poetry_cache, host_pip_dl, host_cache,
host_vcs_cache, host_venvs_cache,
),
context="docker"
)
base_cmd += (
(['--name', name] if name else []) +
['-v', conf_file+':'+DOCKER_ROOT_CONF_FILE] +
@ -3983,6 +4105,13 @@ class Worker(ServiceCommandSection):
if self._session.feature_set == "basic":
raise ValueError("Server does not support --use-owner-token option")
role = self._session.get_decoded_token(self._session.token).get("identity", {}).get("role", None)
if role and role not in ["admin", "root", "system"]:
raise ValueError(
"User role not suitable for --use-owner-token option (requires at least admin,"
" found {})".format(role)
)
if __name__ == "__main__":
pass