From b5bea4839aba471ddd0f9c430e55739cb24c8eea Mon Sep 17 00:00:00 2001 From: Carlos Eduardo Arango Gutierrez Date: Wed, 28 May 2025 07:33:27 +0200 Subject: [PATCH] Unify GetDevices logic at internal/config/image Signed-off-by: Carlos Eduardo Arango Gutierrez --- .../container_config.go | 112 +++++++----------- .../container_config_test.go | 30 ++++- .../hook_test.go | 89 -------------- internal/config/image/builder.go | 63 +++++++++- internal/config/image/cuda_image.go | 35 +++++- internal/config/image/privileged.go | 10 +- internal/config/image/privileged_test.go | 60 ++++++++++ internal/modifier/cdi.go | 2 +- internal/modifier/gated.go | 5 +- 9 files changed, 231 insertions(+), 175 deletions(-) delete mode 100644 cmd/nvidia-container-runtime-hook/hook_test.go create mode 100644 internal/config/image/privileged_test.go diff --git a/cmd/nvidia-container-runtime-hook/container_config.go b/cmd/nvidia-container-runtime-hook/container_config.go index ae3aca75..864b4a94 100644 --- a/cmd/nvidia-container-runtime-hook/container_config.go +++ b/cmd/nvidia-container-runtime-hook/container_config.go @@ -81,7 +81,9 @@ type HookState struct { BundlePath string `json:"bundlePath"` } -func loadSpec(path string) (spec *Spec) { +func loadSpec(path string) *specs.Spec { + var spec Spec + f, err := os.Open(path) if err != nil { log.Panicln("could not open OCI spec:", err) @@ -100,85 +102,57 @@ func loadSpec(path string) (spec *Spec) { if spec.Root == nil { log.Panicln("Root is empty in OCI spec") } - return -} -func isPrivileged(s *Spec) bool { - if s.Process.Capabilities == nil { - return false + process := specs.Process{ + Env: spec.Process.Env, } - var caps []string // If v1.0.0-rc1 <= OCI version < v1.0.0-rc5 parse s.Process.Capabilities as: // github.com/opencontainers/runtime-spec/blob/v1.0.0-rc1/specs-go/config.go#L30-L54 - rc1cmp := semver.Compare("v"+*s.Version, "v1.0.0-rc1") - rc5cmp := semver.Compare("v"+*s.Version, "v1.0.0-rc5") + rc1cmp := semver.Compare("v"+*spec.Version, "v1.0.0-rc1") + rc5cmp := semver.Compare("v"+*spec.Version, "v1.0.0-rc5") if (rc1cmp == 1 || rc1cmp == 0) && (rc5cmp == -1) { - err := json.Unmarshal(*s.Process.Capabilities, &caps) + err := json.Unmarshal(*spec.Process.Capabilities, &caps) if err != nil { log.Panicln("could not decode Process.Capabilities in OCI spec:", err) } for _, c := range caps { if c == capSysAdmin { - return true + process.Capabilities = &specs.LinuxCapabilities{ + Bounding: caps, + } + break } } - return false + } else { + // Otherwise, parse s.Process.Capabilities as: + // github.com/opencontainers/runtime-spec/blob/v1.0.0/specs-go/config.go#L30-L54 + err := json.Unmarshal(*spec.Process.Capabilities, &process.Capabilities) + if err != nil { + log.Panicln("could not decode Process.Capabilities in OCI spec:", err) + } } - // Otherwise, parse s.Process.Capabilities as: - // github.com/opencontainers/runtime-spec/blob/v1.0.0/specs-go/config.go#L30-L54 - process := specs.Process{ - Env: s.Process.Env, + root := specs.Root{ + Path: spec.Root.Path, } - err := json.Unmarshal(*s.Process.Capabilities, &process.Capabilities) - if err != nil { - log.Panicln("could not decode Process.Capabilities in OCI spec:", err) + mounts := make([]specs.Mount, len(spec.Mounts)) + for i, m := range spec.Mounts { + mounts[i] = specs.Mount{ + Source: m.Source, + Destination: m.Destination, + Type: m.Type, + Options: m.Options, + } } - fullSpec := specs.Spec{ - Version: *s.Version, + return &specs.Spec{ + Version: *spec.Version, Process: &process, + Root: &root, + Mounts: mounts, } - - return image.IsPrivileged(&fullSpec) -} - -func getDevicesFromEnvvar(containerImage image.CUDA, swarmResourceEnvvars []string) []string { - // We check if the image has at least one of the Swarm resource envvars defined and use this - // if specified. - for _, envvar := range swarmResourceEnvvars { - if containerImage.HasEnvvar(envvar) { - return containerImage.DevicesFromEnvvars(swarmResourceEnvvars...).List() - } - } - - return containerImage.VisibleDevicesFromEnvVar() -} - -func (hookConfig *hookConfig) getDevices(image image.CUDA, privileged bool) []string { - // If enabled, try and get the device list from volume mounts first - if hookConfig.AcceptDeviceListAsVolumeMounts { - devices := image.VisibleDevicesFromMounts() - if len(devices) > 0 { - return devices - } - } - - // Fallback to reading from the environment variable if privileges are correct - devices := getDevicesFromEnvvar(image, hookConfig.getSwarmResourceEnvvars()) - if len(devices) == 0 { - return nil - } - if privileged || hookConfig.AcceptEnvvarUnprivileged { - return devices - } - - configName := hookConfig.getConfigOption("AcceptEnvvarUnprivileged") - log.Printf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES (privileged=%v, %v=%v) ", privileged, configName, hookConfig.AcceptEnvvarUnprivileged) - - return nil } func getMigConfigDevices(i image.CUDA) *string { @@ -225,7 +199,6 @@ func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacy // We use the default driver capabilities by default. This is filtered to only include the // supported capabilities supportedDriverCapabilities := image.NewDriverCapabilities(hookConfig.SupportedDriverCapabilities) - capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities) capsEnvSpecified := cudaImage.HasEnvvar(image.EnvVarNvidiaDriverCapabilities) @@ -251,7 +224,7 @@ func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacy func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool) *nvidiaConfig { legacyImage := image.IsLegacy() - devices := hookConfig.getDevices(image, privileged) + devices := image.GetDevices(hookConfig.AcceptDeviceListAsVolumeMounts, hookConfig.AcceptEnvvarUnprivileged) if len(devices) == 0 { // empty devices means this is not a GPU container. return nil @@ -305,21 +278,26 @@ func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) { } s := loadSpec(path.Join(b, "config.json")) - - image, err := image.New( + opts := []image.Option{ image.WithEnv(s.Process.Env), image.WithMounts(s.Mounts), + image.WithSpec(s), image.WithDisableRequire(hookConfig.DisableRequire), - ) + } + + if len(hookConfig.getSwarmResourceEnvvars()) > 0 { + opts = append(opts, image.WithSwarmResource(hookConfig.getSwarmResourceEnvvars()...)) + } + + i, err := image.New(opts...) if err != nil { log.Panicln(err) } - privileged := isPrivileged(s) return containerConfig{ Pid: h.Pid, Rootfs: s.Root.Path, - Image: image, - Nvidia: hookConfig.getNvidiaConfig(image, privileged), + Image: i, + Nvidia: hookConfig.getNvidiaConfig(i, i.IsPrivileged()), } } diff --git a/cmd/nvidia-container-runtime-hook/container_config_test.go b/cmd/nvidia-container-runtime-hook/container_config_test.go index 1f3858b1..9de85619 100644 --- a/cmd/nvidia-container-runtime-hook/container_config_test.go +++ b/cmd/nvidia-container-runtime-hook/container_config_test.go @@ -477,9 +477,18 @@ func TestGetNvidiaConfig(t *testing.T) { } for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { - image, _ := image.New( + opts := []image.Option{ image.WithEnvMap(tc.env), - ) + image.WithPrivileged(tc.privileged), + } + + if tc.hookConfig != nil { + if tc.hookConfig.SwarmResource != "" { + opts = append(opts, image.WithSwarmResource(tc.hookConfig.SwarmResource)) + } + } + image, _ := image.New(opts...) + // Wrap the call to getNvidiaConfig() in a closure. var cfg *nvidiaConfig getConfig := func() { @@ -622,12 +631,13 @@ func TestDeviceListSourcePriority(t *testing.T) { }, ), image.WithMounts(tc.mountDevices), + image.WithPrivileged(tc.privileged), ) defaultConfig, _ := config.GetDefault() cfg := &hookConfig{defaultConfig} cfg.AcceptEnvvarUnprivileged = tc.acceptUnprivileged cfg.AcceptDeviceListAsVolumeMounts = tc.acceptMounts - devices = cfg.getDevices(image, tc.privileged) + devices = image.GetDevices(tc.acceptMounts, tc.acceptUnprivileged) } // For all other tests, just grab the devices and check the results @@ -843,10 +853,18 @@ func TestGetDevicesFromEnvvar(t *testing.T) { for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { - image, _ := image.New( + opts := []image.Option{ image.WithEnvMap(tc.env), - ) - devices := getDevicesFromEnvvar(image, tc.swarmResourceEnvvars) + image.WithPrivileged(true), + } + + if len(tc.swarmResourceEnvvars) > 0 { + opts = append(opts, image.WithSwarmResource(tc.swarmResourceEnvvars...)) + } + + image, _ := image.New(opts...) + + devices := image.GetDevices(false, false) require.EqualValues(t, tc.expectedDevices, devices) }) } diff --git a/cmd/nvidia-container-runtime-hook/hook_test.go b/cmd/nvidia-container-runtime-hook/hook_test.go deleted file mode 100644 index d5449bab..00000000 --- a/cmd/nvidia-container-runtime-hook/hook_test.go +++ /dev/null @@ -1,89 +0,0 @@ -package main - -import ( - "encoding/json" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestIsPrivileged(t *testing.T) { - var tests = []struct { - spec string - expected bool - }{ - { - ` - { - "ociVersion": "1.0.0", - "process": { - "capabilities": { - "bounding": [ "CAP_SYS_ADMIN" ] - } - } - } - `, - true, - }, - { - ` - { - "ociVersion": "1.0.0", - "process": { - "capabilities": { - "bounding": [ "CAP_SYS_OTHER" ] - } - } - } - `, - false, - }, - { - ` - { - "ociVersion": "1.0.0", - "process": {} - } - `, - false, - }, - { - ` - { - "ociVersion": "1.0.0-rc2-dev", - "process": { - "capabilities": [ "CAP_SYS_ADMIN" ] - } - } - `, - true, - }, - { - ` - { - "ociVersion": "1.0.0-rc2-dev", - "process": { - "capabilities": [ "CAP_SYS_OTHER" ] - } - } - `, - false, - }, - { - ` - { - "ociVersion": "1.0.0-rc2-dev", - "process": {} - } - `, - false, - }, - } - for i, tc := range tests { - var spec Spec - _ = json.Unmarshal([]byte(tc.spec), &spec) - privileged := isPrivileged(&spec) - - require.Equal(t, tc.expected, privileged, "%d: %v", i, tc) - } -} diff --git a/internal/config/image/builder.go b/internal/config/image/builder.go index 332d017a..7632a2f9 100644 --- a/internal/config/image/builder.go +++ b/internal/config/image/builder.go @@ -24,9 +24,12 @@ import ( ) type builder struct { + spec *specs.Spec env map[string]string mounts []specs.Mount disableRequire bool + + swarmResourceEnvvars []string } // New creates a new CUDA image from the input options. @@ -51,8 +54,10 @@ func (b builder) build() (CUDA, error) { } c := CUDA{ - env: b.env, - mounts: b.mounts, + env: b.env, + mounts: b.mounts, + spec: b.spec, + swarmResourceEnvvars: b.swarmResourceEnvvars, } return c, nil } @@ -100,3 +105,57 @@ func WithMounts(mounts []specs.Mount) Option { return nil } } + +// WithSpec sets the OCI spec to use when creating the CUDA image. +func WithSpec(s *specs.Spec) Option { + return func(b *builder) error { + b.spec = s + return nil + } +} + +// WithSwarmResource sets the swarm resource for the CUDA image. +func WithSwarmResource(resource ...string) Option { + return func(b *builder) error { + if len(resource) == 0 { + return fmt.Errorf("swarm resource cannot be empty") + } + b.swarmResourceEnvvars = []string{} + // if resource is a single string, split it by comma + if len(resource) == 1 && strings.Contains(resource[0], ",") { + candidates := strings.Split(resource[0], ",") + for _, c := range candidates { + trimmed := strings.TrimSpace(c) + if len(trimmed) > 0 { + b.swarmResourceEnvvars = append(b.swarmResourceEnvvars, trimmed) + } + } + return nil + } + + b.swarmResourceEnvvars = append(b.swarmResourceEnvvars, resource...) + return nil + } +} + +// WithPrivileged sets the privileged option for the CUDA image. +// This is to allow testing the privileged mode of the container. +// DO NOT USE THIS IN PRODUCTION CODE. FOR TESTING PURPOSES ONLY. +func WithPrivileged(privileged bool) Option { + return func(b *builder) error { + b.spec = &specs.Spec{ + Process: &specs.Process{ + Capabilities: &specs.LinuxCapabilities{ + Bounding: []string{"CAP_SYS_FOO"}, + }, + }, + } + if privileged { + if b.spec.Process.Capabilities == nil { + b.spec.Process.Capabilities = &specs.LinuxCapabilities{} + } + b.spec.Process.Capabilities.Bounding = append(b.spec.Process.Capabilities.Bounding, "CAP_SYS_ADMIN") + } + return nil + } +} diff --git a/internal/config/image/cuda_image.go b/internal/config/image/cuda_image.go index d5bbc224..d7478863 100644 --- a/internal/config/image/cuda_image.go +++ b/internal/config/image/cuda_image.go @@ -38,8 +38,11 @@ const ( // a map of environment variable to values that can be used to perform lookups // such as requirements. type CUDA struct { + spec *specs.Spec env map[string]string mounts []specs.Mount + + swarmResourceEnvvars []string } // NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec. @@ -53,6 +56,7 @@ func NewCUDAImageFromSpec(spec *specs.Spec) (CUDA, error) { return New( WithEnv(env), WithMounts(spec.Mounts), + WithSpec(spec), ) } @@ -83,6 +87,11 @@ func (i CUDA) IsLegacy() bool { return len(legacyCudaVersion) > 0 && len(cudaRequire) == 0 } +// IsSwarm returns whether the image is a Docker Swarm image. +func (i CUDA) IsSwarmResource() bool { + return len(i.swarmResourceEnvvars) > 0 +} + // GetRequirements returns the requirements from all NVIDIA_REQUIRE_ environment // variables. func (i CUDA) GetRequirements() ([]string, error) { @@ -219,6 +228,9 @@ func (i CUDA) OnlyFullyQualifiedCDIDevices() bool { // VisibleDevicesFromEnvVar returns the set of visible devices requested through // the NVIDIA_VISIBLE_DEVICES environment variable. func (i CUDA) VisibleDevicesFromEnvVar() []string { + if i.IsSwarmResource() { + return i.DevicesFromEnvvars(i.swarmResourceEnvvars...).List() + } return i.DevicesFromEnvvars(EnvVarNvidiaVisibleDevices).List() } @@ -238,7 +250,6 @@ func (i CUDA) VisibleDevicesFromMounts() []string { } // DevicesFromMounts returns a list of device specified as mounts. -// TODO: This should be merged with getDevicesFromMounts used in the NVIDIA Container Runtime func (i CUDA) DevicesFromMounts() []string { root := filepath.Clean(DeviceListAsVolumeMountsRoot) seen := make(map[string]bool) @@ -271,6 +282,28 @@ func (i CUDA) DevicesFromMounts() []string { return devices } +func (i CUDA) GetDevices(acceptDeviceListAsVolumeMounts, acceptEnvvarUnprivileged bool) []string { + // If enabled, try and get the device list from volume mounts first + if acceptDeviceListAsVolumeMounts { + devices := i.VisibleDevicesFromMounts() + if len(devices) > 0 { + return devices + } + } + + // Fallback to reading from the environment variable if privileges are correct + devices := i.VisibleDevicesFromEnvVar() + if len(devices) == 0 { + return nil + } + + if i.IsPrivileged() || acceptEnvvarUnprivileged { + return devices + } + + return nil +} + // CDIDevicesFromMounts returns a list of CDI devices specified as mounts on the image. func (i CUDA) CDIDevicesFromMounts() []string { var devices []string diff --git a/internal/config/image/privileged.go b/internal/config/image/privileged.go index a54598d6..5979cebf 100644 --- a/internal/config/image/privileged.go +++ b/internal/config/image/privileged.go @@ -16,17 +16,13 @@ package image -import ( - "github.com/opencontainers/runtime-spec/specs-go" -) - const ( capSysAdmin = "CAP_SYS_ADMIN" ) // IsPrivileged returns true if the container is a privileged container. -func IsPrivileged(s *specs.Spec) bool { - if s.Process.Capabilities == nil { +func (i CUDA) IsPrivileged() bool { + if i.spec.Process.Capabilities == nil { return false } @@ -34,7 +30,7 @@ func IsPrivileged(s *specs.Spec) bool { // CAP_SYS_ADMIN. This allows us to make sure that the container was // actually started as '--privileged', but also allow non-root users to // access the privileged NVIDIA capabilities. - for _, c := range s.Process.Capabilities.Bounding { + for _, c := range i.spec.Process.Capabilities.Bounding { if c == capSysAdmin { return true } diff --git a/internal/config/image/privileged_test.go b/internal/config/image/privileged_test.go new file mode 100644 index 00000000..009371f6 --- /dev/null +++ b/internal/config/image/privileged_test.go @@ -0,0 +1,60 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package image + +import ( + "testing" + + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/stretchr/testify/require" +) + +func TestIsPrivileged(t *testing.T) { + var tests = []struct { + spec specs.Spec + expected bool + }{ + { + specs.Spec{ + Process: &specs.Process{ + Capabilities: &specs.LinuxCapabilities{ + Bounding: []string{"CAP_SYS_ADMIN"}, + }, + }, + }, + true, + }, + { + specs.Spec{ + Process: &specs.Process{ + Capabilities: &specs.LinuxCapabilities{ + Bounding: []string{"CAP_SYS_FOO"}, + }, + }, + }, + false, + }, + } + for i, tc := range tests { + image, _ := New( + WithSpec(&tc.spec), + ) + privileged := image.IsPrivileged() + + require.Equal(t, tc.expected, privileged, "%d: %v", i, tc) + } +} diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index 90cd481b..87fa597a 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -107,7 +107,7 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C return nil, nil } - if cfg.AcceptEnvvarUnprivileged || image.IsPrivileged(rawSpec) { + if cfg.AcceptEnvvarUnprivileged || container.IsPrivileged() { return devices, nil } diff --git a/internal/modifier/gated.go b/internal/modifier/gated.go index a0239df8..9efd0f2b 100644 --- a/internal/modifier/gated.go +++ b/internal/modifier/gated.go @@ -37,8 +37,9 @@ import ( // // If not devices are selected, no changes are made. func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) { - if devices := image.VisibleDevicesFromEnvVar(); len(devices) == 0 { - logger.Infof("No modification required; no devices requested") + devices := image.GetDevices(cfg.AcceptDeviceListAsVolumeMounts, cfg.AcceptEnvvarUnprivileged) + if len(devices) == 0 { + logger.Debugf("No modification required; no devices requested") return nil, nil }