diff --git a/cmd/nvidia-container-runtime-hook/container_config.go b/cmd/nvidia-container-runtime-hook/container_config.go index fa39bf2f..b509cf41 100644 --- a/cmd/nvidia-container-runtime-hook/container_config.go +++ b/cmd/nvidia-container-runtime-hook/container_config.go @@ -7,7 +7,6 @@ import ( "os" "path" "path/filepath" - "strings" "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/mod/semver" @@ -36,7 +35,7 @@ const ( ) type nvidiaConfig struct { - Devices string + Devices []string MigConfigDevices string MigMonitorDevices string ImexChannels string @@ -172,34 +171,19 @@ func isPrivileged(s *Spec) bool { return image.IsPrivileged(&fullSpec) } -func getDevicesFromEnvvar(image image.CUDA, swarmResourceEnvvars []string) *string { +func getDevicesFromEnvvar(image image.CUDA, swarmResourceEnvvars []string) []string { // We check if the image has at least one of the Swarm resource envvars defined and use this // if specified. - var hasSwarmEnvvar bool for _, envvar := range swarmResourceEnvvars { if image.HasEnvvar(envvar) { - hasSwarmEnvvar = true - break + return image.DevicesFromEnvvars(swarmResourceEnvvars...).List() } } - var devices []string - if hasSwarmEnvvar { - devices = image.DevicesFromEnvvars(swarmResourceEnvvars...).List() - } else { - devices = image.DevicesFromEnvvars(envNVVisibleDevices).List() - } - - if len(devices) == 0 { - return nil - } - - devicesString := strings.Join(devices, ",") - - return &devicesString + return image.DevicesFromEnvvars(envNVVisibleDevices).List() } -func getDevicesFromMounts(mounts []Mount) *string { +func getDevicesFromMounts(mounts []Mount) []string { var devices []string for _, m := range mounts { root := filepath.Clean(deviceListAsVolumeMountsRoot) @@ -232,22 +216,21 @@ func getDevicesFromMounts(mounts []Mount) *string { return nil } - ret := strings.Join(devices, ",") - return &ret + return devices } -func getDevices(hookConfig *HookConfig, image image.CUDA, mounts []Mount, privileged bool) *string { +func getDevices(hookConfig *HookConfig, image image.CUDA, mounts []Mount, privileged bool) []string { // If enabled, try and get the device list from volume mounts first if hookConfig.AcceptDeviceListAsVolumeMounts { devices := getDevicesFromMounts(mounts) - if devices != nil { + if len(devices) > 0 { return devices } } // Fallback to reading from the environment variable if privileges are correct devices := getDevicesFromEnvvar(image, hookConfig.getSwarmResourceEnvvars()) - if devices == nil { + if len(devices) == 0 { return nil } if privileged || hookConfig.AcceptEnvvarUnprivileged { @@ -314,11 +297,9 @@ func (c *HookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage boo func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, mounts []Mount, privileged bool) *nvidiaConfig { legacyImage := image.IsLegacy() - var devices string - if d := getDevices(hookConfig, image, mounts, privileged); d != nil { - devices = *d - } else { - // 'nil' devices means this is not a GPU container. + devices := getDevices(hookConfig, image, mounts, privileged) + if len(devices) == 0 { + // empty devices means this is not a GPU container. return nil } diff --git a/cmd/nvidia-container-runtime-hook/container_config_test.go b/cmd/nvidia-container-runtime-hook/container_config_test.go index 43fac8aa..8f852ff8 100644 --- a/cmd/nvidia-container-runtime-hook/container_config_test.go +++ b/cmd/nvidia-container-runtime-hook/container_config_test.go @@ -1,7 +1,6 @@ package main import ( - "fmt" "path/filepath" "testing" @@ -38,7 +37,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: false, expectedConfig: &nvidiaConfig{ - Devices: "all", + Devices: []string{"all"}, DriverCapabilities: image.SupportedDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, @@ -51,7 +50,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: false, expectedConfig: &nvidiaConfig{ - Devices: "all", + Devices: []string{"all"}, DriverCapabilities: image.SupportedDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, @@ -82,7 +81,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: false, expectedConfig: &nvidiaConfig{ - Devices: "", + Devices: []string{""}, DriverCapabilities: image.SupportedDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, @@ -95,7 +94,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: false, expectedConfig: &nvidiaConfig{ - Devices: "gpu0,gpu1", + Devices: []string{"gpu0", "gpu1"}, DriverCapabilities: image.SupportedDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, @@ -109,7 +108,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: false, expectedConfig: &nvidiaConfig{ - Devices: "gpu0,gpu1", + Devices: []string{"gpu0", "gpu1"}, DriverCapabilities: image.DefaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, @@ -123,7 +122,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: false, expectedConfig: &nvidiaConfig{ - Devices: "gpu0,gpu1", + Devices: []string{"gpu0", "gpu1"}, DriverCapabilities: image.SupportedDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, @@ -137,7 +136,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: false, expectedConfig: &nvidiaConfig{ - Devices: "gpu0,gpu1", + Devices: []string{"gpu0", "gpu1"}, DriverCapabilities: "display,video", Requirements: []string{"cuda>=9.0"}, }, @@ -153,7 +152,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: false, expectedConfig: &nvidiaConfig{ - Devices: "gpu0,gpu1", + Devices: []string{"gpu0", "gpu1"}, DriverCapabilities: "display,video", Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"}, }, @@ -170,7 +169,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: false, expectedConfig: &nvidiaConfig{ - Devices: "gpu0,gpu1", + Devices: []string{"gpu0", "gpu1"}, DriverCapabilities: "display,video", Requirements: []string{}, }, @@ -200,7 +199,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: false, expectedConfig: &nvidiaConfig{ - Devices: "all", + Devices: []string{"all"}, DriverCapabilities: image.DefaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, @@ -231,7 +230,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: false, expectedConfig: &nvidiaConfig{ - Devices: "", + Devices: []string{""}, DriverCapabilities: image.DefaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, @@ -244,7 +243,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: false, expectedConfig: &nvidiaConfig{ - Devices: "gpu0,gpu1", + Devices: []string{"gpu0", "gpu1"}, DriverCapabilities: image.DefaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, @@ -258,7 +257,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: false, expectedConfig: &nvidiaConfig{ - Devices: "gpu0,gpu1", + Devices: []string{"gpu0", "gpu1"}, DriverCapabilities: image.DefaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, @@ -272,7 +271,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: false, expectedConfig: &nvidiaConfig{ - Devices: "gpu0,gpu1", + Devices: []string{"gpu0", "gpu1"}, DriverCapabilities: image.SupportedDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, @@ -286,7 +285,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: false, expectedConfig: &nvidiaConfig{ - Devices: "gpu0,gpu1", + Devices: []string{"gpu0", "gpu1"}, DriverCapabilities: "display,video", Requirements: []string{"cuda>=9.0"}, }, @@ -302,7 +301,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: false, expectedConfig: &nvidiaConfig{ - Devices: "gpu0,gpu1", + Devices: []string{"gpu0", "gpu1"}, DriverCapabilities: "display,video", Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"}, }, @@ -319,7 +318,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: false, expectedConfig: &nvidiaConfig{ - Devices: "gpu0,gpu1", + Devices: []string{"gpu0", "gpu1"}, DriverCapabilities: "display,video", Requirements: []string{}, }, @@ -332,7 +331,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ - Devices: "all", + Devices: []string{"all"}, DriverCapabilities: image.DefaultDriverCapabilities.String(), Requirements: []string{}, }, @@ -346,7 +345,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: true, expectedConfig: &nvidiaConfig{ - Devices: "all", + Devices: []string{"all"}, MigConfigDevices: "mig0,mig1", DriverCapabilities: image.DefaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, @@ -371,7 +370,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: true, expectedConfig: &nvidiaConfig{ - Devices: "all", + Devices: []string{"all"}, MigMonitorDevices: "mig0,mig1", DriverCapabilities: image.DefaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, @@ -398,7 +397,7 @@ func TestGetNvidiaConfig(t *testing.T) { SupportedDriverCapabilities: "video,display", }, expectedConfig: &nvidiaConfig{ - Devices: "all", + Devices: []string{"all"}, DriverCapabilities: "display,video", }, }, @@ -413,7 +412,7 @@ func TestGetNvidiaConfig(t *testing.T) { SupportedDriverCapabilities: "video,display,compute,utility", }, expectedConfig: &nvidiaConfig{ - Devices: "all", + Devices: []string{"all"}, DriverCapabilities: "display,video", }, }, @@ -427,7 +426,7 @@ func TestGetNvidiaConfig(t *testing.T) { SupportedDriverCapabilities: "video,display,utility,compute", }, expectedConfig: &nvidiaConfig{ - Devices: "all", + Devices: []string{"all"}, DriverCapabilities: image.DefaultDriverCapabilities.String(), }, }, @@ -443,7 +442,7 @@ func TestGetNvidiaConfig(t *testing.T) { SupportedDriverCapabilities: "video,display,utility,compute", }, expectedConfig: &nvidiaConfig{ - Devices: "GPU1,GPU2", + Devices: []string{"GPU1", "GPU2"}, DriverCapabilities: image.DefaultDriverCapabilities.String(), }, }, @@ -459,7 +458,7 @@ func TestGetNvidiaConfig(t *testing.T) { SupportedDriverCapabilities: "video,display,utility,compute", }, expectedConfig: &nvidiaConfig{ - Devices: "GPU1,GPU2", + Devices: []string{"GPU1", "GPU2"}, DriverCapabilities: image.DefaultDriverCapabilities.String(), }, }, @@ -511,7 +510,7 @@ func TestGetDevicesFromMounts(t *testing.T) { var tests = []struct { description string mounts []Mount - expectedDevices *string + expectedDevices []string }{ { description: "No mounts", @@ -560,7 +559,7 @@ func TestGetDevicesFromMounts(t *testing.T) { Destination: filepath.Join(deviceListAsVolumeMountsRoot, "GPU1"), }, }, - expectedDevices: &[]string{"GPU0,GPU1"}[0], + expectedDevices: []string{"GPU0", "GPU1"}, }, { description: "Discover 2 devices with slashes in the name", @@ -574,7 +573,7 @@ func TestGetDevicesFromMounts(t *testing.T) { Destination: filepath.Join(deviceListAsVolumeMountsRoot, "GPU1-MIG0/0/1"), }, }, - expectedDevices: &[]string{"GPU0-MIG0/0/1,GPU1-MIG0/0/1"}[0], + expectedDevices: []string{"GPU0-MIG0/0/1", "GPU1-MIG0/0/1"}, }, } for _, tc := range tests { @@ -593,7 +592,7 @@ func TestDeviceListSourcePriority(t *testing.T) { privileged bool acceptUnprivileged bool acceptMounts bool - expectedDevices *string + expectedDevices []string }{ { description: "Mount devices, unprivileged, no accept unprivileged", @@ -611,7 +610,7 @@ func TestDeviceListSourcePriority(t *testing.T) { privileged: false, acceptUnprivileged: false, acceptMounts: true, - expectedDevices: &[]string{"GPU0,GPU1"}[0], + expectedDevices: []string{"GPU0", "GPU1"}, }, { description: "No mount devices, unprivileged, no accept unprivileged", @@ -629,7 +628,7 @@ func TestDeviceListSourcePriority(t *testing.T) { privileged: true, acceptUnprivileged: false, acceptMounts: true, - expectedDevices: &[]string{"GPU0,GPU1"}[0], + expectedDevices: []string{"GPU0", "GPU1"}, }, { description: "No mount devices, unprivileged, accept unprivileged", @@ -638,7 +637,7 @@ func TestDeviceListSourcePriority(t *testing.T) { privileged: false, acceptUnprivileged: true, acceptMounts: true, - expectedDevices: &[]string{"GPU0,GPU1"}[0], + expectedDevices: []string{"GPU0", "GPU1"}, }, { description: "Mount devices, unprivileged, accept unprivileged, no accept mounts", @@ -656,7 +655,7 @@ func TestDeviceListSourcePriority(t *testing.T) { privileged: false, acceptUnprivileged: true, acceptMounts: false, - expectedDevices: &[]string{"GPU2,GPU3"}[0], + expectedDevices: []string{"GPU2", "GPU3"}, }, { description: "Mount devices, unprivileged, no accept unprivileged, no accept mounts", @@ -680,7 +679,7 @@ func TestDeviceListSourcePriority(t *testing.T) { for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { // Wrap the call to getDevices() in a closure. - var devices *string + var devices []string getDevices := func() { image, _ := image.New( image.WithEnvMap( @@ -704,8 +703,6 @@ func TestDeviceListSourcePriority(t *testing.T) { } func TestGetDevicesFromEnvvar(t *testing.T) { - all := "all" - empty := "" envDockerResourceGPUs := "DOCKER_RESOURCE_GPUS" gpuID := "GPU-12345" anotherGPUID := "GPU-67890" @@ -715,7 +712,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) { description string swarmResourceEnvvars []string env map[string]string - expectedDevices *string + expectedDevices []string }{ { description: "empty env returns nil for non-legacy image", @@ -737,14 +734,14 @@ func TestGetDevicesFromEnvvar(t *testing.T) { env: map[string]string{ envNVVisibleDevices: "none", }, - expectedDevices: &empty, + expectedDevices: []string{""}, }, { description: "NVIDIA_VISIBLE_DEVICES set returns value for non-legacy image", env: map[string]string{ envNVVisibleDevices: gpuID, }, - expectedDevices: &gpuID, + expectedDevices: []string{gpuID}, }, { description: "NVIDIA_VISIBLE_DEVICES set returns value for legacy image", @@ -752,14 +749,14 @@ func TestGetDevicesFromEnvvar(t *testing.T) { envNVVisibleDevices: gpuID, envCUDAVersion: "legacy", }, - expectedDevices: &gpuID, + expectedDevices: []string{gpuID}, }, { description: "empty env returns all for legacy image", env: map[string]string{ envCUDAVersion: "legacy", }, - expectedDevices: &all, + expectedDevices: []string{"all"}, }, // Add the `DOCKER_RESOURCE_GPUS` envvar and ensure that this is ignored when // not enabled @@ -789,7 +786,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) { envNVVisibleDevices: "none", envDockerResourceGPUs: anotherGPUID, }, - expectedDevices: &empty, + expectedDevices: []string{""}, }, { description: "NVIDIA_VISIBLE_DEVICES set returns value for non-legacy image", @@ -797,7 +794,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) { envNVVisibleDevices: gpuID, envDockerResourceGPUs: anotherGPUID, }, - expectedDevices: &gpuID, + expectedDevices: []string{gpuID}, }, { description: "NVIDIA_VISIBLE_DEVICES set returns value for legacy image", @@ -806,7 +803,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) { envDockerResourceGPUs: anotherGPUID, envCUDAVersion: "legacy", }, - expectedDevices: &gpuID, + expectedDevices: []string{gpuID}, }, { description: "empty env returns all for legacy image", @@ -814,7 +811,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) { envDockerResourceGPUs: anotherGPUID, envCUDAVersion: "legacy", }, - expectedDevices: &all, + expectedDevices: []string{"all"}, }, // Add the `DOCKER_RESOURCE_GPUS` envvar and ensure that this is selected when // enabled @@ -842,7 +839,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) { env: map[string]string{ envDockerResourceGPUs: "none", }, - expectedDevices: &empty, + expectedDevices: []string{""}, }, { description: "DOCKER_RESOURCE_GPUS set returns value for non-legacy image", @@ -850,7 +847,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) { env: map[string]string{ envDockerResourceGPUs: gpuID, }, - expectedDevices: &gpuID, + expectedDevices: []string{gpuID}, }, { description: "DOCKER_RESOURCE_GPUS set returns value for legacy image", @@ -859,7 +856,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) { envDockerResourceGPUs: gpuID, envCUDAVersion: "legacy", }, - expectedDevices: &gpuID, + expectedDevices: []string{gpuID}, }, { description: "DOCKER_RESOURCE_GPUS is selected if present", @@ -867,7 +864,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) { env: map[string]string{ envDockerResourceGPUs: anotherGPUID, }, - expectedDevices: &anotherGPUID, + expectedDevices: []string{anotherGPUID}, }, { description: "DOCKER_RESOURCE_GPUS overrides NVIDIA_VISIBLE_DEVICES if present", @@ -876,7 +873,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) { envNVVisibleDevices: gpuID, envDockerResourceGPUs: anotherGPUID, }, - expectedDevices: &anotherGPUID, + expectedDevices: []string{anotherGPUID}, }, { description: "DOCKER_RESOURCE_GPUS_ADDITIONAL overrides NVIDIA_VISIBLE_DEVICES if present", @@ -885,7 +882,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) { envNVVisibleDevices: gpuID, "DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID, }, - expectedDevices: &anotherGPUID, + expectedDevices: []string{anotherGPUID}, }, { description: "All available swarm resource envvars are selected and override NVIDIA_VISIBLE_DEVICES if present", @@ -895,10 +892,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) { "DOCKER_RESOURCE_GPUS": thirdGPUID, "DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID, }, - expectedDevices: func() *string { - result := fmt.Sprintf("%s,%s", thirdGPUID, anotherGPUID) - return &result - }(), + expectedDevices: []string{thirdGPUID, anotherGPUID}, }, { description: "DOCKER_RESOURCE_GPUS_ADDITIONAL or DOCKER_RESOURCE_GPUS override NVIDIA_VISIBLE_DEVICES if present", @@ -907,23 +901,17 @@ func TestGetDevicesFromEnvvar(t *testing.T) { envNVVisibleDevices: gpuID, "DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID, }, - expectedDevices: &anotherGPUID, + expectedDevices: []string{anotherGPUID}, }, } - for i, tc := range tests { + for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { image, _ := image.New( image.WithEnvMap(tc.env), ) devices := getDevicesFromEnvvar(image, tc.swarmResourceEnvvars) - if tc.expectedDevices == nil { - require.Nil(t, devices, "%d: %v", i, tc) - return - } - - require.NotNil(t, devices, "%d: %v", i, tc) - require.Equal(t, *tc.expectedDevices, *devices, "%d: %v", i, tc) + require.EqualValues(t, tc.expectedDevices, devices) }) } } diff --git a/cmd/nvidia-container-runtime-hook/main.go b/cmd/nvidia-container-runtime-hook/main.go index 1f4bd525..ad3e208d 100644 --- a/cmd/nvidia-container-runtime-hook/main.go +++ b/cmd/nvidia-container-runtime-hook/main.go @@ -120,8 +120,8 @@ func doPrestart() { if cli.NoCgroups { args = append(args, "--no-cgroups") } - if len(nvidia.Devices) > 0 { - args = append(args, fmt.Sprintf("--device=%s", nvidia.Devices)) + if devicesString := strings.Join(nvidia.Devices, ","); len(devicesString) > 0 { + args = append(args, fmt.Sprintf("--device=%s", devicesString)) } if len(nvidia.MigConfigDevices) > 0 { args = append(args, fmt.Sprintf("--mig-config=%s", nvidia.MigConfigDevices))