diff --git a/cmd/nvidia-container-runtime-hook/container_config.go b/cmd/nvidia-container-runtime-hook/container_config.go index 1956ba54..531b2b42 100644 --- a/cmd/nvidia-container-runtime-hook/container_config.go +++ b/cmd/nvidia-container-runtime-hook/container_config.go @@ -164,10 +164,23 @@ func isPrivileged(s *Spec) bool { } func getDevicesFromEnvvar(image image.CUDA, swarmResourceEnvvars []string) *string { - // Build a list of envvars to consider. Note that the Swarm Resource envvars have a higher precedence. - envVars := append(swarmResourceEnvvars, envNVVisibleDevices) + // We check if the image has at least one of the Swarm resource envvars defined and use this + // if specified. + var hasSwarmEnvvar bool + for _, envvar := range swarmResourceEnvvars { + if _, exists := image[envvar]; exists { + hasSwarmEnvvar = true + break + } + } + + var devices []string + if hasSwarmEnvvar { + devices = image.DevicesFromEnvvars(swarmResourceEnvvars...).List() + } else { + devices = image.DevicesFromEnvvars(envNVVisibleDevices).List() + } - devices := image.DevicesFromEnvvars(envVars...).List() if len(devices) == 0 { return nil } diff --git a/cmd/nvidia-container-runtime-hook/container_config_test.go b/cmd/nvidia-container-runtime-hook/container_config_test.go index 59431d5c..eb0aec61 100644 --- a/cmd/nvidia-container-runtime-hook/container_config_test.go +++ b/cmd/nvidia-container-runtime-hook/container_config_test.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "path/filepath" "testing" @@ -906,17 +907,20 @@ func TestGetDevicesFromEnvvar(t *testing.T) { expectedDevices: &anotherGPUID, }, { - description: "First available swarm resource envvar is selected and overrides NVIDIA_VISIBLE_DEVICES if present", + description: "All available swarm resource envvars are selected and override NVIDIA_VISIBLE_DEVICES if present", swarmResourceEnvvars: []string{"DOCKER_RESOURCE_GPUS", "DOCKER_RESOURCE_GPUS_ADDITIONAL"}, env: map[string]string{ envNVVisibleDevices: gpuID, "DOCKER_RESOURCE_GPUS": thirdGPUID, "DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID, }, - expectedDevices: &thirdGPUID, + expectedDevices: func() *string { + result := fmt.Sprintf("%s,%s", thirdGPUID, anotherGPUID) + return &result + }(), }, { - description: "DOCKER_RESOURCE_GPUS_ADDITIONAL or DOCKER_RESOURCE_GPUS overrides NVIDIA_VISIBLE_DEVICES if present", + description: "DOCKER_RESOURCE_GPUS_ADDITIONAL or DOCKER_RESOURCE_GPUS override NVIDIA_VISIBLE_DEVICES if present", swarmResourceEnvvars: []string{"DOCKER_RESOURCE_GPUS", "DOCKER_RESOURCE_GPUS_ADDITIONAL"}, env: map[string]string{ envNVVisibleDevices: gpuID, diff --git a/internal/config/image/cuda_image.go b/internal/config/image/cuda_image.go index ebd3fdd2..b30e27fa 100644 --- a/internal/config/image/cuda_image.go +++ b/internal/config/image/cuda_image.go @@ -115,28 +115,35 @@ func (i CUDA) HasDisableRequire() bool { // DevicesFromEnvvars returns the devices requested by the image through environment variables func (i CUDA) DevicesFromEnvvars(envVars ...string) VisibleDevices { - // Grab a reference to devices from the first envvar - // in the list that actually exists in the environment. - var devices *string + // We concantenate all the devices from the specified envvars. + var isSet bool + var devices []string + requested := make(map[string]bool) for _, envVar := range envVars { if devs, ok := i[envVar]; ok { - devices = &devs - break + isSet = true + for _, d := range strings.Split(devs, ",") { + trimmed := strings.TrimSpace(d) + if len(trimmed) == 0 { + continue + } + devices = append(devices, trimmed) + requested[trimmed] = true + } } } // Environment variable unset with legacy image: default to "all". - if devices == nil && i.IsLegacy() { + if !isSet && len(devices) == 0 && i.IsLegacy() { return newVisibleDevices("all") } // Environment variable unset or empty or "void": return nil - if devices == nil || len(*devices) == 0 || *devices == "void" { + if len(devices) == 0 || requested["void"] { return newVisibleDevices("void") } - // Environment variable set to "none": reset to "". - return newVisibleDevices(*devices) + return newVisibleDevices(devices...) } // GetDriverCapabilities returns the requested driver capabilities. diff --git a/internal/config/image/devices.go b/internal/config/image/devices.go index 6f3d00b6..125bcc0e 100644 --- a/internal/config/image/devices.go +++ b/internal/config/image/devices.go @@ -33,18 +33,20 @@ var _ VisibleDevices = (*void)(nil) var _ VisibleDevices = (*devices)(nil) // newVisibleDevices creates a VisibleDevices based on the value of the specified envvar. -func newVisibleDevices(envvar string) VisibleDevices { - if envvar == "all" { - return all{} - } - if envvar == "none" { - return none{} - } - if envvar == "" || envvar == "void" { - return void{} +func newVisibleDevices(envvars ...string) VisibleDevices { + for _, envvar := range envvars { + if envvar == "all" { + return all{} + } + if envvar == "none" { + return none{} + } + if envvar == "" || envvar == "void" { + return void{} + } } - return newDevices(envvar) + return newDevices(envvars...) } type all struct{}