[no-relnote] Use string slice for devices in hook

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2024-10-14 14:53:36 +02:00
parent b077e2648d
commit 1991b3ef2a
3 changed files with 67 additions and 98 deletions

View File

@ -7,7 +7,6 @@ import (
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"strings"
"github.com/opencontainers/runtime-spec/specs-go" "github.com/opencontainers/runtime-spec/specs-go"
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
@ -36,7 +35,7 @@ const (
) )
type nvidiaConfig struct { type nvidiaConfig struct {
Devices string Devices []string
MigConfigDevices string MigConfigDevices string
MigMonitorDevices string MigMonitorDevices string
ImexChannels string ImexChannels string
@ -172,34 +171,19 @@ func isPrivileged(s *Spec) bool {
return image.IsPrivileged(&fullSpec) return image.IsPrivileged(&fullSpec)
} }
func getDevicesFromEnvvar(image image.CUDA, swarmResourceEnvvars []string) *string { func getDevicesFromEnvvar(image image.CUDA, swarmResourceEnvvars []string) []string {
// We check if the image has at least one of the Swarm resource envvars defined and use this // We check if the image has at least one of the Swarm resource envvars defined and use this
// if specified. // if specified.
var hasSwarmEnvvar bool
for _, envvar := range swarmResourceEnvvars { for _, envvar := range swarmResourceEnvvars {
if image.HasEnvvar(envvar) { if image.HasEnvvar(envvar) {
hasSwarmEnvvar = true return image.DevicesFromEnvvars(swarmResourceEnvvars...).List()
break
} }
} }
var devices []string return image.DevicesFromEnvvars(envNVVisibleDevices).List()
if hasSwarmEnvvar {
devices = image.DevicesFromEnvvars(swarmResourceEnvvars...).List()
} else {
devices = image.DevicesFromEnvvars(envNVVisibleDevices).List()
} }
if len(devices) == 0 { func getDevicesFromMounts(mounts []Mount) []string {
return nil
}
devicesString := strings.Join(devices, ",")
return &devicesString
}
func getDevicesFromMounts(mounts []Mount) *string {
var devices []string var devices []string
for _, m := range mounts { for _, m := range mounts {
root := filepath.Clean(deviceListAsVolumeMountsRoot) root := filepath.Clean(deviceListAsVolumeMountsRoot)
@ -232,22 +216,21 @@ func getDevicesFromMounts(mounts []Mount) *string {
return nil return nil
} }
ret := strings.Join(devices, ",") return devices
return &ret
} }
func getDevices(hookConfig *HookConfig, image image.CUDA, mounts []Mount, privileged bool) *string { func getDevices(hookConfig *HookConfig, image image.CUDA, mounts []Mount, privileged bool) []string {
// If enabled, try and get the device list from volume mounts first // If enabled, try and get the device list from volume mounts first
if hookConfig.AcceptDeviceListAsVolumeMounts { if hookConfig.AcceptDeviceListAsVolumeMounts {
devices := getDevicesFromMounts(mounts) devices := getDevicesFromMounts(mounts)
if devices != nil { if len(devices) > 0 {
return devices return devices
} }
} }
// Fallback to reading from the environment variable if privileges are correct // Fallback to reading from the environment variable if privileges are correct
devices := getDevicesFromEnvvar(image, hookConfig.getSwarmResourceEnvvars()) devices := getDevicesFromEnvvar(image, hookConfig.getSwarmResourceEnvvars())
if devices == nil { if len(devices) == 0 {
return nil return nil
} }
if privileged || hookConfig.AcceptEnvvarUnprivileged { if privileged || hookConfig.AcceptEnvvarUnprivileged {
@ -314,11 +297,9 @@ func (c *HookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage boo
func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, mounts []Mount, privileged bool) *nvidiaConfig { func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, mounts []Mount, privileged bool) *nvidiaConfig {
legacyImage := image.IsLegacy() legacyImage := image.IsLegacy()
var devices string devices := getDevices(hookConfig, image, mounts, privileged)
if d := getDevices(hookConfig, image, mounts, privileged); d != nil { if len(devices) == 0 {
devices = *d // empty devices means this is not a GPU container.
} else {
// 'nil' devices means this is not a GPU container.
return nil return nil
} }

View File

@ -1,7 +1,6 @@
package main package main
import ( import (
"fmt"
"path/filepath" "path/filepath"
"testing" "testing"
@ -38,7 +37,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: false, privileged: false,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "all", Devices: []string{"all"},
DriverCapabilities: image.SupportedDriverCapabilities.String(), DriverCapabilities: image.SupportedDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"}, Requirements: []string{"cuda>=9.0"},
}, },
@ -51,7 +50,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: false, privileged: false,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "all", Devices: []string{"all"},
DriverCapabilities: image.SupportedDriverCapabilities.String(), DriverCapabilities: image.SupportedDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"}, Requirements: []string{"cuda>=9.0"},
}, },
@ -82,7 +81,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: false, privileged: false,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "", Devices: []string{""},
DriverCapabilities: image.SupportedDriverCapabilities.String(), DriverCapabilities: image.SupportedDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"}, Requirements: []string{"cuda>=9.0"},
}, },
@ -95,7 +94,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: false, privileged: false,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1", Devices: []string{"gpu0", "gpu1"},
DriverCapabilities: image.SupportedDriverCapabilities.String(), DriverCapabilities: image.SupportedDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"}, Requirements: []string{"cuda>=9.0"},
}, },
@ -109,7 +108,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: false, privileged: false,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1", Devices: []string{"gpu0", "gpu1"},
DriverCapabilities: image.DefaultDriverCapabilities.String(), DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"}, Requirements: []string{"cuda>=9.0"},
}, },
@ -123,7 +122,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: false, privileged: false,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1", Devices: []string{"gpu0", "gpu1"},
DriverCapabilities: image.SupportedDriverCapabilities.String(), DriverCapabilities: image.SupportedDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"}, Requirements: []string{"cuda>=9.0"},
}, },
@ -137,7 +136,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: false, privileged: false,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1", Devices: []string{"gpu0", "gpu1"},
DriverCapabilities: "display,video", DriverCapabilities: "display,video",
Requirements: []string{"cuda>=9.0"}, Requirements: []string{"cuda>=9.0"},
}, },
@ -153,7 +152,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: false, privileged: false,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1", Devices: []string{"gpu0", "gpu1"},
DriverCapabilities: "display,video", DriverCapabilities: "display,video",
Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"}, Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"},
}, },
@ -170,7 +169,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: false, privileged: false,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1", Devices: []string{"gpu0", "gpu1"},
DriverCapabilities: "display,video", DriverCapabilities: "display,video",
Requirements: []string{}, Requirements: []string{},
}, },
@ -200,7 +199,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: false, privileged: false,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "all", Devices: []string{"all"},
DriverCapabilities: image.DefaultDriverCapabilities.String(), DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"}, Requirements: []string{"cuda>=9.0"},
}, },
@ -231,7 +230,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: false, privileged: false,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "", Devices: []string{""},
DriverCapabilities: image.DefaultDriverCapabilities.String(), DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"}, Requirements: []string{"cuda>=9.0"},
}, },
@ -244,7 +243,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: false, privileged: false,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1", Devices: []string{"gpu0", "gpu1"},
DriverCapabilities: image.DefaultDriverCapabilities.String(), DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"}, Requirements: []string{"cuda>=9.0"},
}, },
@ -258,7 +257,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: false, privileged: false,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1", Devices: []string{"gpu0", "gpu1"},
DriverCapabilities: image.DefaultDriverCapabilities.String(), DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"}, Requirements: []string{"cuda>=9.0"},
}, },
@ -272,7 +271,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: false, privileged: false,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1", Devices: []string{"gpu0", "gpu1"},
DriverCapabilities: image.SupportedDriverCapabilities.String(), DriverCapabilities: image.SupportedDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"}, Requirements: []string{"cuda>=9.0"},
}, },
@ -286,7 +285,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: false, privileged: false,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1", Devices: []string{"gpu0", "gpu1"},
DriverCapabilities: "display,video", DriverCapabilities: "display,video",
Requirements: []string{"cuda>=9.0"}, Requirements: []string{"cuda>=9.0"},
}, },
@ -302,7 +301,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: false, privileged: false,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1", Devices: []string{"gpu0", "gpu1"},
DriverCapabilities: "display,video", DriverCapabilities: "display,video",
Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"}, Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"},
}, },
@ -319,7 +318,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: false, privileged: false,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1", Devices: []string{"gpu0", "gpu1"},
DriverCapabilities: "display,video", DriverCapabilities: "display,video",
Requirements: []string{}, Requirements: []string{},
}, },
@ -332,7 +331,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false, privileged: false,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "all", Devices: []string{"all"},
DriverCapabilities: image.DefaultDriverCapabilities.String(), DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{}, Requirements: []string{},
}, },
@ -346,7 +345,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: true, privileged: true,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "all", Devices: []string{"all"},
MigConfigDevices: "mig0,mig1", MigConfigDevices: "mig0,mig1",
DriverCapabilities: image.DefaultDriverCapabilities.String(), DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"}, Requirements: []string{"cuda>=9.0"},
@ -371,7 +370,7 @@ func TestGetNvidiaConfig(t *testing.T) {
}, },
privileged: true, privileged: true,
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "all", Devices: []string{"all"},
MigMonitorDevices: "mig0,mig1", MigMonitorDevices: "mig0,mig1",
DriverCapabilities: image.DefaultDriverCapabilities.String(), DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"}, Requirements: []string{"cuda>=9.0"},
@ -398,7 +397,7 @@ func TestGetNvidiaConfig(t *testing.T) {
SupportedDriverCapabilities: "video,display", SupportedDriverCapabilities: "video,display",
}, },
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "all", Devices: []string{"all"},
DriverCapabilities: "display,video", DriverCapabilities: "display,video",
}, },
}, },
@ -413,7 +412,7 @@ func TestGetNvidiaConfig(t *testing.T) {
SupportedDriverCapabilities: "video,display,compute,utility", SupportedDriverCapabilities: "video,display,compute,utility",
}, },
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "all", Devices: []string{"all"},
DriverCapabilities: "display,video", DriverCapabilities: "display,video",
}, },
}, },
@ -427,7 +426,7 @@ func TestGetNvidiaConfig(t *testing.T) {
SupportedDriverCapabilities: "video,display,utility,compute", SupportedDriverCapabilities: "video,display,utility,compute",
}, },
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "all", Devices: []string{"all"},
DriverCapabilities: image.DefaultDriverCapabilities.String(), DriverCapabilities: image.DefaultDriverCapabilities.String(),
}, },
}, },
@ -443,7 +442,7 @@ func TestGetNvidiaConfig(t *testing.T) {
SupportedDriverCapabilities: "video,display,utility,compute", SupportedDriverCapabilities: "video,display,utility,compute",
}, },
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "GPU1,GPU2", Devices: []string{"GPU1", "GPU2"},
DriverCapabilities: image.DefaultDriverCapabilities.String(), DriverCapabilities: image.DefaultDriverCapabilities.String(),
}, },
}, },
@ -459,7 +458,7 @@ func TestGetNvidiaConfig(t *testing.T) {
SupportedDriverCapabilities: "video,display,utility,compute", SupportedDriverCapabilities: "video,display,utility,compute",
}, },
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: "GPU1,GPU2", Devices: []string{"GPU1", "GPU2"},
DriverCapabilities: image.DefaultDriverCapabilities.String(), DriverCapabilities: image.DefaultDriverCapabilities.String(),
}, },
}, },
@ -511,7 +510,7 @@ func TestGetDevicesFromMounts(t *testing.T) {
var tests = []struct { var tests = []struct {
description string description string
mounts []Mount mounts []Mount
expectedDevices *string expectedDevices []string
}{ }{
{ {
description: "No mounts", description: "No mounts",
@ -560,7 +559,7 @@ func TestGetDevicesFromMounts(t *testing.T) {
Destination: filepath.Join(deviceListAsVolumeMountsRoot, "GPU1"), Destination: filepath.Join(deviceListAsVolumeMountsRoot, "GPU1"),
}, },
}, },
expectedDevices: &[]string{"GPU0,GPU1"}[0], expectedDevices: []string{"GPU0", "GPU1"},
}, },
{ {
description: "Discover 2 devices with slashes in the name", description: "Discover 2 devices with slashes in the name",
@ -574,7 +573,7 @@ func TestGetDevicesFromMounts(t *testing.T) {
Destination: filepath.Join(deviceListAsVolumeMountsRoot, "GPU1-MIG0/0/1"), Destination: filepath.Join(deviceListAsVolumeMountsRoot, "GPU1-MIG0/0/1"),
}, },
}, },
expectedDevices: &[]string{"GPU0-MIG0/0/1,GPU1-MIG0/0/1"}[0], expectedDevices: []string{"GPU0-MIG0/0/1", "GPU1-MIG0/0/1"},
}, },
} }
for _, tc := range tests { for _, tc := range tests {
@ -593,7 +592,7 @@ func TestDeviceListSourcePriority(t *testing.T) {
privileged bool privileged bool
acceptUnprivileged bool acceptUnprivileged bool
acceptMounts bool acceptMounts bool
expectedDevices *string expectedDevices []string
}{ }{
{ {
description: "Mount devices, unprivileged, no accept unprivileged", description: "Mount devices, unprivileged, no accept unprivileged",
@ -611,7 +610,7 @@ func TestDeviceListSourcePriority(t *testing.T) {
privileged: false, privileged: false,
acceptUnprivileged: false, acceptUnprivileged: false,
acceptMounts: true, acceptMounts: true,
expectedDevices: &[]string{"GPU0,GPU1"}[0], expectedDevices: []string{"GPU0", "GPU1"},
}, },
{ {
description: "No mount devices, unprivileged, no accept unprivileged", description: "No mount devices, unprivileged, no accept unprivileged",
@ -629,7 +628,7 @@ func TestDeviceListSourcePriority(t *testing.T) {
privileged: true, privileged: true,
acceptUnprivileged: false, acceptUnprivileged: false,
acceptMounts: true, acceptMounts: true,
expectedDevices: &[]string{"GPU0,GPU1"}[0], expectedDevices: []string{"GPU0", "GPU1"},
}, },
{ {
description: "No mount devices, unprivileged, accept unprivileged", description: "No mount devices, unprivileged, accept unprivileged",
@ -638,7 +637,7 @@ func TestDeviceListSourcePriority(t *testing.T) {
privileged: false, privileged: false,
acceptUnprivileged: true, acceptUnprivileged: true,
acceptMounts: true, acceptMounts: true,
expectedDevices: &[]string{"GPU0,GPU1"}[0], expectedDevices: []string{"GPU0", "GPU1"},
}, },
{ {
description: "Mount devices, unprivileged, accept unprivileged, no accept mounts", description: "Mount devices, unprivileged, accept unprivileged, no accept mounts",
@ -656,7 +655,7 @@ func TestDeviceListSourcePriority(t *testing.T) {
privileged: false, privileged: false,
acceptUnprivileged: true, acceptUnprivileged: true,
acceptMounts: false, acceptMounts: false,
expectedDevices: &[]string{"GPU2,GPU3"}[0], expectedDevices: []string{"GPU2", "GPU3"},
}, },
{ {
description: "Mount devices, unprivileged, no accept unprivileged, no accept mounts", description: "Mount devices, unprivileged, no accept unprivileged, no accept mounts",
@ -680,7 +679,7 @@ func TestDeviceListSourcePriority(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
// Wrap the call to getDevices() in a closure. // Wrap the call to getDevices() in a closure.
var devices *string var devices []string
getDevices := func() { getDevices := func() {
image, _ := image.New( image, _ := image.New(
image.WithEnvMap( image.WithEnvMap(
@ -704,8 +703,6 @@ func TestDeviceListSourcePriority(t *testing.T) {
} }
func TestGetDevicesFromEnvvar(t *testing.T) { func TestGetDevicesFromEnvvar(t *testing.T) {
all := "all"
empty := ""
envDockerResourceGPUs := "DOCKER_RESOURCE_GPUS" envDockerResourceGPUs := "DOCKER_RESOURCE_GPUS"
gpuID := "GPU-12345" gpuID := "GPU-12345"
anotherGPUID := "GPU-67890" anotherGPUID := "GPU-67890"
@ -715,7 +712,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
description string description string
swarmResourceEnvvars []string swarmResourceEnvvars []string
env map[string]string env map[string]string
expectedDevices *string expectedDevices []string
}{ }{
{ {
description: "empty env returns nil for non-legacy image", description: "empty env returns nil for non-legacy image",
@ -737,14 +734,14 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
env: map[string]string{ env: map[string]string{
envNVVisibleDevices: "none", envNVVisibleDevices: "none",
}, },
expectedDevices: &empty, expectedDevices: []string{""},
}, },
{ {
description: "NVIDIA_VISIBLE_DEVICES set returns value for non-legacy image", description: "NVIDIA_VISIBLE_DEVICES set returns value for non-legacy image",
env: map[string]string{ env: map[string]string{
envNVVisibleDevices: gpuID, envNVVisibleDevices: gpuID,
}, },
expectedDevices: &gpuID, expectedDevices: []string{gpuID},
}, },
{ {
description: "NVIDIA_VISIBLE_DEVICES set returns value for legacy image", description: "NVIDIA_VISIBLE_DEVICES set returns value for legacy image",
@ -752,14 +749,14 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
envNVVisibleDevices: gpuID, envNVVisibleDevices: gpuID,
envCUDAVersion: "legacy", envCUDAVersion: "legacy",
}, },
expectedDevices: &gpuID, expectedDevices: []string{gpuID},
}, },
{ {
description: "empty env returns all for legacy image", description: "empty env returns all for legacy image",
env: map[string]string{ env: map[string]string{
envCUDAVersion: "legacy", envCUDAVersion: "legacy",
}, },
expectedDevices: &all, expectedDevices: []string{"all"},
}, },
// Add the `DOCKER_RESOURCE_GPUS` envvar and ensure that this is ignored when // Add the `DOCKER_RESOURCE_GPUS` envvar and ensure that this is ignored when
// not enabled // not enabled
@ -789,7 +786,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
envNVVisibleDevices: "none", envNVVisibleDevices: "none",
envDockerResourceGPUs: anotherGPUID, envDockerResourceGPUs: anotherGPUID,
}, },
expectedDevices: &empty, expectedDevices: []string{""},
}, },
{ {
description: "NVIDIA_VISIBLE_DEVICES set returns value for non-legacy image", description: "NVIDIA_VISIBLE_DEVICES set returns value for non-legacy image",
@ -797,7 +794,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
envNVVisibleDevices: gpuID, envNVVisibleDevices: gpuID,
envDockerResourceGPUs: anotherGPUID, envDockerResourceGPUs: anotherGPUID,
}, },
expectedDevices: &gpuID, expectedDevices: []string{gpuID},
}, },
{ {
description: "NVIDIA_VISIBLE_DEVICES set returns value for legacy image", description: "NVIDIA_VISIBLE_DEVICES set returns value for legacy image",
@ -806,7 +803,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
envDockerResourceGPUs: anotherGPUID, envDockerResourceGPUs: anotherGPUID,
envCUDAVersion: "legacy", envCUDAVersion: "legacy",
}, },
expectedDevices: &gpuID, expectedDevices: []string{gpuID},
}, },
{ {
description: "empty env returns all for legacy image", description: "empty env returns all for legacy image",
@ -814,7 +811,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
envDockerResourceGPUs: anotherGPUID, envDockerResourceGPUs: anotherGPUID,
envCUDAVersion: "legacy", envCUDAVersion: "legacy",
}, },
expectedDevices: &all, expectedDevices: []string{"all"},
}, },
// Add the `DOCKER_RESOURCE_GPUS` envvar and ensure that this is selected when // Add the `DOCKER_RESOURCE_GPUS` envvar and ensure that this is selected when
// enabled // enabled
@ -842,7 +839,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
env: map[string]string{ env: map[string]string{
envDockerResourceGPUs: "none", envDockerResourceGPUs: "none",
}, },
expectedDevices: &empty, expectedDevices: []string{""},
}, },
{ {
description: "DOCKER_RESOURCE_GPUS set returns value for non-legacy image", description: "DOCKER_RESOURCE_GPUS set returns value for non-legacy image",
@ -850,7 +847,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
env: map[string]string{ env: map[string]string{
envDockerResourceGPUs: gpuID, envDockerResourceGPUs: gpuID,
}, },
expectedDevices: &gpuID, expectedDevices: []string{gpuID},
}, },
{ {
description: "DOCKER_RESOURCE_GPUS set returns value for legacy image", description: "DOCKER_RESOURCE_GPUS set returns value for legacy image",
@ -859,7 +856,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
envDockerResourceGPUs: gpuID, envDockerResourceGPUs: gpuID,
envCUDAVersion: "legacy", envCUDAVersion: "legacy",
}, },
expectedDevices: &gpuID, expectedDevices: []string{gpuID},
}, },
{ {
description: "DOCKER_RESOURCE_GPUS is selected if present", description: "DOCKER_RESOURCE_GPUS is selected if present",
@ -867,7 +864,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
env: map[string]string{ env: map[string]string{
envDockerResourceGPUs: anotherGPUID, envDockerResourceGPUs: anotherGPUID,
}, },
expectedDevices: &anotherGPUID, expectedDevices: []string{anotherGPUID},
}, },
{ {
description: "DOCKER_RESOURCE_GPUS overrides NVIDIA_VISIBLE_DEVICES if present", description: "DOCKER_RESOURCE_GPUS overrides NVIDIA_VISIBLE_DEVICES if present",
@ -876,7 +873,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
envNVVisibleDevices: gpuID, envNVVisibleDevices: gpuID,
envDockerResourceGPUs: anotherGPUID, envDockerResourceGPUs: anotherGPUID,
}, },
expectedDevices: &anotherGPUID, expectedDevices: []string{anotherGPUID},
}, },
{ {
description: "DOCKER_RESOURCE_GPUS_ADDITIONAL overrides NVIDIA_VISIBLE_DEVICES if present", description: "DOCKER_RESOURCE_GPUS_ADDITIONAL overrides NVIDIA_VISIBLE_DEVICES if present",
@ -885,7 +882,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
envNVVisibleDevices: gpuID, envNVVisibleDevices: gpuID,
"DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID, "DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID,
}, },
expectedDevices: &anotherGPUID, expectedDevices: []string{anotherGPUID},
}, },
{ {
description: "All available swarm resource envvars are selected and override NVIDIA_VISIBLE_DEVICES if present", description: "All available swarm resource envvars are selected and override NVIDIA_VISIBLE_DEVICES if present",
@ -895,10 +892,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
"DOCKER_RESOURCE_GPUS": thirdGPUID, "DOCKER_RESOURCE_GPUS": thirdGPUID,
"DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID, "DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID,
}, },
expectedDevices: func() *string { expectedDevices: []string{thirdGPUID, anotherGPUID},
result := fmt.Sprintf("%s,%s", thirdGPUID, anotherGPUID)
return &result
}(),
}, },
{ {
description: "DOCKER_RESOURCE_GPUS_ADDITIONAL or DOCKER_RESOURCE_GPUS override NVIDIA_VISIBLE_DEVICES if present", description: "DOCKER_RESOURCE_GPUS_ADDITIONAL or DOCKER_RESOURCE_GPUS override NVIDIA_VISIBLE_DEVICES if present",
@ -907,23 +901,17 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
envNVVisibleDevices: gpuID, envNVVisibleDevices: gpuID,
"DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID, "DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID,
}, },
expectedDevices: &anotherGPUID, expectedDevices: []string{anotherGPUID},
}, },
} }
for i, tc := range tests { for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
image, _ := image.New( image, _ := image.New(
image.WithEnvMap(tc.env), image.WithEnvMap(tc.env),
) )
devices := getDevicesFromEnvvar(image, tc.swarmResourceEnvvars) devices := getDevicesFromEnvvar(image, tc.swarmResourceEnvvars)
if tc.expectedDevices == nil { require.EqualValues(t, tc.expectedDevices, devices)
require.Nil(t, devices, "%d: %v", i, tc)
return
}
require.NotNil(t, devices, "%d: %v", i, tc)
require.Equal(t, *tc.expectedDevices, *devices, "%d: %v", i, tc)
}) })
} }
} }

View File

@ -120,8 +120,8 @@ func doPrestart() {
if cli.NoCgroups { if cli.NoCgroups {
args = append(args, "--no-cgroups") args = append(args, "--no-cgroups")
} }
if len(nvidia.Devices) > 0 { if devicesString := strings.Join(nvidia.Devices, ","); len(devicesString) > 0 {
args = append(args, fmt.Sprintf("--device=%s", nvidia.Devices)) args = append(args, fmt.Sprintf("--device=%s", devicesString))
} }
if len(nvidia.MigConfigDevices) > 0 { if len(nvidia.MigConfigDevices) > 0 {
args = append(args, fmt.Sprintf("--mig-config=%s", nvidia.MigConfigDevices)) args = append(args, fmt.Sprintf("--mig-config=%s", nvidia.MigConfigDevices))