diff --git a/cmd/nvidia-container-toolkit/container_config.go b/cmd/nvidia-container-toolkit/container_config.go index 5920f471..64d737ce 100644 --- a/cmd/nvidia-container-toolkit/container_config.go +++ b/cmd/nvidia-container-toolkit/container_config.go @@ -165,7 +165,7 @@ func isPrivileged(s *Spec) bool { return false } -func getDevicesFromEnvvar(env map[string]string, legacyImage bool) *string { +func getDevicesFromEnvvar(image image.CUDA) *string { // Build a list of envvars to consider. envVars := []string{envNVVisibleDevices} if envSwarmGPU != nil { @@ -173,35 +173,14 @@ func getDevicesFromEnvvar(env map[string]string, legacyImage bool) *string { envVars = append([]string{*envSwarmGPU}, envVars...) } - // Grab a reference to devices from the first envvar - // in the list that actually exists in the environment. - var devices *string - for _, envVar := range envVars { - if devs, ok := env[envVar]; ok { - devices = &devs - break - } - } - - // Environment variable unset with legacy image: default to "all". - if devices == nil && legacyImage { - all := "all" - return &all - } - - // Environment variable unset or empty or "void": return nil - if devices == nil || len(*devices) == 0 || *devices == "void" { + devices := image.DevicesFromEnvvars(envVars...) + if len(devices) == 0 { return nil } - // Environment variable set to "none": reset to "". - if *devices == "none" { - empty := "" - return &empty - } + devicesString := strings.Join(devices, ",") - // Any other value. - return devices + return &devicesString } func getDevicesFromMounts(mounts []Mount) *string { @@ -241,7 +220,7 @@ func getDevicesFromMounts(mounts []Mount) *string { return &ret } -func getDevices(hookConfig *HookConfig, env map[string]string, mounts []Mount, privileged bool, legacyImage bool) *string { +func getDevices(hookConfig *HookConfig, image image.CUDA, mounts []Mount, privileged bool) *string { // If enabled, try and get the device list from volume mounts first if hookConfig.AcceptDeviceListAsVolumeMounts { devices := getDevicesFromMounts(mounts) @@ -251,7 +230,7 @@ func getDevices(hookConfig *HookConfig, env map[string]string, mounts []Mount, p } // Fallback to reading from the environment variable if privileges are correct - devices := getDevicesFromEnvvar(env, legacyImage) + devices := getDevicesFromEnvvar(image) if devices == nil { return nil } @@ -307,7 +286,7 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, mounts []Mount, p legacyImage := image.IsLegacy() var devices string - if d := getDevices(hookConfig, image, mounts, privileged, legacyImage); d != nil { + if d := getDevices(hookConfig, image, mounts, privileged); d != nil { devices = *d } else { // 'nil' devices means this is not a GPU container. diff --git a/cmd/nvidia-container-toolkit/container_config_test.go b/cmd/nvidia-container-toolkit/container_config_test.go index 19d0bf23..f1660db1 100644 --- a/cmd/nvidia-container-toolkit/container_config_test.go +++ b/cmd/nvidia-container-toolkit/container_config_test.go @@ -4,6 +4,7 @@ import ( "path/filepath" "testing" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/stretchr/testify/require" ) @@ -671,7 +672,7 @@ func TestDeviceListSourcePriority(t *testing.T) { hookConfig := getDefaultHookConfig() hookConfig.AcceptEnvvarUnprivileged = tc.acceptUnprivileged hookConfig.AcceptDeviceListAsVolumeMounts = tc.acceptMounts - devices = getDevices(&hookConfig, env, tc.mountDevices, tc.privileged, false) + devices = getDevices(&hookConfig, env, tc.mountDevices, tc.privileged) } // For all other tests, just grab the devices and check the results @@ -693,7 +694,6 @@ func TestGetDevicesFromEnvvar(t *testing.T) { description string envSwarmGPU *string env map[string]string - legacyImage bool expectedDevices *string }{ { @@ -729,13 +729,15 @@ func TestGetDevicesFromEnvvar(t *testing.T) { description: "NVIDIA_VISIBLE_DEVICES set returns value for legacy image", env: map[string]string{ envNVVisibleDevices: gpuID, + envCUDAVersion: "legacy", }, - legacyImage: true, expectedDevices: &gpuID, }, { - description: "empty env returns all for legacy image", - legacyImage: true, + description: "empty env returns all for legacy image", + env: map[string]string{ + envCUDAVersion: "legacy", + }, expectedDevices: &all, }, // Add the `DOCKER_RESOURCE_GPUS` envvar and ensure that this is ignored when @@ -781,16 +783,16 @@ func TestGetDevicesFromEnvvar(t *testing.T) { env: map[string]string{ envNVVisibleDevices: gpuID, envDockerResourceGPUs: anotherGPUID, + envCUDAVersion: "legacy", }, - legacyImage: true, expectedDevices: &gpuID, }, { description: "empty env returns all for legacy image", env: map[string]string{ envDockerResourceGPUs: anotherGPUID, + envCUDAVersion: "legacy", }, - legacyImage: true, expectedDevices: &all, }, // Add the `DOCKER_RESOURCE_GPUS` envvar and ensure that this is selected when @@ -834,8 +836,8 @@ func TestGetDevicesFromEnvvar(t *testing.T) { envSwarmGPU: &envDockerResourceGPUs, env: map[string]string{ envDockerResourceGPUs: gpuID, + envCUDAVersion: "legacy", }, - legacyImage: true, expectedDevices: &gpuID, }, { @@ -860,7 +862,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) { for i, tc := range tests { t.Run(tc.description, func(t *testing.T) { envSwarmGPU = tc.envSwarmGPU - devices := getDevicesFromEnvvar(tc.env, tc.legacyImage) + devices := getDevicesFromEnvvar(image.CUDA(tc.env)) if tc.expectedDevices == nil { require.Nil(t, devices, "%d: %v", i, tc) return diff --git a/internal/config/image/cuda_image.go b/internal/config/image/cuda_image.go index 2315d4d0..b9ad4c94 100644 --- a/internal/config/image/cuda_image.go +++ b/internal/config/image/cuda_image.go @@ -112,6 +112,36 @@ func (i CUDA) HasDisableRequire() bool { return false } +// DevicesFromEnvvars returns the devices requested by the image through environment variables +func (i CUDA) DevicesFromEnvvars(envVars ...string) []string { + // Grab a reference to devices from the first envvar + // in the list that actually exists in the environment. + var devices *string + for _, envVar := range envVars { + if devs, ok := i[envVar]; ok { + devices = &devs + break + } + } + + // Environment variable unset with legacy image: default to "all". + if devices == nil && i.IsLegacy() { + return []string{"all"} + } + + // Environment variable unset or empty or "void": return nil + if devices == nil || len(*devices) == 0 || *devices == "void" { + return nil + } + + // Environment variable set to "none": reset to "". + if *devices == "none" { + return []string{""} + } + + return strings.Split(*devices, ",") +} + func (i CUDA) legacyVersion() (string, error) { majorMinor, err := parseMajorMinorVersion(i[envCUDAVersion]) if err != nil {