diff --git a/cmd/nvidia-ctk/system/create-device-nodes/create-device-nodes.go b/cmd/nvidia-ctk/system/create-device-nodes/create-device-nodes.go index b9c58a39..37ed45b0 100644 --- a/cmd/nvidia-ctk/system/create-device-nodes/create-device-nodes.go +++ b/cmd/nvidia-ctk/system/create-device-nodes/create-device-nodes.go @@ -31,7 +31,8 @@ type command struct { } type options struct { - driverRoot string + root string + devRoot string dryRun bool @@ -65,11 +66,21 @@ func (m command) build() *cli.Command { c.Flags = []cli.Flag{ &cli.StringFlag{ - Name: "driver-root", - Usage: "the path to the driver root. Device nodes will be created at `DRIVER_ROOT`/dev", + Name: "root", + // TODO: Remove this alias + Aliases: []string{"driver-root"}, + Usage: "the path to to the root to use to load the kernel modules. This root must be a chrootable path. " + + "If device nodes to be created these will be created at `ROOT`/dev unless an alternative path is specified", Value: "/", - Destination: &opts.driverRoot, - EnvVars: []string{"NVIDIA_DRIVER_ROOT", "DRIVER_ROOT"}, + Destination: &opts.root, + // TODO: Remove the NVIDIA_DRIVER_ROOT and DRIVER_ROOT envvars. + EnvVars: []string{"ROOT", "NVIDIA_DRIVER_ROOT", "DRIVER_ROOT"}, + }, + &cli.StringFlag{ + Name: "dev-root", + Usage: "specify the root where `/dev` is located. If this is not specified, the root is assumed.", + Destination: &opts.devRoot, + EnvVars: []string{"NVIDIA_DEV_ROOT", "DEV_ROOT"}, }, &cli.BoolFlag{ Name: "control-devices", @@ -83,7 +94,7 @@ func (m command) build() *cli.Command { }, &cli.BoolFlag{ Name: "dry-run", - Usage: "if set, the command will not create any symlinks.", + Usage: "if set, the command will not perform any operations", Value: false, Destination: &opts.dryRun, EnvVars: []string{"DRY_RUN"}, @@ -94,6 +105,10 @@ func (m command) build() *cli.Command { } func (m command) validateFlags(r *cli.Context, opts *options) error { + if opts.devRoot == "" && opts.root != "" { + m.logger.Infof("Using dev-root %q", opts.root) + opts.devRoot = opts.root + } return nil } @@ -102,7 +117,7 @@ func (m command) run(c *cli.Context, opts *options) error { modules := nvmodules.New( nvmodules.WithLogger(m.logger), nvmodules.WithDryRun(opts.dryRun), - nvmodules.WithRoot(opts.driverRoot), + nvmodules.WithRoot(opts.root), ) if err := modules.LoadAll(); err != nil { return fmt.Errorf("failed to load NVIDIA kernel modules: %v", err) @@ -113,12 +128,12 @@ func (m command) run(c *cli.Context, opts *options) error { devices, err := nvdevices.New( nvdevices.WithLogger(m.logger), nvdevices.WithDryRun(opts.dryRun), - nvdevices.WithDevRoot(opts.driverRoot), + nvdevices.WithDevRoot(opts.devRoot), ) if err != nil { return err } - m.logger.Infof("Creating control device nodes at %s", opts.driverRoot) + m.logger.Infof("Creating control device nodes at %s", opts.devRoot) if err := devices.CreateNVIDIAControlDevices(); err != nil { return fmt.Errorf("failed to create NVIDIA control device nodes: %v", err) }