From 89321edae6f77c5529e55ee86d56860bd33c9cc9 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Wed, 22 Feb 2023 16:19:22 +0200 Subject: [PATCH] Add top-level GetSpec function to nvcdi API Signed-off-by: Evan Lezar --- cmd/nvidia-ctk/cdi/generate/generate.go | 1 + pkg/nvcdi/lib.go | 6 +- pkg/nvcdi/spec/api.go | 12 ++- pkg/nvcdi/spec/builder.go | 127 ++++++++++++++++++++++++ pkg/nvcdi/spec/spec.go | 96 +++++++++++++----- 5 files changed, 216 insertions(+), 26 deletions(-) create mode 100644 pkg/nvcdi/spec/builder.go diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index b284c75e..86316977 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -187,6 +187,7 @@ func (m command) run(c *cli.Context, cfg *config) error { if err != nil { return fmt.Errorf("failed to convert CDI spec from YAML to JSON: %v", err) } + return nil } err = writeToOutput(cfg.format, data, outputTo) diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index aa90e396..451c1831 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -98,7 +98,11 @@ func (l *wrapper) GetSpec() (spec.Interface, error) { return nil, err } - return spec.New(deviceSpecs, *edits.ContainerEdits) + return spec.New( + spec.WithDeviceSpecs(deviceSpecs), + spec.WithEdits(*edits.ContainerEdits), + ) + } // resolveMode resolves the mode for CDI spec generation based on the current system. diff --git a/pkg/nvcdi/spec/api.go b/pkg/nvcdi/spec/api.go index 2d80bb43..e79216c2 100644 --- a/pkg/nvcdi/spec/api.go +++ b/pkg/nvcdi/spec/api.go @@ -16,10 +16,20 @@ package spec -import "github.com/container-orchestrated-devices/container-device-interface/specs-go" +import ( + "io" + + "github.com/container-orchestrated-devices/container-device-interface/specs-go" +) + +const ( + // DetectMinimumVersion is a constant that triggers a spec to detect the minimum required version. + DetectMinimumVersion = "DETECT_MINIMUM_VERSION" +) // Interface is the interface for the spec API type Interface interface { + io.WriterTo Save(string) error Raw() *specs.Spec } diff --git a/pkg/nvcdi/spec/builder.go b/pkg/nvcdi/spec/builder.go new file mode 100644 index 00000000..e8fc4b71 --- /dev/null +++ b/pkg/nvcdi/spec/builder.go @@ -0,0 +1,127 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package spec + +import ( + "fmt" + + "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" + "github.com/container-orchestrated-devices/container-device-interface/specs-go" +) + +type builder struct { + raw *specs.Spec + version string + vendor string + class string + deviceSpecs []specs.Device + edits specs.ContainerEdits + format string +} + +// NewBuilder creates a new spec builder with the supplied options +func NewBuilder(opts ...Option) *builder { + s := &builder{} + for _, opt := range opts { + opt(s) + } + if s.version == "" { + s.version = DetectMinimumVersion + } + if s.vendor == "" { + s.vendor = "nvidia.com" + } + if s.class == "" { + s.class = "gpu" + } + + return s +} + +// Build builds a CDI spec form the spec builder. +func (o *builder) Build() (*spec, error) { + raw := o.raw + + if raw == nil { + raw = &specs.Spec{ + Version: o.version, + Kind: fmt.Sprintf("%s/%s", o.vendor, o.class), + Devices: o.deviceSpecs, + ContainerEdits: o.edits, + } + } + + if raw.Version == DetectMinimumVersion { + minVersion, err := cdi.MinimumRequiredVersion(raw) + if err != nil { + return nil, fmt.Errorf("failed to get minumum required CDI spec version: %v", err) + } + raw.Version = minVersion + } + + s := spec{ + Spec: raw, + format: o.format, + } + + return &s, nil +} + +// Option defines a function that can be used to configure the spec builder. +type Option func(*builder) + +// WithDeviceSpecs sets the device specs for the spec builder +func WithDeviceSpecs(deviceSpecs []specs.Device) Option { + return func(o *builder) { + o.deviceSpecs = deviceSpecs + } +} + +// WithEdits sets the container edits for the spec builder +func WithEdits(edits specs.ContainerEdits) Option { + return func(o *builder) { + o.edits = edits + } +} + +// WithVersion sets the version for the spec builder +func WithVersion(version string) Option { + return func(o *builder) { + o.version = version + } +} + +// WithVendor sets the vendor for the spec builder +func WithVendor(vendor string) Option { + return func(o *builder) { + o.vendor = vendor + } +} + +// WithClass sets the class for the spec builder +func WithClass(class string) Option { + return func(o *builder) { + o.class = class + } +} + +// WithFormat sets the output file format +func WithFormat(format string) Option { + return func(o *builder) { + o.format = format + } +} diff --git a/pkg/nvcdi/spec/spec.go b/pkg/nvcdi/spec/spec.go index 231892ce..961993a2 100644 --- a/pkg/nvcdi/spec/spec.go +++ b/pkg/nvcdi/spec/spec.go @@ -17,44 +17,92 @@ package spec import ( - "fmt" + "io" + "os" + "path/filepath" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/specs-go" ) -type spec specs.Spec +type spec struct { + *specs.Spec + format string +} var _ Interface = (*spec)(nil) -// New creates a new spec with the specified deivice specs and edits. -func New(deviceSpecs []specs.Device, edits specs.ContainerEdits) (Interface, error) { - s := specs.Spec{ - // TODO: Should be set through an option - Version: "NOT_SET", - // TODO: Should be set through an option - Kind: "nvidia.com/gpu", - // TODO: Should be set through an option - Devices: deviceSpecs, - // TODO: Should be set through an option - ContainerEdits: edits, - } - - minVersion, err := cdi.MinimumRequiredVersion(&s) - if err != nil { - return nil, fmt.Errorf("failed to get minumum required CDI spec version: %v", err) - } - s.Version = minVersion - - return (*spec)(&s), nil +// New creates a new spec with the specified options. +func New(opts ...Option) (Interface, error) { + return NewBuilder(opts...).Build() } // Save writes the spec to the specified path and overwrites the file if it exists. func (s *spec) Save(path string) error { - return cdi.WriteSpec(s, path, true) + path = s.normalizePath(path) + + specDir := filepath.Dir(path) + registry := cdi.GetRegistry( + cdi.WithAutoRefresh(false), + cdi.WithSpecDirs(specDir), + ) + + return registry.SpecDB().WriteSpec(s.Raw(), filepath.Base(path)) +} + +// WriteTo writes the spec to the specified writer. +func (s *spec) WriteTo(w io.Writer) (int64, error) { + name, err := cdi.GenerateNameForSpec(s.Raw()) + if err != nil { + return 0, err + } + + path := s.normalizePath(name) + tmpFile, err := os.CreateTemp("", "*"+filepath.Base(path)) + if err != nil { + return 0, err + } + defer os.Remove(tmpFile.Name()) + + if err := s.Save(tmpFile.Name()); err != nil { + return 0, err + } + + err = tmpFile.Close() + if err != nil { + return 0, fmt.Errorf("failed to close temporary file: %w", err) + } + + r, err := os.Open(tmpFile.Name()) + if err != nil { + return 0, fmt.Errorf("failed to open temporary file: %w", err) + } + defer r.Close() + + return io.Copy(w, r) } // Raw returns a pointer to the raw spec. func (s *spec) Raw() *specs.Spec { - return (*specs.Spec)(s) + return s.Spec +} + +// normalizePath ensures that the specified path has a supported extension +func (s *spec) normalizePath(path string) string { + if ext := filepath.Ext(path); ext != ".yaml" && ext != ".json" { + path += s.extension() + } + + return path +} + +func (s *spec) extension() string { + switch s.format { + case "json": + return ".json" + case "yaml", "yml": + return ".yaml" + } + + return ".yaml" }