diff --git a/CHANGELOG.md b/CHANGELOG.md index 40cd2097..7cde0fe8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ * Add `nvidia-container-runtime-hook.path` config option to specify NVIDIA Container Runtime Hook path explicitly. * Fix bug in creation of `/dev/char` symlinks by failing operation if kernel modules are not loaded. +* Add option to load kernel modules when creating device nodes +* Add option to create device nodes when creating `/dev/char` symlinks ## v1.13.1 diff --git a/cmd/nvidia-ctk/system/create-dev-char-symlinks/create-dev-char-symlinks.go b/cmd/nvidia-ctk/system/create-dev-char-symlinks/create-dev-char-symlinks.go index f84734b3..8bd1e563 100644 --- a/cmd/nvidia-ctk/system/create-dev-char-symlinks/create-dev-char-symlinks.go +++ b/cmd/nvidia-ctk/system/create-dev-char-symlinks/create-dev-char-symlinks.go @@ -24,6 +24,7 @@ import ( "strings" "syscall" + "github.com/NVIDIA/nvidia-container-toolkit/internal/system" "github.com/fsnotify/fsnotify" "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" @@ -38,11 +39,13 @@ type command struct { } type config struct { - devCharPath string - driverRoot string - dryRun bool - watch bool - createAll bool + devCharPath string + driverRoot string + dryRun bool + watch bool + createAll bool + createDeviceNodes bool + loadKernelModules bool } // NewCommand constructs a command sub-command with the specified logger @@ -97,6 +100,18 @@ func (m command) build() *cli.Command { Destination: &cfg.createAll, EnvVars: []string{"CREATE_ALL"}, }, + &cli.BoolFlag{ + Name: "load-kernel-modules", + Usage: "Load the NVIDIA kernel modules before creating symlinks. This is only applicable when --create-all is set.", + Destination: &cfg.loadKernelModules, + EnvVars: []string{"LOAD_KERNEL_MODULES"}, + }, + &cli.BoolFlag{ + Name: "create-device-nodes", + Usage: "Create the NVIDIA control device nodes in the driver root if they do not exist. This is only applicable when --create-all is set", + Destination: &cfg.createDeviceNodes, + EnvVars: []string{"CREATE_DEVICE_NODES"}, + }, &cli.BoolFlag{ Name: "dry-run", Usage: "If set, the command will not create any symlinks.", @@ -114,6 +129,16 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error { return fmt.Errorf("create-all and watch are mutually exclusive") } + if cfg.loadKernelModules && !cfg.createAll { + m.logger.Warn("load-kernel-modules is only applicable when create-all is set; ignoring") + cfg.loadKernelModules = false + } + + if cfg.createDeviceNodes && !cfg.createAll { + m.logger.Warn("create-device-nodes is only applicable when create-all is set; ignoring") + cfg.createDeviceNodes = false + } + return nil } @@ -137,6 +162,8 @@ func (m command) run(c *cli.Context, cfg *config) error { WithDriverRoot(cfg.driverRoot), WithDryRun(cfg.dryRun), WithCreateAll(cfg.createAll), + WithLoadKernelModules(cfg.loadKernelModules), + WithCreateDeviceNodes(cfg.createDeviceNodes), ) if err != nil { return fmt.Errorf("failed to create symlink creator: %v", err) @@ -186,12 +213,14 @@ create: } type linkCreator struct { - logger *logrus.Logger - lister nodeLister - driverRoot string - devCharPath string - dryRun bool - createAll bool + logger *logrus.Logger + lister nodeLister + driverRoot string + devCharPath string + dryRun bool + createAll bool + createDeviceNodes bool + loadKernelModules bool } // Creator is an interface for creating symlinks to /dev/nv* devices in /dev/char. @@ -218,6 +247,10 @@ func NewSymlinkCreator(opts ...Option) (Creator, error) { c.devCharPath = defaultDevCharPath } + if err := c.setup(); err != nil { + return nil, err + } + if c.createAll { lister, err := newAllPossible(c.logger, c.driverRoot) if err != nil { @@ -230,6 +263,34 @@ func NewSymlinkCreator(opts ...Option) (Creator, error) { return c, nil } +func (m linkCreator) setup() error { + if !m.loadKernelModules && !m.createDeviceNodes { + return nil + } + + s, err := system.New( + system.WithLogger(m.logger), + system.WithDryRun(m.dryRun), + ) + if err != nil { + return err + } + + if m.loadKernelModules { + if err := s.LoadNVIDIAKernelModules(); err != nil { + return fmt.Errorf("failed to load NVIDIA kernel modules: %v", err) + } + } + + if m.createDeviceNodes { + if err := s.CreateNVIDIAControlDeviceNodesAt(m.driverRoot); err != nil { + return fmt.Errorf("failed to create NVIDIA device nodes: %v", err) + } + } + + return nil +} + // WithDriverRoot sets the driver root path. func WithDriverRoot(root string) Option { return func(c *linkCreator) { @@ -265,6 +326,20 @@ func WithCreateAll(createAll bool) Option { } } +// WithLoadKernelModules sets the loadKernelModules flag for the linkCreator. +func WithLoadKernelModules(loadKernelModules bool) Option { + return func(lc *linkCreator) { + lc.loadKernelModules = loadKernelModules + } +} + +// WithCreateDeviceNodes sets the createDeviceNodes flag for the linkCreator. +func WithCreateDeviceNodes(createDeviceNodes bool) Option { + return func(lc *linkCreator) { + lc.createDeviceNodes = createDeviceNodes + } +} + // CreateLinks creates symlinks for all NVIDIA device nodes found in the driver root. func (m linkCreator) CreateLinks() error { deviceNodes, err := m.lister.DeviceNodes() diff --git a/cmd/nvidia-ctk/system/create-device-nodes/create-device-nodes.go b/cmd/nvidia-ctk/system/create-device-nodes/create-device-nodes.go index 939c8525..13861f76 100644 --- a/cmd/nvidia-ctk/system/create-device-nodes/create-device-nodes.go +++ b/cmd/nvidia-ctk/system/create-device-nodes/create-device-nodes.go @@ -34,6 +34,8 @@ type options struct { dryRun bool control bool + + loadKernelModules bool } // NewCommand constructs a command sub-command with the specified logger @@ -72,6 +74,11 @@ func (m command) build() *cli.Command { Usage: "create all control device nodes: nvidiactl, nvidia-modeset, nvidia-uvm, nvidia-uvm-tools", Destination: &opts.control, }, + &cli.BoolFlag{ + Name: "load-kernel-modules", + Usage: "load the NVIDIA Kernel Modules before creating devices nodes", + Destination: &opts.loadKernelModules, + }, &cli.BoolFlag{ Name: "dry-run", Usage: "if set, the command will not create any symlinks.", @@ -92,6 +99,7 @@ func (m command) run(c *cli.Context, opts *options) error { s, err := system.New( system.WithLogger(m.logger), system.WithDryRun(opts.dryRun), + system.WithLoadKernelModules(opts.loadKernelModules), ) if err != nil { return fmt.Errorf("failed to create library: %v", err) diff --git a/internal/system/options.go b/internal/system/options.go index fb0fbb38..de3bf21d 100644 --- a/internal/system/options.go +++ b/internal/system/options.go @@ -34,3 +34,10 @@ func WithDryRun(dryRun bool) Option { i.dryRun = dryRun } } + +// WithLoadKernelModules sets the load kernel modules flag +func WithLoadKernelModules(loadKernelModules bool) Option { + return func(i *Interface) { + i.loadKernelModules = loadKernelModules + } +} diff --git a/internal/system/system.go b/internal/system/system.go index d3ad63eb..fe745160 100644 --- a/internal/system/system.go +++ b/internal/system/system.go @@ -19,6 +19,7 @@ package system import ( "fmt" "os" + "os/exec" "path/filepath" "strings" @@ -29,10 +30,10 @@ import ( // Interface is the interface for the system command type Interface struct { - logger *logrus.Logger - dryRun bool - - nvidiaDevices nvidiaDevices + logger *logrus.Logger + dryRun bool + loadKernelModules bool + nvidiaDevices nvidiaDevices } // New constructs a system command with the specified options @@ -44,6 +45,12 @@ func New(opts ...Option) (*Interface, error) { opt(i) } + if i.loadKernelModules { + if err := i.LoadNVIDIAKernelModules(); err != nil { + return nil, fmt.Errorf("failed to load kernel modules: %v", err) + } + } + devices, err := devices.GetNVIDIADevices() if err != nil { return nil, fmt.Errorf("failed to create devices info: %v", err) @@ -108,6 +115,26 @@ func (m *Interface) createDeviceNode(path string, major int, minor int) error { return unix.Chmod(path, 0666) } +// LoadNVIDIAKernelModules loads the NVIDIA kernel modules. +func (m *Interface) LoadNVIDIAKernelModules() error { + modules := []string{"nvidia", "nvidia-uvm", "nvidia-modeset"} + + for _, module := range modules { + if m.dryRun { + m.logger.Infof("Running: /sbin/modprobe %s", module) + continue + } + cmd := exec.Command("/sbin/modprobe", module) + + if output, err := cmd.CombinedOutput(); err != nil { + m.logger.Debugf("Failed to load kernel module %s: %v", module, string(output)) + return fmt.Errorf("failed to load kernel module %s: %v", module, err) + } + } + + return nil +} + type nvidiaDevices struct { devices.Devices }