Merge branch 'CNT-3931/add-spec-validation' into 'main'

Add nvcdi.spec for writing and validating CDI specifications

See merge request nvidia/container-toolkit/container-toolkit!306
This commit is contained in:
Evan Lezar 2023-03-06 08:52:56 +00:00
commit 19c20bb422
9 changed files with 397 additions and 104 deletions

View File

@ -18,7 +18,6 @@ package generate
import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
@ -26,19 +25,16 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
specs "github.com/container-orchestrated-devices/container-device-interface/specs-go"
"github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
"sigs.k8s.io/yaml"
)
const (
formatJSON = "json"
formatYAML = "yaml"
allDeviceName = "all"
)
@ -88,7 +84,7 @@ func (m command) build() *cli.Command {
&cli.StringFlag{
Name: "format",
Usage: "The output format for the generated spec [json | yaml]. This overrides the format defined by the output file extension (if specified).",
Value: formatYAML,
Value: spec.FormatYAML,
Destination: &cfg.format,
},
&cli.StringFlag{
@ -118,11 +114,12 @@ func (m command) build() *cli.Command {
return &c
}
func (m command) validateFlags(r *cli.Context, cfg *config) error {
func (m command) validateFlags(c *cli.Context, cfg *config) error {
cfg.format = strings.ToLower(cfg.format)
switch cfg.format {
case formatJSON:
case formatYAML:
case spec.FormatJSON:
case spec.FormatYAML:
default:
return fmt.Errorf("invalid output format: %v", cfg.format)
}
@ -143,31 +140,6 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error {
cfg.nvidiaCTKPath = discover.FindNvidiaCTK(m.logger, cfg.nvidiaCTKPath)
return nil
}
func (m command) run(c *cli.Context, cfg *config) error {
spec, err := m.generateSpec(cfg)
if err != nil {
return fmt.Errorf("failed to generate CDI spec: %v", err)
}
var outputTo io.Writer
if cfg.output == "" {
outputTo = os.Stdout
} else {
err := createParentDirsIfRequired(cfg.output)
if err != nil {
return fmt.Errorf("failed to create parent folders for output file: %v", err)
}
outputFile, err := os.Create(cfg.output)
if err != nil {
return fmt.Errorf("failed to create output file: %v", err)
}
defer outputFile.Close()
outputTo = outputFile
}
if outputFileFormat := formatFromFilename(cfg.output); outputFileFormat != "" {
m.logger.Debugf("Inferred output format as %q from output file name", outputFileFormat)
if !c.IsSet("format") {
@ -177,56 +149,40 @@ func (m command) run(c *cli.Context, cfg *config) error {
}
}
data, err := yaml.Marshal(spec)
if err != nil {
return fmt.Errorf("failed to marshal CDI spec: %v", err)
}
if strings.ToLower(cfg.format) == formatJSON {
data, err = yaml.YAMLToJSONStrict(data)
if err != nil {
return fmt.Errorf("failed to convert CDI spec from YAML to JSON: %v", err)
}
}
err = writeToOutput(cfg.format, data, outputTo)
if err != nil {
return fmt.Errorf("failed to write output: %v", err)
}
return nil
}
func (m command) run(c *cli.Context, cfg *config) error {
spec, err := m.generateSpec(cfg)
if err != nil {
return fmt.Errorf("failed to generate CDI spec: %v", err)
}
m.logger.Infof("Generated CDI spec with version %v", spec.Raw().Version)
if cfg.output == "" {
_, err := spec.WriteTo(os.Stdout)
if err != nil {
return fmt.Errorf("failed to write CDI spec to STDOUT: %v", err)
}
return nil
}
return spec.Save(cfg.output)
}
func formatFromFilename(filename string) string {
ext := filepath.Ext(filename)
switch strings.ToLower(ext) {
case ".json":
return formatJSON
case ".yaml":
return formatYAML
case ".yml":
return formatYAML
return spec.FormatJSON
case ".yaml", ".yml":
return spec.FormatYAML
}
return ""
}
func writeToOutput(format string, data []byte, output io.Writer) error {
if format == formatYAML {
_, err := output.Write([]byte("---\n"))
if err != nil {
return fmt.Errorf("failed to write YAML separator: %v", err)
}
}
_, err := output.Write(data)
if err != nil {
return fmt.Errorf("failed to write data: %v", err)
}
return nil
}
func (m command) generateSpec(cfg *config) (*specs.Spec, error) {
func (m command) generateSpec(cfg *config) (spec.Interface, error) {
deviceNamer, err := nvcdi.NewDeviceNamer(cfg.deviceNameStrategy)
if err != nil {
return nil, fmt.Errorf("failed to create device namer: %v", err)
@ -274,23 +230,13 @@ func (m command) generateSpec(cfg *config) (*specs.Spec, error) {
return nil, fmt.Errorf("failed to create edits common for entities: %v", err)
}
// We construct the spec and determine the minimum required version based on the specification.
spec := specs.Spec{
Version: "NOT_SET",
Kind: "nvidia.com/gpu",
Devices: deviceSpecs,
ContainerEdits: *commonEdits.ContainerEdits,
}
minVersion, err := cdi.MinimumRequiredVersion(&spec)
if err != nil {
return nil, fmt.Errorf("failed to get minumum required CDI spec version: %v", err)
}
m.logger.Infof("Using minimum required CDI spec version: %s", minVersion)
spec.Version = minVersion
return &spec, nil
return spec.New(
spec.WithVendor("nvidia.com"),
spec.WithClass("gpu"),
spec.WithDeviceSpecs(deviceSpecs),
spec.WithEdits(*commonEdits.ContainerEdits),
spec.WithFormat(cfg.format),
)
}
// MergeDeviceSpecs creates a device with the specified name which combines the edits from the previous devices.
@ -320,14 +266,3 @@ func MergeDeviceSpecs(deviceSpecs []specs.Device, mergedDeviceName string) (spec
}
return merged, nil
}
// createParentDirsIfRequired creates the parent folders of the specified path if requried.
// Note that MkdirAll does not specifically check whether the specified path is non-empty and raises an error if it is.
// The path will be empty if filename in the current folder is specified, for example
func createParentDirsIfRequired(filename string) error {
dir := filepath.Dir(filename)
if dir == "" {
return nil
}
return os.MkdirAll(dir, 0755)
}

View File

@ -17,6 +17,7 @@
package nvcdi
import (
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
@ -33,6 +34,7 @@ const (
// Interface defines the API for the nvcdi package
type Interface interface {
GetSpec() (spec.Interface, error)
GetCommonEdits() (*cdi.ContainerEdits, error)
GetAllDeviceSpecs() ([]specs.Device, error)
GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error)

View File

@ -20,6 +20,7 @@ import (
"fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
@ -29,6 +30,11 @@ type nvmllib nvcdilib
var _ Interface = (*nvmllib)(nil)
// GetSpec should not be called for nvmllib
func (l *nvmllib) GetSpec() (spec.Interface, error) {
return nil, fmt.Errorf("Unexpected call to nvmllib.GetSpec()")
}
// GetAllDeviceSpecs returns the device specs for all available devices.
func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) {
var deviceSpecs []specs.Device

View File

@ -20,6 +20,7 @@ import (
"fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
@ -29,6 +30,11 @@ type wsllib nvcdilib
var _ Interface = (*wsllib)(nil)
// GetSpec should not be called for wsllib
func (l *wsllib) GetSpec() (spec.Interface, error) {
return nil, fmt.Errorf("Unexpected call to wsllib.GetSpec()")
}
// GetAllDeviceSpecs returns the device specs for all available devices.
func (l *wsllib) GetAllDeviceSpecs() ([]specs.Device, error) {
device := newDXGDeviceDiscoverer(l.logger, l.driverRoot)

View File

@ -17,12 +17,20 @@
package nvcdi
import (
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/info"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
)
type wrapper struct {
Interface
vendor string
class string
}
type nvcdilib struct {
logger *logrus.Logger
nvmllib nvml.Interface
@ -32,6 +40,9 @@ type nvcdilib struct {
driverRoot string
nvidiaCTKPath string
vendor string
class string
infolib info.Interface
}
@ -60,6 +71,7 @@ func New(opts ...Option) Interface {
l.infolib = info.New()
}
var lib Interface
switch l.resolveMode() {
case ModeNvml:
if l.nvmllib == nil {
@ -69,13 +81,41 @@ func New(opts ...Option) Interface {
l.devicelib = device.New(device.WithNvml(l.nvmllib))
}
return (*nvmllib)(l)
lib = (*nvmllib)(l)
case ModeWsl:
return (*wsllib)(l)
lib = (*wsllib)(l)
default:
// TODO: We would like to return an error here instead of panicking
panic("Unknown mode")
}
// TODO: We want an error here.
return nil
w := wrapper{
Interface: lib,
vendor: l.vendor,
class: l.class,
}
return &w
}
// GetSpec combines the device specs and common edits from the wrapped Interface to a single spec.Interface.
func (l *wrapper) GetSpec() (spec.Interface, error) {
deviceSpecs, err := l.GetAllDeviceSpecs()
if err != nil {
return nil, err
}
edits, err := l.GetCommonEdits()
if err != nil {
return nil, err
}
return spec.New(
spec.WithDeviceSpecs(deviceSpecs),
spec.WithEdits(*edits.ContainerEdits),
spec.WithVendor(l.vendor),
spec.WithClass(l.class),
)
}
// resolveMode resolves the mode for CDI spec generation based on the current system.

View File

@ -73,3 +73,17 @@ func WithMode(mode string) Option {
l.mode = mode
}
}
// WithVendor sets the vendor for the library
func WithVendor(vendor string) Option {
return func(o *nvcdilib) {
o.vendor = vendor
}
}
// WithClass sets the class for the library
func WithClass(class string) Option {
return func(o *nvcdilib) {
o.class = class
}
}

40
pkg/nvcdi/spec/api.go Normal file
View File

@ -0,0 +1,40 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package spec
import (
"io"
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
)
const (
// DetectMinimumVersion is a constant that triggers a spec to detect the minimum required version.
DetectMinimumVersion = "DETECT_MINIMUM_VERSION"
// FormatJSON indicates a JSON output format
FormatJSON = "json"
// FormatYAML indicates a YAML output format
FormatYAML = "yaml"
)
// Interface is the interface for the spec API
type Interface interface {
io.WriterTo
Save(string) error
Raw() *specs.Spec
}

130
pkg/nvcdi/spec/builder.go Normal file
View File

@ -0,0 +1,130 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package spec
import (
"fmt"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
)
type builder struct {
raw *specs.Spec
version string
vendor string
class string
deviceSpecs []specs.Device
edits specs.ContainerEdits
format string
}
// newBuilder creates a new spec builder with the supplied options
func newBuilder(opts ...Option) *builder {
s := &builder{}
for _, opt := range opts {
opt(s)
}
if s.version == "" {
s.version = DetectMinimumVersion
}
if s.vendor == "" {
s.vendor = "nvidia.com"
}
if s.class == "" {
s.class = "gpu"
}
if s.format == "" {
s.format = FormatYAML
}
return s
}
// Build builds a CDI spec form the spec builder.
func (o *builder) Build() (*spec, error) {
raw := o.raw
if raw == nil {
raw = &specs.Spec{
Version: o.version,
Kind: fmt.Sprintf("%s/%s", o.vendor, o.class),
Devices: o.deviceSpecs,
ContainerEdits: o.edits,
}
}
if raw.Version == DetectMinimumVersion {
minVersion, err := cdi.MinimumRequiredVersion(raw)
if err != nil {
return nil, fmt.Errorf("failed to get minumum required CDI spec version: %v", err)
}
raw.Version = minVersion
}
s := spec{
Spec: raw,
format: o.format,
}
return &s, nil
}
// Option defines a function that can be used to configure the spec builder.
type Option func(*builder)
// WithDeviceSpecs sets the device specs for the spec builder
func WithDeviceSpecs(deviceSpecs []specs.Device) Option {
return func(o *builder) {
o.deviceSpecs = deviceSpecs
}
}
// WithEdits sets the container edits for the spec builder
func WithEdits(edits specs.ContainerEdits) Option {
return func(o *builder) {
o.edits = edits
}
}
// WithVersion sets the version for the spec builder
func WithVersion(version string) Option {
return func(o *builder) {
o.version = version
}
}
// WithVendor sets the vendor for the spec builder
func WithVendor(vendor string) Option {
return func(o *builder) {
o.vendor = vendor
}
}
// WithClass sets the class for the spec builder
func WithClass(class string) Option {
return func(o *builder) {
o.class = class
}
}
// WithFormat sets the output file format
func WithFormat(format string) Option {
return func(o *builder) {
o.format = format
}
}

120
pkg/nvcdi/spec/spec.go Normal file
View File

@ -0,0 +1,120 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package spec
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
)
type spec struct {
*specs.Spec
format string
}
var _ Interface = (*spec)(nil)
// New creates a new spec with the specified options.
func New(opts ...Option) (Interface, error) {
return newBuilder(opts...).Build()
}
// Save writes the spec to the specified path and overwrites the file if it exists.
func (s *spec) Save(path string) error {
path, err := s.normalizePath(path)
if err != nil {
return fmt.Errorf("failed to normalize path: %w", err)
}
specDir := filepath.Dir(path)
registry := cdi.GetRegistry(
cdi.WithAutoRefresh(false),
cdi.WithSpecDirs(specDir),
)
return registry.SpecDB().WriteSpec(s.Raw(), filepath.Base(path))
}
// WriteTo writes the spec to the specified writer.
func (s *spec) WriteTo(w io.Writer) (int64, error) {
name, err := cdi.GenerateNameForSpec(s.Raw())
if err != nil {
return 0, err
}
path, _ := s.normalizePath(name)
tmpFile, err := os.CreateTemp("", "*"+filepath.Base(path))
if err != nil {
return 0, err
}
defer os.Remove(tmpFile.Name())
if err := s.Save(tmpFile.Name()); err != nil {
return 0, err
}
err = tmpFile.Close()
if err != nil {
return 0, fmt.Errorf("failed to close temporary file: %w", err)
}
r, err := os.Open(tmpFile.Name())
if err != nil {
return 0, fmt.Errorf("failed to open temporary file: %w", err)
}
defer r.Close()
return io.Copy(w, r)
}
// Raw returns a pointer to the raw spec.
func (s *spec) Raw() *specs.Spec {
return s.Spec
}
// normalizePath ensures that the specified path has a supported extension
func (s *spec) normalizePath(path string) (string, error) {
if ext := filepath.Ext(path); ext != ".yaml" && ext != ".json" {
path += s.extension()
}
if filepath.Clean(filepath.Dir(path)) == "." {
pwd, err := os.Getwd()
if err != nil {
return path, fmt.Errorf("failed to get current working directory: %v", err)
}
path = filepath.Join(pwd, path)
}
return path, nil
}
func (s *spec) extension() string {
switch s.format {
case FormatJSON:
return ".json"
case FormatYAML:
return ".yaml"
}
return ".yaml"
}