diff --git a/tools/container/toolkit/toolkit.go b/tools/container/toolkit/toolkit.go index 1db3a6f4..0de78b25 100644 --- a/tools/container/toolkit/toolkit.go +++ b/tools/container/toolkit/toolkit.go @@ -70,6 +70,8 @@ type options struct { cdiVendor string cdiClass string + createDeviceNodes cli.StringSlice + acceptNVIDIAVisibleDevicesWhenUnprivileged bool acceptNVIDIAVisibleDevicesAsVolumeMounts bool @@ -224,6 +226,13 @@ func main() { Hidden: true, Destination: &opts.ignoreErrors, }, + &cli.StringSliceFlag{ + Name: "create-device-nodes", + Usage: "(Only applicable with --cdi-enabled) specifies which device nodes should be created. If any one of the options is set to '' or 'none', no device nodes will be created.", + Value: cli.NewStringSlice("control"), + Destination: &opts.createDeviceNodes, + EnvVars: []string{"CREATE_DEVICE_NODES"}, + }, } // Update the subcommand flags with the common subcommand flags @@ -252,6 +261,29 @@ func validateOptions(c *cli.Context, opts *options) error { opts.cdiVendor = vendor opts.cdiClass = class + if opts.cdiEnabled && opts.cdiOutputDir == "" { + log.Warning("Skipping CDI spec generation (no output directory specified)") + opts.cdiEnabled = false + } + + isDisabled := false + for _, mode := range opts.createDeviceNodes.Value() { + if mode != "" && mode != "none" && mode != "control" { + return fmt.Errorf("invalid --create-device-nodes value: %v", mode) + } + if mode == "" || mode == "none" { + isDisabled = true + break + } + } + if !opts.cdiEnabled && !isDisabled { + log.Info("disabling device node creation since --cdi-enabled=false") + isDisabled = true + } + if isDisabled { + opts.createDeviceNodes = *cli.NewStringSlice() + } + return nil } @@ -330,6 +362,13 @@ func Install(cli *cli.Context, opts *options) error { log.Errorf("Ignoring error: %v", fmt.Errorf("error installing NVIDIA container toolkit config: %v", err)) } + err = createDeviceNodes(opts) + if err != nil && !opts.ignoreErrors { + return fmt.Errorf("error creating device nodes: %v", err) + } else if err != nil { + log.Errorf("Ignoring error: %v", fmt.Errorf("error creating device nodes: %v", err)) + } + return generateCDISpec(opts, nvidiaCTKPath) } @@ -677,27 +716,37 @@ func createDirectories(dir ...string) error { return nil } -// generateCDISpec generates a CDI spec for use in managemnt containers -func generateCDISpec(opts *options, nvidiaCTKPath string) error { - if !opts.cdiEnabled { - return nil - } - if opts.cdiOutputDir == "" { - log.Info("Skipping CDI spec generation (no output directory specified)") +func createDeviceNodes(opts *options) error { + modes := opts.createDeviceNodes.Value() + if len(modes) == 0 { return nil } - log.Infof("Creating control device nodes at %v", opts.DriverRootCtrPath) devices, err := nvdevices.New( nvdevices.WithDevRoot(opts.DriverRootCtrPath), ) if err != nil { return fmt.Errorf("failed to create library: %v", err) } - if err := devices.CreateNVIDIAControlDevices(); err != nil { - return fmt.Errorf("failed to create control device nodes: %v", err) - } + for _, mode := range modes { + log.Infof("Creating %v device nodes at %v", mode, opts.DriverRootCtrPath) + if mode != "control" { + log.Warningf("Unrecognised device mode: %v", mode) + continue + } + if err := devices.CreateNVIDIAControlDevices(); err != nil { + return fmt.Errorf("failed to create control device nodes: %v", err) + } + } + return nil +} + +// generateCDISpec generates a CDI spec for use in managemnt containers +func generateCDISpec(opts *options, nvidiaCTKPath string) error { + if !opts.cdiEnabled { + return nil + } log.Info("Generating CDI spec for management containers") cdilib, err := nvcdi.New( nvcdi.WithMode(nvcdi.ModeManagement),