Merge branch 'CNT-3931/add-spec-validation' into 'main'

Add nvcdi.spec for writing and validating CDI specifications

See merge request nvidia/container-toolkit/container-toolkit!306
This commit is contained in:
Evan Lezar 2023-03-06 08:52:56 +00:00
commit 19c20bb422
9 changed files with 397 additions and 104 deletions

View File

@ -18,7 +18,6 @@ package generate
import ( import (
"fmt" "fmt"
"io"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@ -26,19 +25,16 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
specs "github.com/container-orchestrated-devices/container-device-interface/specs-go" specs "github.com/container-orchestrated-devices/container-device-interface/specs-go"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
"sigs.k8s.io/yaml"
) )
const ( const (
formatJSON = "json"
formatYAML = "yaml"
allDeviceName = "all" allDeviceName = "all"
) )
@ -88,7 +84,7 @@ func (m command) build() *cli.Command {
&cli.StringFlag{ &cli.StringFlag{
Name: "format", Name: "format",
Usage: "The output format for the generated spec [json | yaml]. This overrides the format defined by the output file extension (if specified).", Usage: "The output format for the generated spec [json | yaml]. This overrides the format defined by the output file extension (if specified).",
Value: formatYAML, Value: spec.FormatYAML,
Destination: &cfg.format, Destination: &cfg.format,
}, },
&cli.StringFlag{ &cli.StringFlag{
@ -118,11 +114,12 @@ func (m command) build() *cli.Command {
return &c return &c
} }
func (m command) validateFlags(r *cli.Context, cfg *config) error { func (m command) validateFlags(c *cli.Context, cfg *config) error {
cfg.format = strings.ToLower(cfg.format) cfg.format = strings.ToLower(cfg.format)
switch cfg.format { switch cfg.format {
case formatJSON: case spec.FormatJSON:
case formatYAML: case spec.FormatYAML:
default: default:
return fmt.Errorf("invalid output format: %v", cfg.format) return fmt.Errorf("invalid output format: %v", cfg.format)
} }
@ -143,31 +140,6 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error {
cfg.nvidiaCTKPath = discover.FindNvidiaCTK(m.logger, cfg.nvidiaCTKPath) cfg.nvidiaCTKPath = discover.FindNvidiaCTK(m.logger, cfg.nvidiaCTKPath)
return nil
}
func (m command) run(c *cli.Context, cfg *config) error {
spec, err := m.generateSpec(cfg)
if err != nil {
return fmt.Errorf("failed to generate CDI spec: %v", err)
}
var outputTo io.Writer
if cfg.output == "" {
outputTo = os.Stdout
} else {
err := createParentDirsIfRequired(cfg.output)
if err != nil {
return fmt.Errorf("failed to create parent folders for output file: %v", err)
}
outputFile, err := os.Create(cfg.output)
if err != nil {
return fmt.Errorf("failed to create output file: %v", err)
}
defer outputFile.Close()
outputTo = outputFile
}
if outputFileFormat := formatFromFilename(cfg.output); outputFileFormat != "" { if outputFileFormat := formatFromFilename(cfg.output); outputFileFormat != "" {
m.logger.Debugf("Inferred output format as %q from output file name", outputFileFormat) m.logger.Debugf("Inferred output format as %q from output file name", outputFileFormat)
if !c.IsSet("format") { if !c.IsSet("format") {
@ -177,56 +149,40 @@ func (m command) run(c *cli.Context, cfg *config) error {
} }
} }
data, err := yaml.Marshal(spec)
if err != nil {
return fmt.Errorf("failed to marshal CDI spec: %v", err)
}
if strings.ToLower(cfg.format) == formatJSON {
data, err = yaml.YAMLToJSONStrict(data)
if err != nil {
return fmt.Errorf("failed to convert CDI spec from YAML to JSON: %v", err)
}
}
err = writeToOutput(cfg.format, data, outputTo)
if err != nil {
return fmt.Errorf("failed to write output: %v", err)
}
return nil return nil
} }
func (m command) run(c *cli.Context, cfg *config) error {
spec, err := m.generateSpec(cfg)
if err != nil {
return fmt.Errorf("failed to generate CDI spec: %v", err)
}
m.logger.Infof("Generated CDI spec with version %v", spec.Raw().Version)
if cfg.output == "" {
_, err := spec.WriteTo(os.Stdout)
if err != nil {
return fmt.Errorf("failed to write CDI spec to STDOUT: %v", err)
}
return nil
}
return spec.Save(cfg.output)
}
func formatFromFilename(filename string) string { func formatFromFilename(filename string) string {
ext := filepath.Ext(filename) ext := filepath.Ext(filename)
switch strings.ToLower(ext) { switch strings.ToLower(ext) {
case ".json": case ".json":
return formatJSON return spec.FormatJSON
case ".yaml": case ".yaml", ".yml":
return formatYAML return spec.FormatYAML
case ".yml":
return formatYAML
} }
return "" return ""
} }
func writeToOutput(format string, data []byte, output io.Writer) error { func (m command) generateSpec(cfg *config) (spec.Interface, error) {
if format == formatYAML {
_, err := output.Write([]byte("---\n"))
if err != nil {
return fmt.Errorf("failed to write YAML separator: %v", err)
}
}
_, err := output.Write(data)
if err != nil {
return fmt.Errorf("failed to write data: %v", err)
}
return nil
}
func (m command) generateSpec(cfg *config) (*specs.Spec, error) {
deviceNamer, err := nvcdi.NewDeviceNamer(cfg.deviceNameStrategy) deviceNamer, err := nvcdi.NewDeviceNamer(cfg.deviceNameStrategy)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create device namer: %v", err) return nil, fmt.Errorf("failed to create device namer: %v", err)
@ -274,23 +230,13 @@ func (m command) generateSpec(cfg *config) (*specs.Spec, error) {
return nil, fmt.Errorf("failed to create edits common for entities: %v", err) return nil, fmt.Errorf("failed to create edits common for entities: %v", err)
} }
// We construct the spec and determine the minimum required version based on the specification. return spec.New(
spec := specs.Spec{ spec.WithVendor("nvidia.com"),
Version: "NOT_SET", spec.WithClass("gpu"),
Kind: "nvidia.com/gpu", spec.WithDeviceSpecs(deviceSpecs),
Devices: deviceSpecs, spec.WithEdits(*commonEdits.ContainerEdits),
ContainerEdits: *commonEdits.ContainerEdits, spec.WithFormat(cfg.format),
} )
minVersion, err := cdi.MinimumRequiredVersion(&spec)
if err != nil {
return nil, fmt.Errorf("failed to get minumum required CDI spec version: %v", err)
}
m.logger.Infof("Using minimum required CDI spec version: %s", minVersion)
spec.Version = minVersion
return &spec, nil
} }
// MergeDeviceSpecs creates a device with the specified name which combines the edits from the previous devices. // MergeDeviceSpecs creates a device with the specified name which combines the edits from the previous devices.
@ -320,14 +266,3 @@ func MergeDeviceSpecs(deviceSpecs []specs.Device, mergedDeviceName string) (spec
} }
return merged, nil return merged, nil
} }
// createParentDirsIfRequired creates the parent folders of the specified path if requried.
// Note that MkdirAll does not specifically check whether the specified path is non-empty and raises an error if it is.
// The path will be empty if filename in the current folder is specified, for example
func createParentDirsIfRequired(filename string) error {
dir := filepath.Dir(filename)
if dir == "" {
return nil
}
return os.MkdirAll(dir, 0755)
}

