diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index b284c75e..c235c90f 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -18,7 +18,6 @@ package generate import ( "fmt" - "io" "os" "path/filepath" "strings" @@ -26,19 +25,16 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" specs "github.com/container-orchestrated-devices/container-device-interface/specs-go" "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" - "sigs.k8s.io/yaml" ) const ( - formatJSON = "json" - formatYAML = "yaml" - allDeviceName = "all" ) @@ -88,7 +84,7 @@ func (m command) build() *cli.Command { &cli.StringFlag{ Name: "format", Usage: "The output format for the generated spec [json | yaml]. This overrides the format defined by the output file extension (if specified).", - Value: formatYAML, + Value: spec.FormatYAML, Destination: &cfg.format, }, &cli.StringFlag{ @@ -118,11 +114,12 @@ func (m command) build() *cli.Command { return &c } -func (m command) validateFlags(r *cli.Context, cfg *config) error { +func (m command) validateFlags(c *cli.Context, cfg *config) error { + cfg.format = strings.ToLower(cfg.format) switch cfg.format { - case formatJSON: - case formatYAML: + case spec.FormatJSON: + case spec.FormatYAML: default: return fmt.Errorf("invalid output format: %v", cfg.format) } @@ -143,31 +140,6 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error { cfg.nvidiaCTKPath = discover.FindNvidiaCTK(m.logger, cfg.nvidiaCTKPath) - return nil -} - -func (m command) run(c *cli.Context, cfg *config) error { - spec, err := m.generateSpec(cfg) - if err != nil { - return fmt.Errorf("failed to generate CDI spec: %v", err) - } - - var outputTo io.Writer - if cfg.output == "" { - outputTo = os.Stdout - } else { - err := createParentDirsIfRequired(cfg.output) - if err != nil { - return fmt.Errorf("failed to create parent folders for output file: %v", err) - } - outputFile, err := os.Create(cfg.output) - if err != nil { - return fmt.Errorf("failed to create output file: %v", err) - } - defer outputFile.Close() - outputTo = outputFile - } - if outputFileFormat := formatFromFilename(cfg.output); outputFileFormat != "" { m.logger.Debugf("Inferred output format as %q from output file name", outputFileFormat) if !c.IsSet("format") { @@ -177,56 +149,40 @@ func (m command) run(c *cli.Context, cfg *config) error { } } - data, err := yaml.Marshal(spec) - if err != nil { - return fmt.Errorf("failed to marshal CDI spec: %v", err) - } - - if strings.ToLower(cfg.format) == formatJSON { - data, err = yaml.YAMLToJSONStrict(data) - if err != nil { - return fmt.Errorf("failed to convert CDI spec from YAML to JSON: %v", err) - } - } - - err = writeToOutput(cfg.format, data, outputTo) - if err != nil { - return fmt.Errorf("failed to write output: %v", err) - } - return nil } +func (m command) run(c *cli.Context, cfg *config) error { + spec, err := m.generateSpec(cfg) + if err != nil { + return fmt.Errorf("failed to generate CDI spec: %v", err) + } + m.logger.Infof("Generated CDI spec with version %v", spec.Raw().Version) + + if cfg.output == "" { + _, err := spec.WriteTo(os.Stdout) + if err != nil { + return fmt.Errorf("failed to write CDI spec to STDOUT: %v", err) + } + return nil + } + + return spec.Save(cfg.output) +} + func formatFromFilename(filename string) string { ext := filepath.Ext(filename) switch strings.ToLower(ext) { case ".json": - return formatJSON - case ".yaml": - return formatYAML - case ".yml": - return formatYAML + return spec.FormatJSON + case ".yaml", ".yml": + return spec.FormatYAML } return "" } -func writeToOutput(format string, data []byte, output io.Writer) error { - if format == formatYAML { - _, err := output.Write([]byte("---\n")) - if err != nil { - return fmt.Errorf("failed to write YAML separator: %v", err) - } - } - _, err := output.Write(data) - if err != nil { - return fmt.Errorf("failed to write data: %v", err) - } - - return nil -} - -func (m command) generateSpec(cfg *config) (*specs.Spec, error) { +func (m command) generateSpec(cfg *config) (spec.Interface, error) { deviceNamer, err := nvcdi.NewDeviceNamer(cfg.deviceNameStrategy) if err != nil { return nil, fmt.Errorf("failed to create device namer: %v", err) @@ -274,23 +230,13 @@ func (m command) generateSpec(cfg *config) (*specs.Spec, error) { return nil, fmt.Errorf("failed to create edits common for entities: %v", err) } - // We construct the spec and determine the minimum required version based on the specification. - spec := specs.Spec{ - Version: "NOT_SET", - Kind: "nvidia.com/gpu", - Devices: deviceSpecs, - ContainerEdits: *commonEdits.ContainerEdits, - } - - minVersion, err := cdi.MinimumRequiredVersion(&spec) - if err != nil { - return nil, fmt.Errorf("failed to get minumum required CDI spec version: %v", err) - } - m.logger.Infof("Using minimum required CDI spec version: %s", minVersion) - - spec.Version = minVersion - - return &spec, nil + return spec.New( + spec.WithVendor("nvidia.com"), + spec.WithClass("gpu"), + spec.WithDeviceSpecs(deviceSpecs), + spec.WithEdits(*commonEdits.ContainerEdits), + spec.WithFormat(cfg.format), + ) } // MergeDeviceSpecs creates a device with the specified name which combines the edits from the previous devices. @@ -320,14 +266,3 @@ func MergeDeviceSpecs(deviceSpecs []specs.Device, mergedDeviceName string) (spec } return merged, nil } - -// createParentDirsIfRequired creates the parent folders of the specified path if requried. -// Note that MkdirAll does not specifically check whether the specified path is non-empty and raises an error if it is. -// The path will be empty if filename in the current folder is specified, for example -func createParentDirsIfRequired(filename string) error { - dir := filepath.Dir(filename) - if dir == "" { - return nil - } - return os.MkdirAll(dir, 0755) -} diff --git a/pkg/nvcdi/api.go b/pkg/nvcdi/api.go index 267010e9..85bace99 100644 --- a/pkg/nvcdi/api.go +++ b/pkg/nvcdi/api.go @@ -17,6 +17,7 @@ package nvcdi import ( + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/specs-go" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" @@ -33,6 +34,7 @@ const ( // Interface defines the API for the nvcdi package type Interface interface { + GetSpec() (spec.Interface, error) GetCommonEdits() (*cdi.ContainerEdits, error) GetAllDeviceSpecs() ([]specs.Device, error) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) diff --git a/pkg/nvcdi/lib-nvml.go b/pkg/nvcdi/lib-nvml.go index aaca382e..8fa29c11 100644 --- a/pkg/nvcdi/lib-nvml.go +++ b/pkg/nvcdi/lib-nvml.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/specs-go" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" @@ -29,6 +30,11 @@ type nvmllib nvcdilib var _ Interface = (*nvmllib)(nil) +// GetSpec should not be called for nvmllib +func (l *nvmllib) GetSpec() (spec.Interface, error) { + return nil, fmt.Errorf("Unexpected call to nvmllib.GetSpec()") +} + // GetAllDeviceSpecs returns the device specs for all available devices. func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) { var deviceSpecs []specs.Device diff --git a/pkg/nvcdi/lib-wsl.go b/pkg/nvcdi/lib-wsl.go index d901995c..937d1d33 100644 --- a/pkg/nvcdi/lib-wsl.go +++ b/pkg/nvcdi/lib-wsl.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/specs-go" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" @@ -29,6 +30,11 @@ type wsllib nvcdilib var _ Interface = (*wsllib)(nil) +// GetSpec should not be called for wsllib +func (l *wsllib) GetSpec() (spec.Interface, error) { + return nil, fmt.Errorf("Unexpected call to wsllib.GetSpec()") +} + // GetAllDeviceSpecs returns the device specs for all available devices. func (l *wsllib) GetAllDeviceSpecs() ([]specs.Device, error) { device := newDXGDeviceDiscoverer(l.logger, l.driverRoot) diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index be5554eb..c13a2ab5 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -17,12 +17,20 @@ package nvcdi import ( + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/sirupsen/logrus" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/info" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" ) +type wrapper struct { + Interface + + vendor string + class string +} + type nvcdilib struct { logger *logrus.Logger nvmllib nvml.Interface @@ -32,6 +40,9 @@ type nvcdilib struct { driverRoot string nvidiaCTKPath string + vendor string + class string + infolib info.Interface } @@ -60,6 +71,7 @@ func New(opts ...Option) Interface { l.infolib = info.New() } + var lib Interface switch l.resolveMode() { case ModeNvml: if l.nvmllib == nil { @@ -69,13 +81,41 @@ func New(opts ...Option) Interface { l.devicelib = device.New(device.WithNvml(l.nvmllib)) } - return (*nvmllib)(l) + lib = (*nvmllib)(l) case ModeWsl: - return (*wsllib)(l) + lib = (*wsllib)(l) + default: + // TODO: We would like to return an error here instead of panicking + panic("Unknown mode") } - // TODO: We want an error here. - return nil + w := wrapper{ + Interface: lib, + vendor: l.vendor, + class: l.class, + } + return &w +} + +// GetSpec combines the device specs and common edits from the wrapped Interface to a single spec.Interface. +func (l *wrapper) GetSpec() (spec.Interface, error) { + deviceSpecs, err := l.GetAllDeviceSpecs() + if err != nil { + return nil, err + } + + edits, err := l.GetCommonEdits() + if err != nil { + return nil, err + } + + return spec.New( + spec.WithDeviceSpecs(deviceSpecs), + spec.WithEdits(*edits.ContainerEdits), + spec.WithVendor(l.vendor), + spec.WithClass(l.class), + ) + } // resolveMode resolves the mode for CDI spec generation based on the current system. diff --git a/pkg/nvcdi/options.go b/pkg/nvcdi/options.go index 317cace2..19aef93c 100644 --- a/pkg/nvcdi/options.go +++ b/pkg/nvcdi/options.go @@ -73,3 +73,17 @@ func WithMode(mode string) Option { l.mode = mode } } + +// WithVendor sets the vendor for the library +func WithVendor(vendor string) Option { + return func(o *nvcdilib) { + o.vendor = vendor + } +} + +// WithClass sets the class for the library +func WithClass(class string) Option { + return func(o *nvcdilib) { + o.class = class + } +} diff --git a/pkg/nvcdi/spec/api.go b/pkg/nvcdi/spec/api.go new file mode 100644 index 00000000..a72c6dff --- /dev/null +++ b/pkg/nvcdi/spec/api.go @@ -0,0 +1,40 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package spec + +import ( + "io" + + "github.com/container-orchestrated-devices/container-device-interface/specs-go" +) + +const ( + // DetectMinimumVersion is a constant that triggers a spec to detect the minimum required version. + DetectMinimumVersion = "DETECT_MINIMUM_VERSION" + + // FormatJSON indicates a JSON output format + FormatJSON = "json" + // FormatYAML indicates a YAML output format + FormatYAML = "yaml" +) + +// Interface is the interface for the spec API +type Interface interface { + io.WriterTo + Save(string) error + Raw() *specs.Spec +} diff --git a/pkg/nvcdi/spec/builder.go b/pkg/nvcdi/spec/builder.go new file mode 100644 index 00000000..bdf68fe2 --- /dev/null +++ b/pkg/nvcdi/spec/builder.go @@ -0,0 +1,130 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package spec + +import ( + "fmt" + + "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" + "github.com/container-orchestrated-devices/container-device-interface/specs-go" +) + +type builder struct { + raw *specs.Spec + version string + vendor string + class string + deviceSpecs []specs.Device + edits specs.ContainerEdits + format string +} + +// newBuilder creates a new spec builder with the supplied options +func newBuilder(opts ...Option) *builder { + s := &builder{} + for _, opt := range opts { + opt(s) + } + if s.version == "" { + s.version = DetectMinimumVersion + } + if s.vendor == "" { + s.vendor = "nvidia.com" + } + if s.class == "" { + s.class = "gpu" + } + if s.format == "" { + s.format = FormatYAML + } + + return s +} + +// Build builds a CDI spec form the spec builder. +func (o *builder) Build() (*spec, error) { + raw := o.raw + + if raw == nil { + raw = &specs.Spec{ + Version: o.version, + Kind: fmt.Sprintf("%s/%s", o.vendor, o.class), + Devices: o.deviceSpecs, + ContainerEdits: o.edits, + } + } + + if raw.Version == DetectMinimumVersion { + minVersion, err := cdi.MinimumRequiredVersion(raw) + if err != nil { + return nil, fmt.Errorf("failed to get minumum required CDI spec version: %v", err) + } + raw.Version = minVersion + } + + s := spec{ + Spec: raw, + format: o.format, + } + + return &s, nil +} + +// Option defines a function that can be used to configure the spec builder. +type Option func(*builder) + +// WithDeviceSpecs sets the device specs for the spec builder +func WithDeviceSpecs(deviceSpecs []specs.Device) Option { + return func(o *builder) { + o.deviceSpecs = deviceSpecs + } +} + +// WithEdits sets the container edits for the spec builder +func WithEdits(edits specs.ContainerEdits) Option { + return func(o *builder) { + o.edits = edits + } +} + +// WithVersion sets the version for the spec builder +func WithVersion(version string) Option { + return func(o *builder) { + o.version = version + } +} + +// WithVendor sets the vendor for the spec builder +func WithVendor(vendor string) Option { + return func(o *builder) { + o.vendor = vendor + } +} + +// WithClass sets the class for the spec builder +func WithClass(class string) Option { + return func(o *builder) { + o.class = class + } +} + +// WithFormat sets the output file format +func WithFormat(format string) Option { + return func(o *builder) { + o.format = format + } +} diff --git a/pkg/nvcdi/spec/spec.go b/pkg/nvcdi/spec/spec.go new file mode 100644 index 00000000..2bb26a71 --- /dev/null +++ b/pkg/nvcdi/spec/spec.go @@ -0,0 +1,120 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package spec + +import ( + "fmt" + "io" + "os" + "path/filepath" + + "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" + "github.com/container-orchestrated-devices/container-device-interface/specs-go" +) + +type spec struct { + *specs.Spec + format string +} + +var _ Interface = (*spec)(nil) + +// New creates a new spec with the specified options. +func New(opts ...Option) (Interface, error) { + return newBuilder(opts...).Build() +} + +// Save writes the spec to the specified path and overwrites the file if it exists. +func (s *spec) Save(path string) error { + path, err := s.normalizePath(path) + if err != nil { + return fmt.Errorf("failed to normalize path: %w", err) + } + + specDir := filepath.Dir(path) + registry := cdi.GetRegistry( + cdi.WithAutoRefresh(false), + cdi.WithSpecDirs(specDir), + ) + + return registry.SpecDB().WriteSpec(s.Raw(), filepath.Base(path)) +} + +// WriteTo writes the spec to the specified writer. +func (s *spec) WriteTo(w io.Writer) (int64, error) { + name, err := cdi.GenerateNameForSpec(s.Raw()) + if err != nil { + return 0, err + } + + path, _ := s.normalizePath(name) + tmpFile, err := os.CreateTemp("", "*"+filepath.Base(path)) + if err != nil { + return 0, err + } + defer os.Remove(tmpFile.Name()) + + if err := s.Save(tmpFile.Name()); err != nil { + return 0, err + } + + err = tmpFile.Close() + if err != nil { + return 0, fmt.Errorf("failed to close temporary file: %w", err) + } + + r, err := os.Open(tmpFile.Name()) + if err != nil { + return 0, fmt.Errorf("failed to open temporary file: %w", err) + } + defer r.Close() + + return io.Copy(w, r) +} + +// Raw returns a pointer to the raw spec. +func (s *spec) Raw() *specs.Spec { + return s.Spec +} + +// normalizePath ensures that the specified path has a supported extension +func (s *spec) normalizePath(path string) (string, error) { + if ext := filepath.Ext(path); ext != ".yaml" && ext != ".json" { + path += s.extension() + } + + if filepath.Clean(filepath.Dir(path)) == "." { + pwd, err := os.Getwd() + if err != nil { + return path, fmt.Errorf("failed to get current working directory: %v", err) + } + path = filepath.Join(pwd, path) + } + + return path, nil +} + +func (s *spec) extension() string { + switch s.format { + case FormatJSON: + return ".json" + case FormatYAML: + return ".yaml" + } + + return ".yaml" +}