diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index 77165e71..eb869d28 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -36,6 +36,8 @@ import ( ) const ( + discoveryModeNVML = "nvml" + formatJSON = "json" formatYAML = "yaml" ) @@ -50,6 +52,7 @@ type config struct { deviceNameStrategy string driverRoot string nvidiaCTKPath string + discoveryMode string } // NewCommand constructs a generate-cdi command with the specified logger @@ -88,6 +91,12 @@ func (m command) build() *cli.Command { Value: formatYAML, Destination: &cfg.format, }, + &cli.StringFlag{ + Name: "discovery-mode", + Usage: "The mode to use when discovering the available entities. One of [nvml]", + Value: discoveryModeNVML, + Destination: &cfg.discoveryMode, + }, &cli.StringFlag{ Name: "device-name-strategy", Usage: "Specify the strategy for generating device names. One of [index | uuid | type-index]", @@ -118,6 +127,13 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error { return fmt.Errorf("invalid output format: %v", cfg.format) } + cfg.discoveryMode = strings.ToLower(cfg.discoveryMode) + switch cfg.discoveryMode { + case discoveryModeNVML: + default: + return fmt.Errorf("invalid discovery mode: %v", cfg.discoveryMode) + } + _, err := nvcdi.NewDeviceNamer(cfg.deviceNameStrategy) if err != nil { return err @@ -229,6 +245,7 @@ func (m command) generateSpec(cfg *config) (*specs.Spec, error) { nvcdi.WithDeviceNamer(deviceNamer), nvcdi.WithDeviceLib(devicelib), nvcdi.WithNvmlLib(nvmllib), + nvcdi.WithMode(string(cfg.discoveryMode)), ) deviceSpecs, err := cdilib.GetAllDeviceSpecs() @@ -298,3 +315,5 @@ func createParentDirsIfRequired(filename string) error { } return os.MkdirAll(dir, 0755) } + +type discoveryMode string