View File

@ -17,6 +17,7 @@
package nvcdi package nvcdi
import ( import (
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/specs-go" "github.com/container-orchestrated-devices/container-device-interface/specs-go"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
@ -33,6 +34,7 @@ const (
// Interface defines the API for the nvcdi package // Interface defines the API for the nvcdi package
type Interface interface { type Interface interface {
GetSpec() (spec.Interface, error)
GetCommonEdits() (*cdi.ContainerEdits, error) GetCommonEdits() (*cdi.ContainerEdits, error)
GetAllDeviceSpecs() ([]specs.Device, error) GetAllDeviceSpecs() ([]specs.Device, error)
GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error)

View File

@ -20,6 +20,7 @@ import (
"fmt" "fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/specs-go" "github.com/container-orchestrated-devices/container-device-interface/specs-go"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
@ -29,6 +30,11 @@ type nvmllib nvcdilib
var _ Interface = (*nvmllib)(nil) var _ Interface = (*nvmllib)(nil)
// GetSpec should not be called for nvmllib
func (l *nvmllib) GetSpec() (spec.Interface, error) {
return nil, fmt.Errorf("Unexpected call to nvmllib.GetSpec()")
}
// GetAllDeviceSpecs returns the device specs for all available devices. // GetAllDeviceSpecs returns the device specs for all available devices.
func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) { func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) {
var deviceSpecs []specs.Device var deviceSpecs []specs.Device

View File

@ -20,6 +20,7 @@ import (
"fmt" "fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/specs-go" "github.com/container-orchestrated-devices/container-device-interface/specs-go"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
@ -29,6 +30,11 @@ type wsllib nvcdilib
var _ Interface = (*wsllib)(nil) var _ Interface = (*wsllib)(nil)
// GetSpec should not be called for wsllib
func (l *wsllib) GetSpec() (spec.Interface, error) {
return nil, fmt.Errorf("Unexpected call to wsllib.GetSpec()")
}
// GetAllDeviceSpecs returns the device specs for all available devices. // GetAllDeviceSpecs returns the device specs for all available devices.
func (l *wsllib) GetAllDeviceSpecs() ([]specs.Device, error) { func (l *wsllib) GetAllDeviceSpecs() ([]specs.Device, error) {
device := newDXGDeviceDiscoverer(l.logger, l.driverRoot) device := newDXGDeviceDiscoverer(l.logger, l.driverRoot)

View File

@ -17,12 +17,20 @@
package nvcdi package nvcdi
import ( import (
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/info" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/info"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
) )
type wrapper struct {
Interface
vendor string
class string
}
type nvcdilib struct { type nvcdilib struct {
logger *logrus.Logger logger *logrus.Logger
nvmllib nvml.Interface nvmllib nvml.Interface
@ -32,6 +40,9 @@ type nvcdilib struct {
driverRoot string driverRoot string
nvidiaCTKPath string nvidiaCTKPath string
vendor string
class string
infolib info.Interface infolib info.Interface
} }
@ -60,6 +71,7 @@ func New(opts ...Option) Interface {
l.infolib = info.New() l.infolib = info.New()
} }
var lib Interface
switch l.resolveMode() { switch l.resolveMode() {
case ModeNvml: case ModeNvml:
if l.nvmllib == nil { if l.nvmllib == nil {
@ -69,13 +81,41 @@ func New(opts ...Option) Interface {
l.devicelib = device.New(device.WithNvml(l.nvmllib)) l.devicelib = device.New(device.WithNvml(l.nvmllib))
} }
return (*nvmllib)(l) lib = (*nvmllib)(l)
case ModeWsl: case ModeWsl:
return (*wsllib)(l) lib = (*wsllib)(l)
default:
// TODO: We would like to return an error here instead of panicking
panic("Unknown mode")
} }
// TODO: We want an error here. w := wrapper{
return nil Interface: lib,
vendor: l.vendor,
class: l.class,
}
return &w
}
// GetSpec combines the device specs and common edits from the wrapped Interface to a single spec.Interface.
func (l *wrapper) GetSpec() (spec.Interface, error) {
deviceSpecs, err := l.GetAllDeviceSpecs()
if err != nil {
return nil, err
}
edits, err := l.GetCommonEdits()
if err != nil {
return nil, err
}
return spec.New(
spec.WithDeviceSpecs(deviceSpecs),
spec.WithEdits(*edits.ContainerEdits),
spec.WithVendor(l.vendor),
spec.WithClass(l.class),
)
} }
// resolveMode resolves the mode for CDI spec generation based on the current system. // resolveMode resolves the mode for CDI spec generation based on the current system.

View File

@ -73,3 +73,17 @@ func WithMode(mode string) Option {
l.mode = mode l.mode = mode
} }
} }
// WithVendor sets the vendor for the library
func WithVendor(vendor string) Option {
return func(o *nvcdilib) {
o.vendor = vendor
}
}
// WithClass sets the class for the library
func WithClass(class string) Option {
return func(o *nvcdilib) {
o.class = class
}
}

40
pkg/nvcdi/spec/api.go Normal file
View File

@ -0,0 +1,40 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package spec
import (
"io"
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
)
const (
// DetectMinimumVersion is a constant that triggers a spec to detect the minimum required version.
DetectMinimumVersion = "DETECT_MINIMUM_VERSION"
// FormatJSON indicates a JSON output format
FormatJSON = "json"
// FormatYAML indicates a YAML output format
FormatYAML = "yaml"
)
// Interface is the interface for the spec API
type Interface interface {
io.WriterTo
Save(string) error
Raw() *specs.Spec
}

130
pkg/nvcdi/spec/builder.go Normal file
View File

@ -0,0 +1,130 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package spec
import (
"fmt"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
)
type builder struct {
raw *specs.Spec
version string
vendor string
class string
deviceSpecs []specs.Device
edits specs.ContainerEdits
format string
}
// newBuilder creates a new spec builder with the supplied options
func newBuilder(opts ...Option) *builder {
s := &builder{}
for _, opt := range opts {
opt(s)
}
if s.version == "" {
s.version = DetectMinimumVersion
}
if s.vendor == "" {
s.vendor = "nvidia.com"
}
if s.class == "" {
s.class = "gpu"
}
if s.format == "" {
s.format = FormatYAML
}
return s
}
// Build builds a CDI spec form the spec builder.
func (o *builder) Build() (*spec, error) {
raw := o.raw
if raw == nil {
raw = &specs.Spec{
Version: o.version,
Kind: fmt.Sprintf("%s/%s", o.vendor, o.class),
Devices: o.deviceSpecs,
ContainerEdits: o.edits,
}
}
if raw.Version == DetectMinimumVersion {
minVersion, err := cdi.MinimumRequiredVersion(raw)
if err != nil {
return nil, fmt.Errorf("failed to get minumum required CDI spec version: %v", err)
}
raw.Version = minVersion
}
s := spec{
Spec: raw,
format: o.format,
}
return &s, nil
}
// Option defines a function that can be used to configure the spec builder.
type Option func(*builder)
// WithDeviceSpecs sets the device specs for the spec builder
func WithDeviceSpecs(deviceSpecs []specs.Device) Option {
return func(o *builder) {
o.deviceSpecs = deviceSpecs
}
}
// WithEdits sets the container edits for the spec builder
func WithEdits(edits specs.ContainerEdits) Option {
return func(o *builder) {
o.edits = edits
}
}
// WithVersion sets the version for the spec builder
func WithVersion(version string) Option {
return func(o *builder) {
o.version = version
}
}
// WithVendor sets the vendor for the spec builder
func WithVendor(vendor string) Option {
return func(o *builder) {
o.vendor = vendor
}
}
// WithClass sets the class for the spec builder
func WithClass(class string) Option {
return func(o *builder) {
o.class = class
}
}
// WithFormat sets the output file format
func WithFormat(format string) Option {
return func(o *builder) {
o.format = format
}
}

120
pkg/nvcdi/spec/spec.go Normal file
View File

@ -0,0 +1,120 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package spec
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
)
type spec struct {
*specs.Spec
format string
}
var _ Interface = (*spec)(nil)
// New creates a new spec with the specified options.
func New(opts ...Option) (Interface, error) {
return newBuilder(opts...).Build()
}
// Save writes the spec to the specified path and overwrites the file if it exists.
func (s *spec) Save(path string) error {
path, err := s.normalizePath(path)
if err != nil {
return fmt.Errorf("failed to normalize path: %w", err)
}
specDir := filepath.Dir(path)
registry := cdi.GetRegistry(
cdi.WithAutoRefresh(false),
cdi.WithSpecDirs(specDir),
)
return registry.SpecDB().WriteSpec(s.Raw(), filepath.Base(path))
}
// WriteTo writes the spec to the specified writer.
func (s *spec) WriteTo(w io.Writer) (int64, error) {
name, err := cdi.GenerateNameForSpec(s.Raw())
if err != nil {
return 0, err
}
path, _ := s.normalizePath(name)
tmpFile, err := os.CreateTemp("", "*"+filepath.Base(path))
if err != nil {
return 0, err
}
defer os.Remove(tmpFile.Name())
if err := s.Save(tmpFile.Name()); err != nil {
return 0, err
}
err = tmpFile.Close()
if err != nil {
return 0, fmt.Errorf("failed to close temporary file: %w", err)
}
r, err := os.Open(tmpFile.Name())
if err != nil {
return 0, fmt.Errorf("failed to open temporary file: %w", err)
}
defer r.Close()
return io.Copy(w, r)
}
// Raw returns a pointer to the raw spec.
func (s *spec) Raw() *specs.Spec {
return s.Spec
}
// normalizePath ensures that the specified path has a supported extension
func (s *spec) normalizePath(path string) (string, error) {
if ext := filepath.Ext(path); ext != ".yaml" && ext != ".json" {
path += s.extension()
}
if filepath.Clean(filepath.Dir(path)) == "." {
pwd, err := os.Getwd()
if err != nil {
return path, fmt.Errorf("failed to get current working directory: %v", err)
}
path = filepath.Join(pwd, path)
}
return path, nil
}
func (s *spec) extension() string {
switch s.format {
case FormatJSON:
return ".json"
case FormatYAML:
return ".yaml"
}
return ".yaml"
}