diff --git a/cmd/nvidia-ctk/system/create-dev-char-symlinks/create-dev-char-symlinks.go b/cmd/nvidia-ctk/system/create-dev-char-symlinks/create-dev-char-symlinks.go index 03b9de78..8bd1e563 100644 --- a/cmd/nvidia-ctk/system/create-dev-char-symlinks/create-dev-char-symlinks.go +++ b/cmd/nvidia-ctk/system/create-dev-char-symlinks/create-dev-char-symlinks.go @@ -44,6 +44,7 @@ type config struct { dryRun bool watch bool createAll bool + createDeviceNodes bool loadKernelModules bool } @@ -105,6 +106,12 @@ func (m command) build() *cli.Command { Destination: &cfg.loadKernelModules, EnvVars: []string{"LOAD_KERNEL_MODULES"}, }, + &cli.BoolFlag{ + Name: "create-device-nodes", + Usage: "Create the NVIDIA control device nodes in the driver root if they do not exist. This is only applicable when --create-all is set", + Destination: &cfg.createDeviceNodes, + EnvVars: []string{"CREATE_DEVICE_NODES"}, + }, &cli.BoolFlag{ Name: "dry-run", Usage: "If set, the command will not create any symlinks.", @@ -127,6 +134,11 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error { cfg.loadKernelModules = false } + if cfg.createDeviceNodes && !cfg.createAll { + m.logger.Warn("create-device-nodes is only applicable when create-all is set; ignoring") + cfg.createDeviceNodes = false + } + return nil } @@ -151,6 +163,7 @@ func (m command) run(c *cli.Context, cfg *config) error { WithDryRun(cfg.dryRun), WithCreateAll(cfg.createAll), WithLoadKernelModules(cfg.loadKernelModules), + WithCreateDeviceNodes(cfg.createDeviceNodes), ) if err != nil { return fmt.Errorf("failed to create symlink creator: %v", err) @@ -206,6 +219,7 @@ type linkCreator struct { devCharPath string dryRun bool createAll bool + createDeviceNodes bool loadKernelModules bool } @@ -233,17 +247,8 @@ func NewSymlinkCreator(opts ...Option) (Creator, error) { c.devCharPath = defaultDevCharPath } - if c.loadKernelModules { - s, err := system.New( - system.WithLogger(c.logger), - system.WithDryRun(c.dryRun), - ) - if err != nil { - return nil, err - } - if err := s.LoadNVIDIAKernelModules(); err != nil { - return nil, fmt.Errorf("failed to load NVIDIA kernel modules: %v", err) - } + if err := c.setup(); err != nil { + return nil, err } if c.createAll { @@ -258,6 +263,34 @@ func NewSymlinkCreator(opts ...Option) (Creator, error) { return c, nil } +func (m linkCreator) setup() error { + if !m.loadKernelModules && !m.createDeviceNodes { + return nil + } + + s, err := system.New( + system.WithLogger(m.logger), + system.WithDryRun(m.dryRun), + ) + if err != nil { + return err + } + + if m.loadKernelModules { + if err := s.LoadNVIDIAKernelModules(); err != nil { + return fmt.Errorf("failed to load NVIDIA kernel modules: %v", err) + } + } + + if m.createDeviceNodes { + if err := s.CreateNVIDIAControlDeviceNodesAt(m.driverRoot); err != nil { + return fmt.Errorf("failed to create NVIDIA device nodes: %v", err) + } + } + + return nil +} + // WithDriverRoot sets the driver root path. func WithDriverRoot(root string) Option { return func(c *linkCreator) { @@ -300,6 +333,13 @@ func WithLoadKernelModules(loadKernelModules bool) Option { } } +// WithCreateDeviceNodes sets the createDeviceNodes flag for the linkCreator. +func WithCreateDeviceNodes(createDeviceNodes bool) Option { + return func(lc *linkCreator) { + lc.createDeviceNodes = createDeviceNodes + } +} + // CreateLinks creates symlinks for all NVIDIA device nodes found in the driver root. func (m linkCreator) CreateLinks() error { deviceNodes, err := m.lister.DeviceNodes()