From 98a860c5646ab49570f12d4fe5ba7cbd71f897ef Mon Sep 17 00:00:00 2001 From: Carlos Eduardo Arango Gutierrez Date: Wed, 28 May 2025 07:33:27 +0200 Subject: [PATCH] Unify GetDevices logic at internal/config/image Signed-off-by: Carlos Eduardo Arango Gutierrez Signed-off-by: Evan Lezar --- .../container_config.go | 90 ++++++------------- .../container_config_test.go | 39 +++++--- .../hook_config.go | 20 ----- .../hook_config_test.go | 50 ----------- internal/config/image/builder.go | 56 ++++++++++-- internal/config/image/cuda_image.go | 54 ++++++++--- internal/config/image/cuda_image_test.go | 6 +- internal/config/image/privileged.go | 39 +++++--- internal/config/image/privileged_test.go | 57 ++++++++++++ internal/modifier/cdi.go | 4 +- internal/modifier/csv.go | 2 +- internal/modifier/gated.go | 2 +- internal/modifier/graphics.go | 2 +- internal/modifier/graphics_test.go | 1 + 14 files changed, 238 insertions(+), 184 deletions(-) create mode 100644 internal/config/image/privileged_test.go diff --git a/cmd/nvidia-container-runtime-hook/container_config.go b/cmd/nvidia-container-runtime-hook/container_config.go index ae3aca75..7e9bd8d3 100644 --- a/cmd/nvidia-container-runtime-hook/container_config.go +++ b/cmd/nvidia-container-runtime-hook/container_config.go @@ -13,10 +13,6 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" ) -const ( - capSysAdmin = "CAP_SYS_ADMIN" -) - type nvidiaConfig struct { Devices []string MigConfigDevices string @@ -103,9 +99,9 @@ func loadSpec(path string) (spec *Spec) { return } -func isPrivileged(s *Spec) bool { - if s.Process.Capabilities == nil { - return false +func (s *Spec) GetCapabilities() []string { + if s == nil || s.Process == nil || s.Process.Capabilities == nil { + return nil } var caps []string @@ -118,67 +114,22 @@ func isPrivileged(s *Spec) bool { if err != nil { log.Panicln("could not decode Process.Capabilities in OCI spec:", err) } - for _, c := range caps { - if c == capSysAdmin { - return true - } - } - return false + return caps } // Otherwise, parse s.Process.Capabilities as: // github.com/opencontainers/runtime-spec/blob/v1.0.0/specs-go/config.go#L30-L54 - process := specs.Process{ - Env: s.Process.Env, - } - - err := json.Unmarshal(*s.Process.Capabilities, &process.Capabilities) + capabilities := specs.LinuxCapabilities{} + err := json.Unmarshal(*s.Process.Capabilities, &capabilities) if err != nil { log.Panicln("could not decode Process.Capabilities in OCI spec:", err) } - fullSpec := specs.Spec{ - Version: *s.Version, - Process: &process, - } - - return image.IsPrivileged(&fullSpec) + return image.OCISpecCapabilities(capabilities).GetCapabilities() } -func getDevicesFromEnvvar(containerImage image.CUDA, swarmResourceEnvvars []string) []string { - // We check if the image has at least one of the Swarm resource envvars defined and use this - // if specified. - for _, envvar := range swarmResourceEnvvars { - if containerImage.HasEnvvar(envvar) { - return containerImage.DevicesFromEnvvars(swarmResourceEnvvars...).List() - } - } - - return containerImage.VisibleDevicesFromEnvVar() -} - -func (hookConfig *hookConfig) getDevices(image image.CUDA, privileged bool) []string { - // If enabled, try and get the device list from volume mounts first - if hookConfig.AcceptDeviceListAsVolumeMounts { - devices := image.VisibleDevicesFromMounts() - if len(devices) > 0 { - return devices - } - } - - // Fallback to reading from the environment variable if privileges are correct - devices := getDevicesFromEnvvar(image, hookConfig.getSwarmResourceEnvvars()) - if len(devices) == 0 { - return nil - } - if privileged || hookConfig.AcceptEnvvarUnprivileged { - return devices - } - - configName := hookConfig.getConfigOption("AcceptEnvvarUnprivileged") - log.Printf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES (privileged=%v, %v=%v) ", privileged, configName, hookConfig.AcceptEnvvarUnprivileged) - - return nil +func isPrivileged(s *Spec) bool { + return image.IsPrivileged(s) } func getMigConfigDevices(i image.CUDA) *string { @@ -225,7 +176,6 @@ func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacy // We use the default driver capabilities by default. This is filtered to only include the // supported capabilities supportedDriverCapabilities := image.NewDriverCapabilities(hookConfig.SupportedDriverCapabilities) - capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities) capsEnvSpecified := cudaImage.HasEnvvar(image.EnvVarNvidiaDriverCapabilities) @@ -251,7 +201,7 @@ func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacy func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool) *nvidiaConfig { legacyImage := image.IsLegacy() - devices := hookConfig.getDevices(image, privileged) + devices := image.VisibleDevices() if len(devices) == 0 { // empty devices means this is not a GPU container. return nil @@ -306,20 +256,30 @@ func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) { s := loadSpec(path.Join(b, "config.json")) - image, err := image.New( + privileged := isPrivileged(s) + + opts := []image.Option{ image.WithEnv(s.Process.Env), image.WithMounts(s.Mounts), + image.WithPrivileged(privileged), image.WithDisableRequire(hookConfig.DisableRequire), - ) + image.WithAcceptDeviceListAsVolumeMounts(hookConfig.AcceptDeviceListAsVolumeMounts), + image.WithAcceptEnvvarUnprivileged(hookConfig.AcceptEnvvarUnprivileged), + } + + if hookConfig.SwarmResource == "" { + opts = append(opts, image.WithVisibleDevicesEnvVars(hookConfig.SwarmResource)) + } + + i, err := image.New(opts...) if err != nil { log.Panicln(err) } - privileged := isPrivileged(s) return containerConfig{ Pid: h.Pid, Rootfs: s.Root.Path, - Image: image, - Nvidia: hookConfig.getNvidiaConfig(image, privileged), + Image: i, + Nvidia: hookConfig.getNvidiaConfig(i, privileged), } } diff --git a/cmd/nvidia-container-runtime-hook/container_config_test.go b/cmd/nvidia-container-runtime-hook/container_config_test.go index 1f3858b1..6a4b8194 100644 --- a/cmd/nvidia-container-runtime-hook/container_config_test.go +++ b/cmd/nvidia-container-runtime-hook/container_config_test.go @@ -477,9 +477,19 @@ func TestGetNvidiaConfig(t *testing.T) { } for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { - image, _ := image.New( + opts := []image.Option{ image.WithEnvMap(tc.env), - ) + image.WithPrivileged(tc.privileged), + image.WithAcceptEnvvarUnprivileged(true), + } + + if tc.hookConfig != nil { + if tc.hookConfig.SwarmResource != "" { + opts = append(opts, image.WithVisibleDevicesEnvVars(tc.hookConfig.SwarmResource)) + } + } + image, _ := image.New(opts...) + // Wrap the call to getNvidiaConfig() in a closure. var cfg *nvidiaConfig getConfig := func() { @@ -622,12 +632,11 @@ func TestDeviceListSourcePriority(t *testing.T) { }, ), image.WithMounts(tc.mountDevices), + image.WithPrivileged(tc.privileged), + image.WithAcceptDeviceListAsVolumeMounts(tc.acceptMounts), + image.WithAcceptEnvvarUnprivileged(tc.acceptUnprivileged), ) - defaultConfig, _ := config.GetDefault() - cfg := &hookConfig{defaultConfig} - cfg.AcceptEnvvarUnprivileged = tc.acceptUnprivileged - cfg.AcceptDeviceListAsVolumeMounts = tc.acceptMounts - devices = cfg.getDevices(image, tc.privileged) + devices = image.VisibleDevices() } // For all other tests, just grab the devices and check the results @@ -843,10 +852,20 @@ func TestGetDevicesFromEnvvar(t *testing.T) { for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { - image, _ := image.New( + opts := []image.Option{ image.WithEnvMap(tc.env), - ) - devices := getDevicesFromEnvvar(image, tc.swarmResourceEnvvars) + image.WithPrivileged(true), + image.WithAcceptDeviceListAsVolumeMounts(false), + image.WithAcceptEnvvarUnprivileged(false), + } + + if len(tc.swarmResourceEnvvars) > 0 { + opts = append(opts, image.WithVisibleDevicesEnvVars(tc.swarmResourceEnvvars...)) + } + + image, _ := image.New(opts...) + + devices := image.VisibleDevices() require.EqualValues(t, tc.expectedDevices, devices) }) } diff --git a/cmd/nvidia-container-runtime-hook/hook_config.go b/cmd/nvidia-container-runtime-hook/hook_config.go index 3ad8e3b1..bca0be0d 100644 --- a/cmd/nvidia-container-runtime-hook/hook_config.go +++ b/cmd/nvidia-container-runtime-hook/hook_config.go @@ -6,7 +6,6 @@ import ( "os" "path" "reflect" - "strings" "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" @@ -86,25 +85,6 @@ func (c hookConfig) getConfigOption(fieldName string) string { return v } -// getSwarmResourceEnvvars returns the swarm resource envvars for the config. -func (c *hookConfig) getSwarmResourceEnvvars() []string { - if c.SwarmResource == "" { - return nil - } - - candidates := strings.Split(c.SwarmResource, ",") - - var envvars []string - for _, c := range candidates { - trimmed := strings.TrimSpace(c) - if len(trimmed) > 0 { - envvars = append(envvars, trimmed) - } - } - - return envvars -} - // nvidiaContainerCliCUDACompatModeFlags returns required --cuda-compat-mode // flag(s) depending on the hook and runtime configurations. func (c *hookConfig) nvidiaContainerCliCUDACompatModeFlags() []string { diff --git a/cmd/nvidia-container-runtime-hook/hook_config_test.go b/cmd/nvidia-container-runtime-hook/hook_config_test.go index 19147ecf..744e19b2 100644 --- a/cmd/nvidia-container-runtime-hook/hook_config_test.go +++ b/cmd/nvidia-container-runtime-hook/hook_config_test.go @@ -23,7 +23,6 @@ import ( "github.com/stretchr/testify/require" - "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" ) @@ -107,52 +106,3 @@ func TestGetHookConfig(t *testing.T) { }) } } - -func TestGetSwarmResourceEnvvars(t *testing.T) { - testCases := []struct { - value string - expected []string - }{ - { - value: "", - expected: nil, - }, - { - value: " ", - expected: nil, - }, - { - value: "single", - expected: []string{"single"}, - }, - { - value: "single ", - expected: []string{"single"}, - }, - { - value: "one,two", - expected: []string{"one", "two"}, - }, - { - value: "one ,two", - expected: []string{"one", "two"}, - }, - { - value: "one, two", - expected: []string{"one", "two"}, - }, - } - - for i, tc := range testCases { - t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - c := &hookConfig{ - Config: &config.Config{ - SwarmResource: tc.value, - }, - } - - envvars := c.getSwarmResourceEnvvars() - require.EqualValues(t, tc.expected, envvars) - }) - } -} diff --git a/internal/config/image/builder.go b/internal/config/image/builder.go index 332d017a..bdcd8e28 100644 --- a/internal/config/image/builder.go +++ b/internal/config/image/builder.go @@ -24,11 +24,14 @@ import ( ) type builder struct { - env map[string]string - mounts []specs.Mount + CUDA + disableRequire bool } +// Option is a functional option for creating a CUDA image. +type Option func(*builder) error + // New creates a new CUDA image from the input options. func New(opt ...Option) (CUDA, error) { b := &builder{} @@ -50,15 +53,22 @@ func (b builder) build() (CUDA, error) { b.env[EnvVarNvidiaDisableRequire] = "true" } - c := CUDA{ - env: b.env, - mounts: b.mounts, - } - return c, nil + return b.CUDA, nil } -// Option is a functional option for creating a CUDA image. -type Option func(*builder) error +func WithAcceptDeviceListAsVolumeMounts(acceptDeviceListAsVolumeMounts bool) Option { + return func(b *builder) error { + b.acceptDeviceListAsVolumeMounts = acceptDeviceListAsVolumeMounts + return nil + } +} + +func WithAcceptEnvvarUnprivileged(acceptEnvvarUnprivileged bool) Option { + return func(b *builder) error { + b.acceptEnvvarUnprivileged = acceptEnvvarUnprivileged + return nil + } +} // WithDisableRequire sets the disable require option. func WithDisableRequire(disableRequire bool) Option { @@ -100,3 +110,31 @@ func WithMounts(mounts []specs.Mount) Option { return nil } } + +// WithPrivileged sets whether an image is privileged or not. +func WithPrivileged(isPrivileged bool) Option { + return func(b *builder) error { + b.isPrivileged = isPrivileged + return nil + } +} + +// WithVisibleDevicesEnvVars sets the swarm resource for the CUDA image. +func WithVisibleDevicesEnvVars(visibleDevicesEnvVars ...string) Option { + return func(b *builder) error { + // if resource is a single string, split it by comma + if len(visibleDevicesEnvVars) == 1 && strings.Contains(visibleDevicesEnvVars[0], ",") { + candidates := strings.Split(visibleDevicesEnvVars[0], ",") + for _, c := range candidates { + trimmed := strings.TrimSpace(c) + if len(trimmed) > 0 { + b.visibleDevicesEnvVars = append(b.visibleDevicesEnvVars, trimmed) + } + } + return nil + } + + b.visibleDevicesEnvVars = append(b.visibleDevicesEnvVars, visibleDevicesEnvVars...) + return nil + } +} diff --git a/internal/config/image/cuda_image.go b/internal/config/image/cuda_image.go index d5bbc224..f8489e62 100644 --- a/internal/config/image/cuda_image.go +++ b/internal/config/image/cuda_image.go @@ -38,8 +38,15 @@ const ( // a map of environment variable to values that can be used to perform lookups // such as requirements. type CUDA struct { - env map[string]string + visibleDevicesEnvVars []string + + env map[string]string + mounts []specs.Mount + + acceptDeviceListAsVolumeMounts bool + acceptEnvvarUnprivileged bool + isPrivileged bool } // NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec. @@ -53,12 +60,13 @@ func NewCUDAImageFromSpec(spec *specs.Spec) (CUDA, error) { return New( WithEnv(env), WithMounts(spec.Mounts), + WithPrivileged(IsPrivileged((*OCISpec)(spec))), ) } -// NewCUDAImageFromEnv creates a CUDA image from the input environment. The environment +// newCUDAImageFromEnv creates a CUDA image from the input environment. The environment // is a list of strings of the form ENVAR=VALUE. -func NewCUDAImageFromEnv(env []string) (CUDA, error) { +func newCUDAImageFromEnv(env []string) (CUDA, error) { return New(WithEnv(env)) } @@ -155,7 +163,7 @@ func (i CUDA) DevicesFromEnvvars(envVars ...string) VisibleDevices { // GetDriverCapabilities returns the requested driver capabilities. func (i CUDA) GetDriverCapabilities() DriverCapabilities { - env := i.env[EnvVarNvidiaDriverCapabilities] + env := i.Getenv(EnvVarNvidiaDriverCapabilities) capabilities := make(DriverCapabilities) for _, c := range strings.Split(env, ",") { @@ -166,7 +174,7 @@ func (i CUDA) GetDriverCapabilities() DriverCapabilities { } func (i CUDA) legacyVersion() (string, error) { - cudaVersion := i.env[EnvVarCudaVersion] + cudaVersion := i.Getenv(EnvVarCudaVersion) majorMinor, err := parseMajorMinorVersion(cudaVersion) if err != nil { return "", fmt.Errorf("invalid CUDA version %v: %v", cudaVersion, err) @@ -200,7 +208,7 @@ func parseMajorMinorVersion(version string) (string, error) { // OnlyFullyQualifiedCDIDevices returns true if all devices requested in the image are requested as CDI devices/ func (i CUDA) OnlyFullyQualifiedCDIDevices() bool { var hasCDIdevice bool - for _, device := range i.VisibleDevicesFromEnvVar() { + for _, device := range i.visibleDevicesFromEnvVar() { if !parser.IsQualifiedName(device) { return false } @@ -216,14 +224,17 @@ func (i CUDA) OnlyFullyQualifiedCDIDevices() bool { return hasCDIdevice } -// VisibleDevicesFromEnvVar returns the set of visible devices requested through +// visibleDevicesFromEnvVar returns the set of visible devices requested through // the NVIDIA_VISIBLE_DEVICES environment variable. -func (i CUDA) VisibleDevicesFromEnvVar() []string { +func (i CUDA) visibleDevicesFromEnvVar() []string { + if len(i.visibleDevicesEnvVars) > 0 { + return i.DevicesFromEnvvars(i.visibleDevicesEnvVars...).List() + } return i.DevicesFromEnvvars(EnvVarNvidiaVisibleDevices).List() } -// VisibleDevicesFromMounts returns the set of visible devices requested as mounts. -func (i CUDA) VisibleDevicesFromMounts() []string { +// visibleDevicesFromMounts returns the set of visible devices requested as mounts. +func (i CUDA) visibleDevicesFromMounts() []string { var devices []string for _, device := range i.DevicesFromMounts() { switch { @@ -238,7 +249,6 @@ func (i CUDA) VisibleDevicesFromMounts() []string { } // DevicesFromMounts returns a list of device specified as mounts. -// TODO: This should be merged with getDevicesFromMounts used in the NVIDIA Container Runtime func (i CUDA) DevicesFromMounts() []string { root := filepath.Clean(DeviceListAsVolumeMountsRoot) seen := make(map[string]bool) @@ -271,6 +281,28 @@ func (i CUDA) DevicesFromMounts() []string { return devices } +func (i CUDA) VisibleDevices() []string { + // If enabled, try and get the device list from volume mounts first + if i.acceptDeviceListAsVolumeMounts { + devices := i.visibleDevicesFromMounts() + if len(devices) > 0 { + return devices + } + } + + // Fallback to reading from the environment variable if privileges are correct + devices := i.visibleDevicesFromEnvVar() + if len(devices) == 0 { + return nil + } + + if i.isPrivileged || i.acceptEnvvarUnprivileged { + return devices + } + + return nil +} + // CDIDevicesFromMounts returns a list of CDI devices specified as mounts on the image. func (i CUDA) CDIDevicesFromMounts() []string { var devices []string diff --git a/internal/config/image/cuda_image_test.go b/internal/config/image/cuda_image_test.go index 3b77d333..a0f15899 100644 --- a/internal/config/image/cuda_image_test.go +++ b/internal/config/image/cuda_image_test.go @@ -122,7 +122,7 @@ func TestGetRequirements(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { - image, err := NewCUDAImageFromEnv(tc.env) + image, err := newCUDAImageFromEnv(tc.env) require.NoError(t, err) requirements, err := image.GetRequirements() @@ -198,7 +198,7 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) { for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { image, _ := New(WithMounts(tc.mounts)) - require.Equal(t, tc.expectedDevices, image.VisibleDevicesFromMounts()) + require.Equal(t, tc.expectedDevices, image.visibleDevicesFromMounts()) }) } } @@ -224,7 +224,7 @@ func TestImexChannelsFromEnvVar(t *testing.T) { for _, tc := range testCases { for id, baseEnvvars := range map[string][]string{"": nil, "legacy": {"CUDA_VERSION=1.2.3"}} { t.Run(tc.description+id, func(t *testing.T) { - i, err := NewCUDAImageFromEnv(append(baseEnvvars, tc.env...)) + i, err := newCUDAImageFromEnv(append(baseEnvvars, tc.env...)) require.NoError(t, err) channels := i.ImexChannelsFromEnvVar() diff --git a/internal/config/image/privileged.go b/internal/config/image/privileged.go index a54598d6..fa67ec1a 100644 --- a/internal/config/image/privileged.go +++ b/internal/config/image/privileged.go @@ -16,28 +16,45 @@ package image -import ( - "github.com/opencontainers/runtime-spec/specs-go" -) +import "github.com/opencontainers/runtime-spec/specs-go" const ( capSysAdmin = "CAP_SYS_ADMIN" ) +type CapabilitiesGetter interface { + GetCapabilities() []string +} + +type OCISpec specs.Spec + +type OCISpecCapabilities specs.LinuxCapabilities + // IsPrivileged returns true if the container is a privileged container. -func IsPrivileged(s *specs.Spec) bool { - if s.Process.Capabilities == nil { +func IsPrivileged(s CapabilitiesGetter) bool { + if s == nil { return false } - - // We only make sure that the bounding capabibility set has - // CAP_SYS_ADMIN. This allows us to make sure that the container was - // actually started as '--privileged', but also allow non-root users to - // access the privileged NVIDIA capabilities. - for _, c := range s.Process.Capabilities.Bounding { + for _, c := range s.GetCapabilities() { if c == capSysAdmin { return true } } + return false } + +func (s OCISpec) GetCapabilities() []string { + if s.Process == nil || s.Process.Capabilities == nil { + return nil + } + return (*OCISpecCapabilities)(s.Process.Capabilities).GetCapabilities() +} + +func (c OCISpecCapabilities) GetCapabilities() []string { + // We only make sure that the bounding capabibility set has + // CAP_SYS_ADMIN. This allows us to make sure that the container was + // actually started as '--privileged', but also allow non-root users to + // access the privileged NVIDIA capabilities. + return c.Bounding +} diff --git a/internal/config/image/privileged_test.go b/internal/config/image/privileged_test.go new file mode 100644 index 00000000..328c496a --- /dev/null +++ b/internal/config/image/privileged_test.go @@ -0,0 +1,57 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package image + +import ( + "testing" + + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/stretchr/testify/require" +) + +func TestIsPrivileged(t *testing.T) { + var tests = []struct { + spec specs.Spec + expected bool + }{ + { + specs.Spec{ + Process: &specs.Process{ + Capabilities: &specs.LinuxCapabilities{ + Bounding: []string{"CAP_SYS_ADMIN"}, + }, + }, + }, + true, + }, + { + specs.Spec{ + Process: &specs.Process{ + Capabilities: &specs.LinuxCapabilities{ + Bounding: []string{"CAP_SYS_FOO"}, + }, + }, + }, + false, + }, + } + for i, tc := range tests { + privileged := IsPrivileged((*OCISpec)(&tc.spec)) + + require.Equal(t, tc.expected, privileged, "%d: %v", i, tc) + } +} diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index 90cd481b..bfea1964 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -92,7 +92,7 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C var devices []string seen := make(map[string]bool) - for _, name := range container.VisibleDevicesFromEnvVar() { + for _, name := range container.VisibleDevices() { if !parser.IsQualifiedName(name) { name = fmt.Sprintf("%s=%s", cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind, name) } @@ -107,7 +107,7 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C return nil, nil } - if cfg.AcceptEnvvarUnprivileged || image.IsPrivileged(rawSpec) { + if cfg.AcceptEnvvarUnprivileged || image.IsPrivileged((*image.OCISpec)(rawSpec)) { return devices, nil } diff --git a/internal/modifier/csv.go b/internal/modifier/csv.go index d3dc84e7..1f8a12f8 100644 --- a/internal/modifier/csv.go +++ b/internal/modifier/csv.go @@ -33,7 +33,7 @@ import ( // NewCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper. // The modifications are defined by CSV MountSpecs. func NewCSVModifier(logger logger.Interface, cfg *config.Config, container image.CUDA) (oci.SpecModifier, error) { - if devices := container.VisibleDevicesFromEnvVar(); len(devices) == 0 { + if devices := container.VisibleDevices(); len(devices) == 0 { logger.Infof("No modification required; no devices requested") return nil, nil } diff --git a/internal/modifier/gated.go b/internal/modifier/gated.go index a0239df8..e96946a4 100644 --- a/internal/modifier/gated.go +++ b/internal/modifier/gated.go @@ -37,7 +37,7 @@ import ( // // If not devices are selected, no changes are made. func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) { - if devices := image.VisibleDevicesFromEnvVar(); len(devices) == 0 { + if devices := image.VisibleDevices(); len(devices) == 0 { logger.Infof("No modification required; no devices requested") return nil, nil } diff --git a/internal/modifier/graphics.go b/internal/modifier/graphics.go index 6e602d7a..e949bd27 100644 --- a/internal/modifier/graphics.go +++ b/internal/modifier/graphics.go @@ -65,7 +65,7 @@ func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerI // requiresGraphicsModifier determines whether a graphics modifier is required. func requiresGraphicsModifier(cudaImage image.CUDA) (bool, string) { - if devices := cudaImage.VisibleDevicesFromEnvVar(); len(devices) == 0 { + if devices := cudaImage.VisibleDevices(); len(devices) == 0 { return false, "no devices requested" } diff --git a/internal/modifier/graphics_test.go b/internal/modifier/graphics_test.go index 186af48a..8bf3a16c 100644 --- a/internal/modifier/graphics_test.go +++ b/internal/modifier/graphics_test.go @@ -92,6 +92,7 @@ func TestGraphicsModifier(t *testing.T) { t.Run(tc.description, func(t *testing.T) { image, _ := image.New( image.WithEnvMap(tc.envmap), + image.WithAcceptEnvvarUnprivileged(true), ) required, _ := requiresGraphicsModifier(image) require.EqualValues(t, tc.expectedRequired, required)