diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index 86316977..c9a146fb 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -18,7 +18,6 @@ package generate import ( "fmt" - "io" "os" "path/filepath" "strings" @@ -26,13 +25,13 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" specs "github.com/container-orchestrated-devices/container-device-interface/specs-go" "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" - "sigs.k8s.io/yaml" ) const ( @@ -118,7 +117,8 @@ func (m command) build() *cli.Command { return &c } -func (m command) validateFlags(r *cli.Context, cfg *config) error { +func (m command) validateFlags(c *cli.Context, cfg *config) error { + cfg.format = strings.ToLower(cfg.format) switch cfg.format { case formatJSON: @@ -143,31 +143,6 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error { cfg.nvidiaCTKPath = discover.FindNvidiaCTK(m.logger, cfg.nvidiaCTKPath) - return nil -} - -func (m command) run(c *cli.Context, cfg *config) error { - spec, err := m.generateSpec(cfg) - if err != nil { - return fmt.Errorf("failed to generate CDI spec: %v", err) - } - - var outputTo io.Writer - if cfg.output == "" { - outputTo = os.Stdout - } else { - err := createParentDirsIfRequired(cfg.output) - if err != nil { - return fmt.Errorf("failed to create parent folders for output file: %v", err) - } - outputFile, err := os.Create(cfg.output) - if err != nil { - return fmt.Errorf("failed to create output file: %v", err) - } - defer outputFile.Close() - outputTo = outputFile - } - if outputFileFormat := formatFromFilename(cfg.output); outputFileFormat != "" { m.logger.Debugf("Inferred output format as %q from output file name", outputFileFormat) if !c.IsSet("format") { @@ -177,25 +152,29 @@ func (m command) run(c *cli.Context, cfg *config) error { } } - data, err := yaml.Marshal(spec) - if err != nil { - return fmt.Errorf("failed to marshal CDI spec: %v", err) - } + return nil +} - if strings.ToLower(cfg.format) == formatJSON { - data, err = yaml.YAMLToJSONStrict(data) +func (m command) run(c *cli.Context, cfg *config) error { + spec, err := m.generateSpec(cfg) + if err != nil { + return fmt.Errorf("failed to generate CDI spec: %v", err) + } + m.logger.Infof("Generated CDI spec with version", spec.Raw().Version) + + if cfg.output == "" { + _, err := spec.WriteTo(os.Stdout) if err != nil { - return fmt.Errorf("failed to convert CDI spec from YAML to JSON: %v", err) + return fmt.Errorf("failed to write CDI spec to STDOUT: %v", err) } return nil } - err = writeToOutput(cfg.format, data, outputTo) + err = createParentDirsIfRequired(cfg.output) if err != nil { - return fmt.Errorf("failed to write output: %v", err) + return fmt.Errorf("failed to create parent folders for output file: %v", err) } - - return nil + return spec.Save(cfg.output) } func formatFromFilename(filename string) string { @@ -212,22 +191,7 @@ func formatFromFilename(filename string) string { return "" } -func writeToOutput(format string, data []byte, output io.Writer) error { - if format == formatYAML { - _, err := output.Write([]byte("---\n")) - if err != nil { - return fmt.Errorf("failed to write YAML separator: %v", err) - } - } - _, err := output.Write(data) - if err != nil { - return fmt.Errorf("failed to write data: %v", err) - } - - return nil -} - -func (m command) generateSpec(cfg *config) (*specs.Spec, error) { +func (m command) generateSpec(cfg *config) (spec.Interface, error) { deviceNamer, err := nvcdi.NewDeviceNamer(cfg.deviceNameStrategy) if err != nil { return nil, fmt.Errorf("failed to create device namer: %v", err) @@ -275,23 +239,13 @@ func (m command) generateSpec(cfg *config) (*specs.Spec, error) { return nil, fmt.Errorf("failed to create edits common for entities: %v", err) } - // We construct the spec and determine the minimum required version based on the specification. - spec := specs.Spec{ - Version: "NOT_SET", - Kind: "nvidia.com/gpu", - Devices: deviceSpecs, - ContainerEdits: *commonEdits.ContainerEdits, - } - - minVersion, err := cdi.MinimumRequiredVersion(&spec) - if err != nil { - return nil, fmt.Errorf("failed to get minumum required CDI spec version: %v", err) - } - m.logger.Infof("Using minimum required CDI spec version: %s", minVersion) - - spec.Version = minVersion - - return &spec, nil + return spec.New( + spec.WithVendor("nvidia.com"), + spec.WithClass("gpu"), + spec.WithDeviceSpecs(deviceSpecs), + spec.WithEdits(*commonEdits.ContainerEdits), + spec.WithFormat(cfg.format), + ) } // MergeDeviceSpecs creates a device with the specified name which combines the edits from the previous devices.