diff --git a/internal/config/image/cuda_image.go b/internal/config/image/cuda_image.go index 7d5bf7c1..49184dbd 100644 --- a/internal/config/image/cuda_image.go +++ b/internal/config/image/cuda_image.go @@ -270,7 +270,7 @@ func (i CUDA) VisibleDevices() []string { } // Get the Fallback to reading from the environment variable if privileges are correct - envVarDeviceRequests := i.VisibleDevicesFromEnvVar() + envVarDeviceRequests := i.visibleDevicesFromEnvVar() if len(envVarDeviceRequests) == 0 { return nil } @@ -322,11 +322,11 @@ func (i CUDA) cdiDeviceRequestsFromAnnotations() []string { return devices } -// VisibleDevicesFromEnvVar returns the set of visible devices requested through environment variables. +// visibleDevicesFromEnvVar returns the set of visible devices requested through environment variables. // If any of the preferredVisibleDeviceEnvVars are present in the image, they // are used to determine the visible devices. If this is not the case, the // NVIDIA_VISIBLE_DEVICES environment variable is used. -func (i CUDA) VisibleDevicesFromEnvVar() []string { +func (i CUDA) visibleDevicesFromEnvVar() []string { envVars := i.visibleEnvVars() return i.DevicesFromEnvvars(envVars...).List() } diff --git a/internal/config/image/cuda_image_test.go b/internal/config/image/cuda_image_test.go index 7302a695..aad69d31 100644 --- a/internal/config/image/cuda_image_test.go +++ b/internal/config/image/cuda_image_test.go @@ -429,7 +429,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) { ) require.NoError(t, err) - devices := image.VisibleDevicesFromEnvVar() + devices := image.visibleDevicesFromEnvVar() require.EqualValues(t, tc.expectedDevices, devices) }) } @@ -508,13 +508,15 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) { func TestVisibleDevices(t *testing.T) { var tests = []struct { - description string - mountDevices []specs.Mount - envvarDevices string - privileged bool - acceptUnprivileged bool - acceptMounts bool - expectedDevices []string + description string + mountDevices []specs.Mount + envvarDevices string + privileged bool + acceptUnprivileged bool + acceptMounts bool + preferredVisibleDeviceEnvVars []string + env map[string]string + expectedDevices []string }{ { description: "Mount devices, unprivileged, no accept unprivileged", @@ -597,20 +599,92 @@ func TestVisibleDevices(t *testing.T) { acceptMounts: false, expectedDevices: nil, }, + // New test cases for visibleEnvVars functionality + { + description: "preferred env var set and present in env, privileged", + mountDevices: nil, + envvarDevices: "", + privileged: true, + acceptUnprivileged: false, + acceptMounts: true, + preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS"}, + env: map[string]string{ + "DOCKER_RESOURCE_GPUS": "GPU-12345", + }, + expectedDevices: []string{"GPU-12345"}, + }, + { + description: "preferred env var set and present in env, unprivileged but accepted", + mountDevices: nil, + envvarDevices: "", + privileged: false, + acceptUnprivileged: true, + acceptMounts: true, + preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS"}, + env: map[string]string{ + "DOCKER_RESOURCE_GPUS": "GPU-12345", + }, + expectedDevices: []string{"GPU-12345"}, + }, + { + description: "preferred env var set and present in env, unprivileged and not accepted", + mountDevices: nil, + envvarDevices: "", + privileged: false, + acceptUnprivileged: false, + acceptMounts: true, + preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS"}, + env: map[string]string{ + "DOCKER_RESOURCE_GPUS": "GPU-12345", + }, + expectedDevices: nil, + }, + { + description: "multiple preferred env vars, both present, privileged", + mountDevices: nil, + envvarDevices: "", + privileged: true, + acceptUnprivileged: false, + acceptMounts: true, + preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS", "DOCKER_RESOURCE_GPUS_ADDITIONAL"}, + env: map[string]string{ + "DOCKER_RESOURCE_GPUS": "GPU-12345", + "DOCKER_RESOURCE_GPUS_ADDITIONAL": "GPU-67890", + }, + expectedDevices: []string{"GPU-12345", "GPU-67890"}, + }, + { + description: "preferred env var not present, fallback to NVIDIA_VISIBLE_DEVICES, privileged", + mountDevices: nil, + envvarDevices: "GPU-12345", + privileged: true, + acceptUnprivileged: false, + acceptMounts: true, + preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS"}, + env: map[string]string{ + EnvVarNvidiaVisibleDevices: "GPU-12345", + }, + expectedDevices: []string{"GPU-12345"}, + }, } for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { - // Wrap the call to getDevices() in a closure. + // Create env map with both NVIDIA_VISIBLE_DEVICES and any additional env vars + env := make(map[string]string) + if tc.envvarDevices != "" { + env[EnvVarNvidiaVisibleDevices] = tc.envvarDevices + } + for k, v := range tc.env { + env[k] = v + } + image, err := New( - WithEnvMap( - map[string]string{ - EnvVarNvidiaVisibleDevices: tc.envvarDevices, - }, - ), + WithEnvMap(env), WithMounts(tc.mountDevices), WithPrivileged(tc.privileged), WithAcceptDeviceListAsVolumeMounts(tc.acceptMounts), WithAcceptEnvvarUnprivileged(tc.acceptUnprivileged), + WithPreferredVisibleDevicesEnvVars(tc.preferredVisibleDeviceEnvVars...), ) require.NoError(t, err) require.Equal(t, tc.expectedDevices, image.VisibleDevices()) diff --git a/internal/modifier/csv.go b/internal/modifier/csv.go index d3dc84e7..1f8a12f8 100644 --- a/internal/modifier/csv.go +++ b/internal/modifier/csv.go @@ -33,7 +33,7 @@ import ( // NewCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper. // The modifications are defined by CSV MountSpecs. func NewCSVModifier(logger logger.Interface, cfg *config.Config, container image.CUDA) (oci.SpecModifier, error) { - if devices := container.VisibleDevicesFromEnvVar(); len(devices) == 0 { + if devices := container.VisibleDevices(); len(devices) == 0 { logger.Infof("No modification required; no devices requested") return nil, nil } diff --git a/internal/modifier/gated.go b/internal/modifier/gated.go index a0239df8..e96946a4 100644 --- a/internal/modifier/gated.go +++ b/internal/modifier/gated.go @@ -37,7 +37,7 @@ import ( // // If not devices are selected, no changes are made. func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) { - if devices := image.VisibleDevicesFromEnvVar(); len(devices) == 0 { + if devices := image.VisibleDevices(); len(devices) == 0 { logger.Infof("No modification required; no devices requested") return nil, nil } diff --git a/internal/modifier/graphics.go b/internal/modifier/graphics.go index 6e602d7a..cd024a89 100644 --- a/internal/modifier/graphics.go +++ b/internal/modifier/graphics.go @@ -29,9 +29,10 @@ import ( // NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification. // The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made. -func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerImage image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) { - if required, reason := requiresGraphicsModifier(containerImage); !required { - logger.Infof("No graphics modifier required: %v", reason) +func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, container image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) { + devices, reason := requiresGraphicsModifier(container) + if len(devices) == 0 { + logger.Infof("No graphics modifier required; %v", reason) return nil, nil } @@ -48,7 +49,7 @@ func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerI devRoot := driver.Root drmNodes, err := discover.NewDRMNodesDiscoverer( logger, - containerImage.DevicesFromEnvvars(image.EnvVarNvidiaVisibleDevices), + image.NewVisibleDevices(devices...), devRoot, hookCreator, ) @@ -64,14 +65,15 @@ func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerI } // requiresGraphicsModifier determines whether a graphics modifier is required. -func requiresGraphicsModifier(cudaImage image.CUDA) (bool, string) { - if devices := cudaImage.VisibleDevicesFromEnvVar(); len(devices) == 0 { - return false, "no devices requested" +func requiresGraphicsModifier(cudaImage image.CUDA) ([]string, string) { + devices := cudaImage.VisibleDevices() + if len(devices) == 0 { + return nil, "no devices requested" } if !cudaImage.GetDriverCapabilities().Any(image.DriverCapabilityGraphics, image.DriverCapabilityDisplay) { - return false, "no required capabilities requested" + return nil, "no required capabilities requested" } - return true, "" + return devices, "" } diff --git a/internal/modifier/graphics_test.go b/internal/modifier/graphics_test.go index 186af48a..540fee5a 100644 --- a/internal/modifier/graphics_test.go +++ b/internal/modifier/graphics_test.go @@ -26,9 +26,9 @@ import ( func TestGraphicsModifier(t *testing.T) { testCases := []struct { - description string - envmap map[string]string - expectedRequired bool + description string + envmap map[string]string + expectedDevices []string }{ { description: "empty image does not create modifier", @@ -52,7 +52,7 @@ func TestGraphicsModifier(t *testing.T) { "NVIDIA_VISIBLE_DEVICES": "all", "NVIDIA_DRIVER_CAPABILITIES": "all", }, - expectedRequired: true, + expectedDevices: []string{"all"}, }, { description: "devices with graphics capability creates modifier", @@ -60,7 +60,7 @@ func TestGraphicsModifier(t *testing.T) { "NVIDIA_VISIBLE_DEVICES": "all", "NVIDIA_DRIVER_CAPABILITIES": "graphics", }, - expectedRequired: true, + expectedDevices: []string{"all"}, }, { description: "devices with compute,graphics capability creates modifier", @@ -68,7 +68,7 @@ func TestGraphicsModifier(t *testing.T) { "NVIDIA_VISIBLE_DEVICES": "all", "NVIDIA_DRIVER_CAPABILITIES": "compute,graphics", }, - expectedRequired: true, + expectedDevices: []string{"all"}, }, { description: "devices with display capability creates modifier", @@ -76,7 +76,7 @@ func TestGraphicsModifier(t *testing.T) { "NVIDIA_VISIBLE_DEVICES": "all", "NVIDIA_DRIVER_CAPABILITIES": "display", }, - expectedRequired: true, + expectedDevices: []string{"all"}, }, { description: "devices with display,graphics capability creates modifier", @@ -84,7 +84,7 @@ func TestGraphicsModifier(t *testing.T) { "NVIDIA_VISIBLE_DEVICES": "all", "NVIDIA_DRIVER_CAPABILITIES": "display,graphics", }, - expectedRequired: true, + expectedDevices: []string{"all"}, }, } @@ -94,7 +94,7 @@ func TestGraphicsModifier(t *testing.T) { image.WithEnvMap(tc.envmap), ) required, _ := requiresGraphicsModifier(image) - require.EqualValues(t, tc.expectedRequired, required) + require.EqualValues(t, tc.expectedDevices, required) }) } }