diff --git a/docs/trains.conf b/docs/trains.conf index 0b40119..77a04d3 100644 --- a/docs/trains.conf +++ b/docs/trains.conf @@ -83,6 +83,10 @@ agent { # apt cache folder used mapped into docker, for ubuntu package caching docker_apt_cache = ~/.trains/apt-cache + # set to true in order to force "docker pull" before running an experiment using a docker image. + # This makes sure the docker image is updated. + docker_force_pull: false + default_docker: { # default docker image to use when running in docker mode image: "nvidia/cuda" diff --git a/trains_agent/backend_api/config/default/agent.conf b/trains_agent/backend_api/config/default/agent.conf index 937be2c..cf4b850 100644 --- a/trains_agent/backend_api/config/default/agent.conf +++ b/trains_agent/backend_api/config/default/agent.conf @@ -68,6 +68,10 @@ # apt cache folder used mapped into docker, for ubuntu package caching docker_apt_cache = ~/.trains/apt-cache + # set to true in order to force "docker pull" before running an experiment using a docker image. + # This makes sure the docker image is updated. + docker_force_pull: false + default_docker: { # default docker image to use when running in docker mode image: "nvidia/cuda" diff --git a/trains_agent/commands/worker.py b/trains_agent/commands/worker.py index 113efee..d5e4205 100644 --- a/trains_agent/commands/worker.py +++ b/trains_agent/commands/worker.py @@ -354,6 +354,7 @@ class Worker(ServiceCommandSection): self.docker_image_func = None self._docker_image = None self._docker_arguments = None + self._docker_force_pull = self._session.config.get("agent.docker_force_pull", False) self._daemon_foreground = None self._standalone_mode = None @@ -467,6 +468,19 @@ class Worker(ServiceCommandSection): try: # set WORKER ID os.environ['TRAINS_WORKER_ID'] = self.worker_id + + if self._docker_force_pull and docker_image: + full_pull_cmd = ['docker', 'pull', docker_image] + pull_cmd = Argv(*full_pull_cmd) + status, stop_signal_status = self._log_command_output( + task_id=task_id, + cmd=pull_cmd, + stdout_path=temp_stdout_name, + stderr_path=temp_stderr_name, + daemon=True, + stop_signal=stop_signal, + ) + status, stop_signal_status = self._log_command_output( task_id=task_id, cmd=cmd,