Make CDI device requests consistent with other methods

Following the refactoring of device request extraction, we can
now make CDI device requests consistent with other methods.

This change moves to using image.VisibleDevices instead of
separate calls to CDIDevicesFromMounts and VisibleDevicesFromEnvVar.
The handling of annotation-based requests will be addressed in a
follow-up.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2025-06-05 15:48:09 +02:00
parent 27f5ec83de
commit 1b0f07a124
No known key found for this signature in database
5 changed files with 171 additions and 63 deletions

View File

@ -54,8 +54,12 @@ type CUDA struct {
// NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec. // NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec.
// The process environment is read (if present) to construc the CUDA Image. // The process environment is read (if present) to construc the CUDA Image.
func NewCUDAImageFromSpec(spec *specs.Spec, opts ...Option) (CUDA, error) { func NewCUDAImageFromSpec(spec *specs.Spec, opts ...Option) (CUDA, error) {
if spec == nil {
return New(opts...)
}
var env []string var env []string
if spec != nil && spec.Process != nil { if spec.Process != nil {
env = spec.Process.Env env = spec.Process.Env
} }
@ -212,19 +216,12 @@ func parseMajorMinorVersion(version string) (string, error) {
// OnlyFullyQualifiedCDIDevices returns true if all devices requested in the image are requested as CDI devices/ // OnlyFullyQualifiedCDIDevices returns true if all devices requested in the image are requested as CDI devices/
func (i CUDA) OnlyFullyQualifiedCDIDevices() bool { func (i CUDA) OnlyFullyQualifiedCDIDevices() bool {
var hasCDIdevice bool var hasCDIdevice bool
for _, device := range i.VisibleDevicesFromEnvVar() { for _, device := range i.VisibleDevices() {
if !parser.IsQualifiedName(device) { if !parser.IsQualifiedName(device) {
return false return false
} }
hasCDIdevice = true hasCDIdevice = true
} }
for _, device := range i.DevicesFromMounts() {
if !strings.HasPrefix(device, "cdi/") {
return false
}
hasCDIdevice = true
}
return hasCDIdevice return hasCDIdevice
} }
@ -276,20 +273,27 @@ func (i CUDA) VisibleDevicesFromEnvVar() []string {
// visibleDevicesFromMounts returns the set of visible devices requested as mounts. // visibleDevicesFromMounts returns the set of visible devices requested as mounts.
func (i CUDA) visibleDevicesFromMounts() []string { func (i CUDA) visibleDevicesFromMounts() []string {
var devices []string var devices []string
for _, device := range i.DevicesFromMounts() { for _, device := range i.requestsFromMounts() {
switch { switch {
case strings.HasPrefix(device, volumeMountDevicePrefixCDI):
continue
case strings.HasPrefix(device, volumeMountDevicePrefixImex): case strings.HasPrefix(device, volumeMountDevicePrefixImex):
continue continue
case strings.HasPrefix(device, volumeMountDevicePrefixCDI):
name, err := cdiDeviceMountRequest(device).qualifiedName()
if err != nil {
i.logger.Warningf("Ignoring invalid mount request for CDI device %v: %w", device, err)
continue
}
devices = append(devices, name)
default:
devices = append(devices, device)
} }
devices = append(devices, device)
} }
return devices return devices
} }
// DevicesFromMounts returns a list of device specified as mounts. // requestsFromMounts returns a list of device specified as mounts.
func (i CUDA) DevicesFromMounts() []string { func (i CUDA) requestsFromMounts() []string {
root := filepath.Clean(DeviceListAsVolumeMountsRoot) root := filepath.Clean(DeviceListAsVolumeMountsRoot)
seen := make(map[string]bool) seen := make(map[string]bool)
var devices []string var devices []string
@ -321,23 +325,30 @@ func (i CUDA) DevicesFromMounts() []string {
return devices return devices
} }
// CDIDevicesFromMounts returns a list of CDI devices specified as mounts on the image. // a cdiDeviceMountRequest represents a CDI device requests as a mount.
func (i CUDA) CDIDevicesFromMounts() []string { // Here the host path /dev/null is mounted to a particular path in the container.
var devices []string // The container path has the form:
for _, mountDevice := range i.DevicesFromMounts() { // /var/run/nvidia-container-devices/cdi/<vendor>/<class>/<device>
if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixCDI) { // or
continue // /var/run/nvidia-container-devices/cdi/<vendor>/<class>=<device>
} type cdiDeviceMountRequest string
parts := strings.SplitN(strings.TrimPrefix(mountDevice, volumeMountDevicePrefixCDI), "/", 3)
if len(parts) != 3 { // qualifiedName returns the fully-qualified name of the CDI device.
continue func (m cdiDeviceMountRequest) qualifiedName() (string, error) {
} if !strings.HasPrefix(string(m), volumeMountDevicePrefixCDI) {
vendor := parts[0] return "", fmt.Errorf("invalid mount CDI device request: %s", m)
class := parts[1]
device := parts[2]
devices = append(devices, fmt.Sprintf("%s/%s=%s", vendor, class, device))
} }
return devices
requestedDevice := strings.TrimPrefix(string(m), volumeMountDevicePrefixCDI)
if parser.IsQualifiedName(requestedDevice) {
return requestedDevice, nil
}
parts := strings.SplitN(requestedDevice, "/", 3)
if len(parts) != 3 {
return "", fmt.Errorf("invalid mount CDI device request: %s", m)
}
return fmt.Sprintf("%s/%s=%s", parts[0], parts[1], parts[2]), nil
} }
// ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image. // ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image.
@ -352,7 +363,7 @@ func (i CUDA) ImexChannelsFromEnvVar() []string {
// ImexChannelsFromMounts returns the list of IMEX channels requested for the image. // ImexChannelsFromMounts returns the list of IMEX channels requested for the image.
func (i CUDA) ImexChannelsFromMounts() []string { func (i CUDA) ImexChannelsFromMounts() []string {
var channels []string var channels []string
for _, mountDevice := range i.DevicesFromMounts() { for _, mountDevice := range i.requestsFromMounts() {
if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixImex) { if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixImex) {
continue continue
} }

View File

@ -487,9 +487,9 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) {
expectedDevices: []string{"GPU0-MIG0/0/1", "GPU1-MIG0/0/1"}, expectedDevices: []string{"GPU0-MIG0/0/1", "GPU1-MIG0/0/1"},
}, },
{ {
description: "cdi devices are ignored", description: "cdi devices are included",
mounts: makeTestMounts("GPU0", "cdi/nvidia.com/gpu=all", "GPU1"), mounts: makeTestMounts("GPU0", "nvidia.com/gpu=all", "GPU1"),
expectedDevices: []string{"GPU0", "GPU1"}, expectedDevices: []string{"GPU0", "nvidia.com/gpu=all", "GPU1"},
}, },
{ {
description: "imex devices are ignored", description: "imex devices are ignored",

View File

@ -184,7 +184,7 @@ func TestResolveAutoMode(t *testing.T) {
expectedMode: "legacy", expectedMode: "legacy",
}, },
{ {
description: "cdi mount and non-CDI envvar resolves to legacy", description: "cdi mount and non-CDI envvar resolves to cdi",
mode: "auto", mode: "auto",
envmap: map[string]string{ envmap: map[string]string{
"NVIDIA_VISIBLE_DEVICES": "0", "NVIDIA_VISIBLE_DEVICES": "0",
@ -197,6 +197,22 @@ func TestResolveAutoMode(t *testing.T) {
"tegra": false, "tegra": false,
"nvgpu": false, "nvgpu": false,
}, },
expectedMode: "cdi",
},
{
description: "non-cdi mount and CDI envvar resolves to legacy",
mode: "auto",
envmap: map[string]string{
"NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0",
},
mounts: []string{
"/var/run/nvidia-container-devices/0",
},
info: map[string]bool{
"nvml": true,
"tegra": false,
"nvgpu": false,
},
expectedMode: "legacy", expectedMode: "legacy",
}, },
} }
@ -232,6 +248,8 @@ func TestResolveAutoMode(t *testing.T) {
image, _ := image.New( image, _ := image.New(
image.WithEnvMap(tc.envmap), image.WithEnvMap(tc.envmap),
image.WithMounts(mounts), image.WithMounts(mounts),
image.WithAcceptDeviceListAsVolumeMounts(true),
image.WithAcceptEnvvarUnprivileged(true),
) )
mode := resolveMode(logger, tc.mode, image, properties) mode := resolveMode(logger, tc.mode, image, properties)
require.EqualValues(t, tc.expectedMode, mode) require.EqualValues(t, tc.expectedMode, mode)

View File

@ -66,57 +66,66 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spe
} }
func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.Config) ([]string, error) { func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.Config) ([]string, error) {
cdiModifier := &cdiModifier{
logger: logger,
acceptDeviceListAsVolumeMounts: cfg.AcceptDeviceListAsVolumeMounts,
acceptEnvvarUnprivileged: cfg.AcceptEnvvarUnprivileged,
annotationPrefixes: cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes,
defaultKind: cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind,
}
return cdiModifier.getDevicesFromSpec(ociSpec)
}
// TODO: We should rename this type.
type cdiModifier struct {
logger logger.Interface
acceptDeviceListAsVolumeMounts bool
acceptEnvvarUnprivileged bool
annotationPrefixes []string
defaultKind string
}
func (c *cdiModifier) getDevicesFromSpec(ociSpec oci.Spec) ([]string, error) {
rawSpec, err := ociSpec.Load() rawSpec, err := ociSpec.Load()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load OCI spec: %v", err) return nil, fmt.Errorf("failed to load OCI spec: %v", err)
} }
annotationDevices, err := getAnnotationDevices(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes, rawSpec.Annotations) if rawSpec != nil {
if err != nil { annotationDevices, err := getAnnotationDevices(c.annotationPrefixes, rawSpec.Annotations)
return nil, fmt.Errorf("failed to parse container annotations: %v", err) if err != nil {
} return nil, fmt.Errorf("failed to parse container annotations: %v", err)
if len(annotationDevices) > 0 { }
return annotationDevices, nil if len(annotationDevices) > 0 {
return annotationDevices, nil
}
} }
container, err := image.NewCUDAImageFromSpec( container, err := image.NewCUDAImageFromSpec(
rawSpec, rawSpec,
image.WithLogger(logger), image.WithLogger(c.logger),
image.WithAcceptDeviceListAsVolumeMounts(c.acceptDeviceListAsVolumeMounts),
image.WithAcceptEnvvarUnprivileged(c.acceptEnvvarUnprivileged),
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
if cfg.AcceptDeviceListAsVolumeMounts {
mountDevices := container.CDIDevicesFromMounts()
if len(mountDevices) > 0 {
return mountDevices, nil
}
}
var devices []string var devices []string
seen := make(map[string]bool) seen := make(map[string]bool)
for _, name := range container.VisibleDevicesFromEnvVar() { for _, name := range container.VisibleDevices() {
if !parser.IsQualifiedName(name) { if !parser.IsQualifiedName(name) {
name = fmt.Sprintf("%s=%s", cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind, name) name = fmt.Sprintf("%s=%s", c.defaultKind, name)
} }
if seen[name] { if seen[name] {
logger.Debugf("Ignoring duplicate device %q", name) c.logger.Debugf("Ignoring duplicate device %q", name)
continue continue
} }
seen[name] = true
devices = append(devices, name) devices = append(devices, name)
} }
if len(devices) == 0 { return devices, nil
return nil, nil
}
if cfg.AcceptEnvvarUnprivileged || image.IsPrivileged((*image.OCISpec)(rawSpec)) {
return devices, nil
}
logger.Warningf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES: %v", devices)
return nil, nil
} }
// getAnnotationDevices returns a list of devices specified in the annotations. // getAnnotationDevices returns a list of devices specified in the annotations.

View File

@ -20,7 +20,11 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/opencontainers/runtime-spec/specs-go"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
) )
func TestGetAnnotationDevices(t *testing.T) { func TestGetAnnotationDevices(t *testing.T) {
@ -90,3 +94,69 @@ func TestGetAnnotationDevices(t *testing.T) {
}) })
} }
} }
func TestGetDevicesFromSpec(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := []struct {
description string
input cdiModifier
spec *specs.Spec
expectedDevices []string
}{
{
description: "empty spec yields no devices",
},
{
description: "cdi devices from mounts",
input: cdiModifier{
defaultKind: "nvidia.com/gpu",
acceptEnvvarUnprivileged: true,
acceptDeviceListAsVolumeMounts: true,
},
spec: &specs.Spec{
Mounts: []specs.Mount{
{
Destination: "/var/run/nvidia-container-devices/cdi/nvidia.com/gpu/0",
Source: "/dev/null",
},
{
Destination: "/var/run/nvidia-container-devices/cdi/nvidia.com/gpu/1",
Source: "/dev/null",
},
},
},
expectedDevices: []string{"nvidia.com/gpu=0", "nvidia.com/gpu=1"},
},
{
description: "cdi devices from envvar",
input: cdiModifier{
defaultKind: "nvidia.com/gpu",
acceptEnvvarUnprivileged: true,
acceptDeviceListAsVolumeMounts: true,
},
spec: &specs.Spec{
Process: &specs.Process{
Env: []string{"NVIDIA_VISIBLE_DEVICES=0,example.com/class=device"},
},
},
expectedDevices: []string{"nvidia.com/gpu=0", "example.com/class=device"},
},
}
for _, tc := range testCases {
tc.input.logger = logger
spec := &oci.SpecMock{
LoadFunc: func() (*specs.Spec, error) {
return tc.spec, nil
},
}
t.Run(tc.description, func(t *testing.T) {
devices, err := tc.input.getDevicesFromSpec(spec)
require.NoError(t, err)
require.EqualValues(t, tc.expectedDevices, devices)
})
}
}