diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index 6f896952..37bb562e 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -49,6 +49,9 @@ type config struct { targetDriverRoot string nvidiaCTKPath string mode string + kind string + vendor string + class string } // NewCommand constructs a generate-cdi command with the specified logger @@ -116,6 +119,13 @@ func (m command) build() *cli.Command { Value: "", Destination: &cfg.targetDriverRoot, }, + &cli.StringFlag{ + Name: "kind", + Aliases: []string{"cdi-kind"}, + Usage: "the vendor string to use for the generated CDI specification.", + Value: "nvidia.com/gpu", + Destination: &cfg.kind, + }, } return &c @@ -161,6 +171,16 @@ func (m command) validateFlags(c *cli.Context, cfg *config) error { } } + vendor, class := cdi.ParseQualifier(cfg.kind) + if err := cdi.ValidateVendorName(vendor); err != nil { + return fmt.Errorf("invalid CDI vendor name: %v", err) + } + if err := cdi.ValidateClassName(class); err != nil { + return fmt.Errorf("invalid CDI class name: %v", err) + } + cfg.vendor = vendor + cfg.class = class + return nil } @@ -236,8 +256,8 @@ func (m command) generateSpec(cfg *config) (spec.Interface, error) { } spec, err := spec.New( - spec.WithVendor("nvidia.com"), - spec.WithClass("gpu"), + spec.WithVendor(cfg.vendor), + spec.WithClass(cfg.class), spec.WithDeviceSpecs(deviceSpecs), spec.WithEdits(*commonEdits.ContainerEdits), spec.WithFormat(cfg.format),