Create device nodes in JIT-CDI mode
Some checks failed
CI Pipeline / code-scanning (push) Has been cancelled
CI Pipeline / variables (push) Has been cancelled
CI Pipeline / golang (push) Has been cancelled
CI Pipeline / image (push) Has been cancelled
CI Pipeline / e2e-test (push) Has been cancelled

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2025-03-09 13:49:56 +02:00
parent 76b6d4d38f
commit 1cfaef4b01
No known key found for this signature in database
11 changed files with 82 additions and 53 deletions

View File

@ -45,7 +45,7 @@ func New(opts ...Option) Devices {
type Option func(*builder) type Option func(*builder)
// WithDeviceToMajor specifies an explicit device name to major number map. // WithDeviceToMajor specifies an explicit device name to major number map.
func WithDeviceToMajor(deviceToMajor map[string]int) Option { func WithDeviceToMajor(deviceToMajor map[string]uint32) Option {
return func(b *builder) { return func(b *builder) {
b.asMap = make(devices) b.asMap = make(devices)
for name, major := range deviceToMajor { for name, major := range deviceToMajor {

View File

@ -45,7 +45,7 @@ const (
type Name string type Name string
// Major represents a device major as specified under /proc/devices // Major represents a device major as specified under /proc/devices
type Major int type Major uint32
// Devices represents the set of devices under /proc/devices // Devices represents the set of devices under /proc/devices
// //
@ -130,8 +130,8 @@ func nvidiaDeviceFrom(reader io.Reader) (Devices, error) {
return nvidiaDevices, nil return nvidiaDevices, nil
} }
func devicesFrom(reader io.Reader) map[string]int { func devicesFrom(reader io.Reader) map[string]uint32 {
allDevices := make(map[string]int) allDevices := make(map[string]uint32)
scanner := bufio.NewScanner(reader) scanner := bufio.NewScanner(reader)
for scanner.Scan() { for scanner.Scan() {
device, major, err := processProcDeviceLine(scanner.Text()) device, major, err := processProcDeviceLine(scanner.Text())
@ -143,11 +143,11 @@ func devicesFrom(reader io.Reader) map[string]int {
return allDevices return allDevices
} }
func processProcDeviceLine(line string) (string, int, error) { func processProcDeviceLine(line string) (string, uint32, error) {
trimmed := strings.TrimSpace(line) trimmed := strings.TrimSpace(line)
var name string var name string
var major int var major uint32
n, _ := fmt.Sscanf(trimmed, "%d %s", &major, &name) n, _ := fmt.Sscanf(trimmed, "%d %s", &major, &name)
if n == 2 { if n == 2 {

View File

@ -22,12 +22,15 @@ import (
"tags.cncf.io/container-device-interface/pkg/parser" "tags.cncf.io/container-device-interface/pkg/parser"
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
"github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi" "github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
) )
@ -198,12 +201,14 @@ func generateAutomaticCDISpec(logger logger.Interface, cfg *config.Config, drive
logger.Warningf("Ignoring error(s) loading kernel modules: %v", err) logger.Warningf("Ignoring error(s) loading kernel modules: %v", err)
} }
identifiers := []string{} var identifiers []string
for _, device := range devices { for _, device := range devices {
_, _, id := parser.ParseDevice(device) _, _, id := parser.ParseDevice(device)
identifiers = append(identifiers, id) identifiers = append(identifiers, id)
} }
tryCreateDeviceNodes(logger, driver, identifiers...)
deviceSpecs, err := cdilib.GetDeviceSpecsByID(identifiers...) deviceSpecs, err := cdilib.GetDeviceSpecsByID(identifiers...)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get CDI device specs: %w", err) return nil, fmt.Errorf("failed to get CDI device specs: %w", err)
@ -221,3 +226,27 @@ func generateAutomaticCDISpec(logger logger.Interface, cfg *config.Config, drive
spec.WithClass("gpu"), spec.WithClass("gpu"),
) )
} }
func tryCreateDeviceNodes(logger logger.Interface, driver *root.Driver, identifiers ...string) {
devices, err := nvdevices.New(
nvdevices.WithLogger(logger),
nvdevices.WithDevRoot(driver.Root),
)
if err != nil {
logger.Warningf("Failed to create devices library: %v", err)
return
}
if err := devices.CreateNVIDIAControlDevices(); err != nil {
logger.Warningf("Failed to create control devices: %v", err)
}
if err := devices.CreateNVIDIACapsControlDeviceNodes(); err != nil {
logger.Warningf("Failed to create nvidia-caps control devices: %v", err)
}
for _, id := range identifiers {
identifier := device.Identifier(id)
if err := devices.CreateDeviceNodes(identifier); err != nil {
logger.Warningf("Error creating device nodes for %v: %v", identifier, err)
}
}
}

View File

@ -37,7 +37,7 @@ const (
) )
// MigMinor represents the minor number of a MIG device // MigMinor represents the minor number of a MIG device
type MigMinor int type MigMinor uint32
// MigCap represents the path to a MIG cap file. // MigCap represents the path to a MIG cap file.
// These are listed in /proc/driver/nvidia-caps/mig-minors and have one of the // These are listed in /proc/driver/nvidia-caps/mig-minors and have one of the
@ -144,7 +144,7 @@ func processMigMinorsLine(line string) (MigCap, MigMinor, error) {
return "", 0, fmt.Errorf("invalid MIG minors line: '%v'", line) return "", 0, fmt.Errorf("invalid MIG minors line: '%v'", line)
} }
minor, err := strconv.Atoi(parts[1]) minor, err := strconv.ParseUint(parts[1], 10, 32)
if err != nil { if err != nil {
return "", 0, fmt.Errorf("error reading MIG minor from '%v': %v", line, err) return "", 0, fmt.Errorf("error reading MIG minor from '%v': %v", line, err)
} }

View File

@ -4,9 +4,8 @@
package oci package oci
import ( import (
"sync"
"github.com/opencontainers/runtime-spec/specs-go" "github.com/opencontainers/runtime-spec/specs-go"
"sync"
) )
// Ensure, that SpecMock does implement Spec. // Ensure, that SpecMock does implement Spec.

View File

@ -59,7 +59,7 @@ func (m *Interface) CreateNVIDIACapsControlDeviceNodes() error {
continue continue
} }
deviceNodePath := migMinor.DevicePath() deviceNodePath := migMinor.DevicePath()
if err := m.createDeviceNode(deviceNodePath, int(capsMajor), int(migMinor)); err != nil { if err := m.createDeviceNode(deviceNodePath, capsMajor, uint32(migMinor)); err != nil {
errs = errors.Join(errs, fmt.Errorf("failed to create nvidia-caps device node %v: %w", deviceNodePath, err)) errs = errors.Join(errs, fmt.Errorf("failed to create nvidia-caps device node %v: %w", deviceNodePath, err))
} }
} }
@ -82,12 +82,12 @@ func (m *Interface) createControlDeviceNode(node controlDeviceNode) error {
return fmt.Errorf("failed to determine minor: %w", err) return fmt.Errorf("failed to determine minor: %w", err)
} }
return m.createDeviceNode(node.path(), int(major), int(minor)) return m.createDeviceNode(node.path(), major, minor)
} }
// controlDeviceNodeMajor returns the major number for the specified NVIDIA control device node. // controlDeviceNodeMajor returns the major number for the specified NVIDIA control device node.
// If the device node is not supported, an error is returned. // If the device node is not supported, an error is returned.
func (m *Interface) controlDeviceNodeMajor(node controlDeviceNode) (int64, error) { func (m *Interface) controlDeviceNodeMajor(node controlDeviceNode) (devices.Major, error) {
var valid bool var valid bool
var major devices.Major var major devices.Major
switch node { switch node {
@ -98,7 +98,7 @@ func (m *Interface) controlDeviceNodeMajor(node controlDeviceNode) (int64, error
} }
if valid { if valid {
return int64(major), nil return major, nil
} }
return 0, errInvalidDeviceNode return 0, errInvalidDeviceNode
@ -106,7 +106,7 @@ func (m *Interface) controlDeviceNodeMajor(node controlDeviceNode) (int64, error
// controlDeviceNodeMinor returns the minor number for the specified NVIDIA control device node. // controlDeviceNodeMinor returns the minor number for the specified NVIDIA control device node.
// If the device node is not supported, an error is returned. // If the device node is not supported, an error is returned.
func (m *Interface) controlDeviceNodeMinor(node controlDeviceNode) (int64, error) { func (m *Interface) controlDeviceNodeMinor(node controlDeviceNode) (uint32, error) {
switch node { switch node {
case "nvidia-modeset": case "nvidia-modeset":
return devices.NVIDIAModesetMinor, nil return devices.NVIDIAModesetMinor, nil

View File

@ -89,25 +89,25 @@ func New(opts ...Option) (*Interface, error) {
func (m *Interface) CreateDeviceNodes(id device.Identifier) error { func (m *Interface) CreateDeviceNodes(id device.Identifier) error {
switch { switch {
case id.IsGpuIndex(): case id.IsGpuIndex():
index, err := strconv.Atoi(string(id)) index, err := strconv.ParseUint(string(id), 10, 32)
if err != nil { if err != nil {
return fmt.Errorf("invalid GPU index: %v", id) return fmt.Errorf("invalid GPU index: %v", id)
} }
return m.createGPUDeviceNode(index) return m.createGPUDeviceNode(uint32(index))
case id.IsMigIndex(): case id.IsMigIndex():
indices := strings.Split(string(id), ":") indices := strings.Split(string(id), ":")
if len(indices) != 2 { if len(indices) != 2 {
return fmt.Errorf("invalid MIG index %v", id) return fmt.Errorf("invalid MIG index %v", id)
} }
gpuIndex, err := strconv.Atoi(indices[0]) gpuIndex, err := strconv.ParseUint(indices[0], 10, 32)
if err != nil { if err != nil {
return fmt.Errorf("invalid parent index %v: %w", indices[0], err) return fmt.Errorf("invalid parent index %v: %w", indices[0], err)
} }
if err := m.createGPUDeviceNode(gpuIndex); err != nil { if err := m.createGPUDeviceNode(uint32(gpuIndex)); err != nil {
return fmt.Errorf("failed to create parent device node: %w", err) return fmt.Errorf("failed to create parent device node: %w", err)
} }
return m.createMigDeviceNodes(gpuIndex) return m.createMigDeviceNodes(uint32(gpuIndex))
case id.IsGpuUUID(), id.IsMigUUID(), id == "all": case id.IsGpuUUID(), id.IsMigUUID(), id == "all":
return m.createAllGPUDeviceNodes() return m.createAllGPUDeviceNodes()
default: default:
@ -117,7 +117,7 @@ func (m *Interface) CreateDeviceNodes(id device.Identifier) error {
// createDeviceNode creates the specified device node with the require major and minor numbers. // createDeviceNode creates the specified device node with the require major and minor numbers.
// If a devRoot is configured, this is prepended to the path. // If a devRoot is configured, this is prepended to the path.
func (m *Interface) createDeviceNode(path string, major int, minor int) error { func (m *Interface) createDeviceNode(path string, major devices.Major, minor uint32) error {
path = filepath.Join(m.devRoot, path) path = filepath.Join(m.devRoot, path)
return m.Mknode(path, major, minor) return m.Mknode(path, uint32(major), minor)
} }

View File

@ -26,28 +26,28 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices" "github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices"
) )
func (m *Interface) createGPUDeviceNode(gpuIndex int) error { func (m *Interface) createGPUDeviceNode(gpuIndex uint32) error {
major, exists := m.Get(devices.NVIDIAGPU) major, exists := m.Get(devices.NVIDIAGPU)
if !exists { if !exists {
return fmt.Errorf("failed to determine device major; nvidia kernel module may not be loaded") return fmt.Errorf("failed to determine device major; nvidia kernel module may not be loaded")
} }
deviceNodePath := fmt.Sprintf("/dev/nvidia%d", gpuIndex) deviceNodePath := fmt.Sprintf("/dev/nvidia%d", gpuIndex)
if err := m.createDeviceNode(deviceNodePath, int(major), gpuIndex); err != nil { if err := m.createDeviceNode(deviceNodePath, major, uint32(gpuIndex)); err != nil {
return fmt.Errorf("failed to create device node %v: %w", deviceNodePath, err) return fmt.Errorf("failed to create device node %v: %w", deviceNodePath, err)
} }
return nil return nil
} }
func (m *Interface) createMigDeviceNodes(gpuIndex int) error { func (m *Interface) createMigDeviceNodes(gpuIndex uint32) error {
capsMajor, exists := m.Get("nvidia-caps") capsMajor, exists := m.Get("nvidia-caps")
if !exists { if !exists {
return nil return nil
} }
var errs error var errs error
for _, capsDeviceMinor := range m.migCaps.FilterForGPU(gpuIndex) { for _, capsDeviceMinor := range m.migCaps.FilterForGPU(int(gpuIndex)) {
capDevicePath := capsDeviceMinor.DevicePath() capDevicePath := capsDeviceMinor.DevicePath()
err := m.createDeviceNode(capDevicePath, int(capsMajor), int(capsDeviceMinor)) err := m.createDeviceNode(capDevicePath, capsMajor, uint32(capsDeviceMinor))
errs = errors.Join(errs, fmt.Errorf("failed to create %v: %w", capDevicePath, err)) errs = errors.Join(errs, fmt.Errorf("failed to create %v: %w", capDevicePath, err))
} }
return errs return errs
@ -62,13 +62,13 @@ func (m *Interface) createAllGPUDeviceNodes() error {
return fmt.Errorf("failed to get GPU information from PCI: %w", err) return fmt.Errorf("failed to get GPU information from PCI: %w", err)
} }
count := len(gpus) count := uint32(len(gpus))
if count == 0 { if count == 0 {
return nil return nil
} }
var errs error var errs error
for gpuIndex := 0; gpuIndex < count; gpuIndex++ { for gpuIndex := uint32(0); gpuIndex < count; gpuIndex++ {
errs = errors.Join(errs, m.createGPUDeviceNode(gpuIndex)) errs = errors.Join(errs, m.createGPUDeviceNode(gpuIndex))
errs = errors.Join(errs, m.createMigDeviceNodes(gpuIndex)) errs = errors.Join(errs, m.createMigDeviceNodes(gpuIndex))
} }

View File

@ -25,16 +25,18 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
) )
//go:generate moq -stub -out mknod_mock.go . mknoder type mint uint32
//go:generate moq -fmt=goimports -rm -stub -out mknod_mock.go . mknoder
type mknoder interface { type mknoder interface {
Mknode(string, int, int) error Mknode(string, uint32, uint32) error
} }
type mknodLogger struct { type mknodLogger struct {
logger.Interface logger.Interface
} }
func (m *mknodLogger) Mknode(path string, major, minor int) error { func (m *mknodLogger) Mknode(path string, major uint32, minor uint32) error {
m.Infof("Running: mknod --mode=0666 %s c %d %d", path, major, minor) m.Infof("Running: mknod --mode=0666 %s c %d %d", path, major, minor)
return nil return nil
} }
@ -43,7 +45,7 @@ type mknodUnix struct {
logger logger.Interface logger logger.Interface
} }
func (m *mknodUnix) Mknode(path string, major, minor int) error { func (m *mknodUnix) Mknode(path string, major uint32, minor uint32) error {
// TODO: Ensure that the existing device node has the correct properties. // TODO: Ensure that the existing device node has the correct properties.
if _, err := os.Stat(path); err == nil { if _, err := os.Stat(path); err == nil {
m.logger.Infof("Skipping: %s already exists", path) m.logger.Infof("Skipping: %s already exists", path)
@ -52,7 +54,7 @@ func (m *mknodUnix) Mknode(path string, major, minor int) error {
return fmt.Errorf("failed to stat %s: %v", path, err) return fmt.Errorf("failed to stat %s: %v", path, err)
} }
err := unix.Mknod(path, unix.S_IFCHR, int(unix.Mkdev(uint32(major), uint32(minor)))) err := unix.Mknod(path, unix.S_IFCHR, int(unix.Mkdev(major, minor)))
if err != nil { if err != nil {
return err return err
} }

View File

@ -17,7 +17,7 @@ var _ mknoder = &mknoderMock{}
// //
// // make and configure a mocked mknoder // // make and configure a mocked mknoder
// mockedmknoder := &mknoderMock{ // mockedmknoder := &mknoderMock{
// MknodeFunc: func(s string, n1 int, n2 int) error { // MknodeFunc: func(s string, v1 uint32, v2 uint32) error {
// panic("mock out the Mknode method") // panic("mock out the Mknode method")
// }, // },
// } // }
@ -28,7 +28,7 @@ var _ mknoder = &mknoderMock{}
// } // }
type mknoderMock struct { type mknoderMock struct {
// MknodeFunc mocks the Mknode method. // MknodeFunc mocks the Mknode method.
MknodeFunc func(s string, n1 int, n2 int) error MknodeFunc func(s string, v1 uint32, v2 uint32) error
// calls tracks calls to the methods. // calls tracks calls to the methods.
calls struct { calls struct {
@ -36,25 +36,25 @@ type mknoderMock struct {
Mknode []struct { Mknode []struct {
// S is the s argument value. // S is the s argument value.
S string S string
// N1 is the n1 argument value. // V1 is the v1 argument value.
N1 int V1 uint32
// N2 is the n2 argument value. // V2 is the v2 argument value.
N2 int V2 uint32
} }
} }
lockMknode sync.RWMutex lockMknode sync.RWMutex
} }
// Mknode calls MknodeFunc. // Mknode calls MknodeFunc.
func (mock *mknoderMock) Mknode(s string, n1 int, n2 int) error { func (mock *mknoderMock) Mknode(s string, v1 uint32, v2 uint32) error {
callInfo := struct { callInfo := struct {
S string S string
N1 int V1 uint32
N2 int V2 uint32
}{ }{
S: s, S: s,
N1: n1, V1: v1,
N2: n2, V2: v2,
} }
mock.lockMknode.Lock() mock.lockMknode.Lock()
mock.calls.Mknode = append(mock.calls.Mknode, callInfo) mock.calls.Mknode = append(mock.calls.Mknode, callInfo)
@ -65,7 +65,7 @@ func (mock *mknoderMock) Mknode(s string, n1 int, n2 int) error {
) )
return errOut return errOut
} }
return mock.MknodeFunc(s, n1, n2) return mock.MknodeFunc(s, v1, v2)
} }
// MknodeCalls gets all the calls that were made to Mknode. // MknodeCalls gets all the calls that were made to Mknode.
@ -74,13 +74,13 @@ func (mock *mknoderMock) Mknode(s string, n1 int, n2 int) error {
// len(mockedmknoder.MknodeCalls()) // len(mockedmknoder.MknodeCalls())
func (mock *mknoderMock) MknodeCalls() []struct { func (mock *mknoderMock) MknodeCalls() []struct {
S string S string
N1 int V1 uint32
N2 int V2 uint32
} { } {
var calls []struct { var calls []struct {
S string S string
N1 int V1 uint32
N2 int V2 uint32
} }
mock.lockMknode.RLock() mock.lockMknode.RLock()
calls = mock.calls.Mknode calls = mock.calls.Mknode

View File

@ -4,9 +4,8 @@
package nvcdi package nvcdi
import ( import (
"sync"
"github.com/NVIDIA/go-nvml/pkg/nvml" "github.com/NVIDIA/go-nvml/pkg/nvml"
"sync"
) )
// Ensure, that nvmlUUIDerMock does implement nvmlUUIDer. // Ensure, that nvmlUUIDerMock does implement nvmlUUIDer.