diff --git a/trains_agent/commands/worker.py b/trains_agent/commands/worker.py index 7852336..061e032 100644 --- a/trains_agent/commands/worker.py +++ b/trains_agent/commands/worker.py @@ -1868,11 +1868,19 @@ class Worker(ServiceCommandSection): docker = 'docker' base_cmd = [docker, 'run', '-t'] + dockers_nvidia_visible_devices = 'all' gpu_devices = os.environ.get('NVIDIA_VISIBLE_DEVICES', None) if gpu_devices is None or gpu_devices.lower().strip() == 'all': - base_cmd += ['--gpus', 'all', ] + if os.environ.get('TRAINS_DOCKER_SKIP_GPUS_FLAG', None): + dockers_nvidia_visible_devices = os.environ.get('NVIDIA_VISIBLE_DEVICES') or \ + dockers_nvidia_visible_devices + else: + base_cmd += ['--gpus', 'all', ] elif gpu_devices.strip() and gpu_devices.strip() != 'none': - base_cmd += ['--gpus', 'device='+gpu_devices, ] + if os.environ.get('TRAINS_DOCKER_SKIP_GPUS_FLAG', None): + dockers_nvidia_visible_devices = gpu_devices + else: + base_cmd += ['--gpus', 'device='+gpu_devices, ] # We are using --gpu, so we should not pass NVIDIA_VISIBLE_DEVICES, I think. # base_cmd += ['-e', 'NVIDIA_VISIBLE_DEVICES=' + gpu_devices, ]