diff --git a/docs/trains.conf b/docs/trains.conf index 77a04d3..13e215e 100644 --- a/docs/trains.conf +++ b/docs/trains.conf @@ -83,6 +83,13 @@ agent { # apt cache folder used mapped into docker, for ubuntu package caching docker_apt_cache = ~/.trains/apt-cache + # optional arguments to pass to docker image + # these are local for this agent and will not be updated in the experiment's docker_cmd section + # extra_docker_arguments: ["--ipc=host", ] + + # optional shell script to run in docker when started before the experiment is started + # extra_docker_shell_script: ["apt-get install -y bindfs", ] + # set to true in order to force "docker pull" before running an experiment using a docker image. # This makes sure the docker image is updated. docker_force_pull: false diff --git a/trains_agent/backend_api/config/default/agent.conf b/trains_agent/backend_api/config/default/agent.conf index cf4b850..6a4a43f 100644 --- a/trains_agent/backend_api/config/default/agent.conf +++ b/trains_agent/backend_api/config/default/agent.conf @@ -68,6 +68,13 @@ # apt cache folder used mapped into docker, for ubuntu package caching docker_apt_cache = ~/.trains/apt-cache + # optional arguments to pass to docker image + # these are local for this agent and will not be updated in the experiment's docker_cmd section + # extra_docker_arguments: ["--ipc=host", ] + + # optional shell script to run in docker when started before the experiment is started + # extra_docker_shell_script: ["apt-get install -y bindfs", ] + # set to true in order to force "docker pull" before running an experiment using a docker image. # This makes sure the docker image is updated. docker_force_pull: false diff --git a/trains_agent/commands/worker.py b/trains_agent/commands/worker.py index e120587..eb3e0f7 100644 --- a/trains_agent/commands/worker.py +++ b/trains_agent/commands/worker.py @@ -354,6 +354,8 @@ class Worker(ServiceCommandSection): self.docker_image_func = None self._docker_image = None self._docker_arguments = None + self._extra_docker_arguments = self._session.config.get("agent.extra_docker_arguments", None) + self._extra_shell_script = self._session.config.get("agent.extra_docker_shell_script", None) self._docker_force_pull = self._session.config.get("agent.docker_force_pull", False) self._daemon_foreground = None self._standalone_mode = None @@ -414,6 +416,7 @@ class Worker(ServiceCommandSection): ) ) + docker_image = None if self.docker_image_func: try: response = get_task(self._session, task_id, only_fields=["execution.docker_cmd"]) @@ -1755,9 +1758,19 @@ class Worker(ServiceCommandSection): # store docker arguments self._docker_image = docker_image self._docker_arguments = docker_arguments + + extra_shell_script_str = "" + if self._extra_shell_script: + cmds = self._extra_shell_script + if not isinstance(cmds, (list, tuple)): + cmds = [cmds] + extra_shell_script_str = " ; ".join(map(str, cmds)) + " ; " + docker_cmd = dict(worker_id=self.worker_id, # docker_image=docker_image, # docker_arguments=docker_arguments, + extra_docker_arguments=self._extra_docker_arguments, + extra_shell_script=extra_shell_script_str, python_version=python_version, conf_file=self.temp_config_path, host_apt_cache=host_apt_cache, host_pip_cache=host_pip_cache, @@ -1776,7 +1789,8 @@ class Worker(ServiceCommandSection): host_ssh_cache, host_cache, mounted_cache, host_pip_dl, mounted_pip_dl, - host_vcs_cache, mounted_vcs_cache, standalone_mode=False): + host_vcs_cache, mounted_vcs_cache, + standalone_mode=False, extra_docker_arguments=None, extra_shell_script=None): docker = 'docker' base_cmd = [docker, 'run', '-t'] @@ -1793,6 +1807,11 @@ class Worker(ServiceCommandSection): if isinstance(docker_arguments, (list, tuple)) else [docker_arguments] base_cmd += [a for a in docker_arguments if a] + if extra_docker_arguments: + extra_docker_arguments = [extra_docker_arguments] \ + if isinstance(extra_docker_arguments, six.string_types) else extra_docker_arguments + base_cmd += [str(a) for a in extra_docker_arguments if a] + base_cmd += ['-e', 'TRAINS_WORKER_ID='+worker_id, ] if host_ssh_cache: @@ -1831,6 +1850,7 @@ class Worker(ServiceCommandSection): '-v', host_vcs_cache+':'+mounted_vcs_cache, '--rm', docker_image, 'bash', '-c', update_scheme + + extra_shell_script + "NVIDIA_VISIBLE_DEVICES=all {python} -u -m trains_agent ".format(python=python_version) ]