diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index 544f1219..cd4f52f0 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -49,6 +49,7 @@ type command struct { type config struct { output string format string + root string } // NewCommand constructs a generate-cdi command with the specified logger @@ -87,6 +88,11 @@ func (m command) build() *cli.Command { Value: formatYAML, Destination: &cfg.format, }, + &cli.StringFlag{ + Name: "root", + Usage: "Specify the root to use when discovering the entities that should be included in the CDI specification.", + Destination: &cfg.root, + }, } return &c @@ -105,7 +111,7 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error { } func (m command) run(c *cli.Context, cfg *config) error { - spec, err := m.generateSpec() + spec, err := m.generateSpec(cfg.root) if err != nil { return fmt.Errorf("failed to generate CDI spec: %v", err) } @@ -184,7 +190,7 @@ func writeToOutput(format string, data []byte, output io.Writer) error { return nil } -func (m command) generateSpec() (*specs.Spec, error) { +func (m command) generateSpec(root string) (*specs.Spec, error) { nvmllib := nvml.New() if r := nvmllib.Init(); r != nvml.SUCCESS { return nil, r @@ -193,7 +199,7 @@ func (m command) generateSpec() (*specs.Spec, error) { devicelib := device.New(device.WithNvml(nvmllib)) - deviceSpecs, err := m.generateDeviceSpecs(devicelib) + deviceSpecs, err := m.generateDeviceSpecs(devicelib, root) if err != nil { return nil, fmt.Errorf("failed to create device CDI specs: %v", err) } @@ -204,7 +210,7 @@ func (m command) generateSpec() (*specs.Spec, error) { allEdits := cdi.ContainerEdits{} - ipcs, err := NewIPCDiscoverer(m.logger, "") + ipcs, err := NewIPCDiscoverer(m.logger, root) if err != nil { return nil, fmt.Errorf("failed to create discoverer for IPC sockets: %v", err) } @@ -220,12 +226,12 @@ func (m command) generateSpec() (*specs.Spec, error) { allEdits.Append(ipcEdits) - common, err := NewCommonDiscoverer(m.logger, "", nvmllib) + common, err := NewCommonDiscoverer(m.logger, root, nvmllib) if err != nil { return nil, fmt.Errorf("failed to create discoverer for common entities: %v", err) } - deviceFolderPermissionHooks, err := NewDeviceFolderPermissionHookDiscoverer(m.logger, "", deviceSpecs) + deviceFolderPermissionHooks, err := NewDeviceFolderPermissionHookDiscoverer(m.logger, root, deviceSpecs) if err != nil { return nil, fmt.Errorf("failed to generated permission hooks for device nodes: %v", err) } @@ -249,7 +255,7 @@ func (m command) generateSpec() (*specs.Spec, error) { return &spec, nil } -func (m command) generateDeviceSpecs(devicelib device.Interface) ([]specs.Device, error) { +func (m command) generateDeviceSpecs(devicelib device.Interface, root string) ([]specs.Device, error) { var deviceSpecs []specs.Device err := devicelib.VisitDevices(func(i int, d device.Device) error { @@ -260,7 +266,7 @@ func (m command) generateDeviceSpecs(devicelib device.Interface) ([]specs.Device if isMigEnabled { return nil } - device, err := NewFullGPUDiscoverer(m.logger, "", d) + device, err := NewFullGPUDiscoverer(m.logger, root, d) if err != nil { return fmt.Errorf("failed to create device: %v", err) }