diff --git a/pkg/nvcdi/spec/api.go b/pkg/nvcdi/spec/api.go index b91072c5..918f56f2 100644 --- a/pkg/nvcdi/spec/api.go +++ b/pkg/nvcdi/spec/api.go @@ -24,6 +24,8 @@ import ( const ( // DetectMinimumVersion is a constant that triggers a spec to detect the minimum required version. + // + // Deprecated: DetectMinimumVersion is deprecated and will be removed. DetectMinimumVersion = "DETECT_MINIMUM_VERSION" // FormatJSON indicates a JSON output format diff --git a/pkg/nvcdi/spec/builder.go b/pkg/nvcdi/spec/builder.go index 8fb80a68..ec8cc144 100644 --- a/pkg/nvcdi/spec/builder.go +++ b/pkg/nvcdi/spec/builder.go @@ -39,6 +39,8 @@ type builder struct { mergedDeviceOptions []transform.MergedDeviceOption noSimplify bool permissions os.FileMode + + transformOnSave transform.Transformer } // newBuilder creates a new spec builder with the supplied options @@ -47,15 +49,23 @@ func newBuilder(opts ...Option) *builder { for _, opt := range opts { opt(s) } + if s.raw != nil { s.noSimplify = true vendor, class := parser.ParseQualifier(s.raw.Kind) - s.vendor = vendor - s.class = class + if s.vendor == "" { + s.vendor = vendor + } + if s.class == "" { + s.class = class + } + if s.version == "" || s.version == DetectMinimumVersion { + s.version = s.raw.Version + } } - - if s.version == "" { - s.version = DetectMinimumVersion + if s.version == "" || s.version == DetectMinimumVersion { + s.transformOnSave = &setMinimumRequiredVersion{} + s.version = cdi.CurrentVersion } if s.vendor == "" { s.vendor = "nvidia.com" @@ -83,13 +93,8 @@ func (o *builder) Build() (*spec, error) { ContainerEdits: o.edits, } } - - if raw.Version == DetectMinimumVersion { - minVersion, err := cdi.MinimumRequiredVersion(raw) - if err != nil { - return nil, fmt.Errorf("failed to get minimum required CDI spec version: %v", err) - } - raw.Version = minVersion + if raw.Version == "" { + raw.Version = o.version } if !o.noSimplify { @@ -110,11 +115,11 @@ func (o *builder) Build() (*spec, error) { } s := spec{ - Spec: raw, - format: o.format, - permissions: o.permissions, + Spec: raw, + format: o.format, + permissions: o.permissions, + transformOnSave: o.transformOnSave, } - return &s, nil } diff --git a/pkg/nvcdi/spec/set-minimum-version.go b/pkg/nvcdi/spec/set-minimum-version.go new file mode 100644 index 00000000..69969c0b --- /dev/null +++ b/pkg/nvcdi/spec/set-minimum-version.go @@ -0,0 +1,35 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package spec + +import ( + "fmt" + + "tags.cncf.io/container-device-interface/pkg/cdi" + "tags.cncf.io/container-device-interface/specs-go" +) + +type setMinimumRequiredVersion struct{} + +func (d setMinimumRequiredVersion) Transform(spec *specs.Spec) error { + minVersion, err := cdi.MinimumRequiredVersion(spec) + if err != nil { + return fmt.Errorf("failed to get minimum required CDI spec version: %v", err) + } + spec.Version = minVersion + return nil +} diff --git a/pkg/nvcdi/spec/spec.go b/pkg/nvcdi/spec/spec.go index c27c4de3..28cccc51 100644 --- a/pkg/nvcdi/spec/spec.go +++ b/pkg/nvcdi/spec/spec.go @@ -24,12 +24,15 @@ import ( "tags.cncf.io/container-device-interface/pkg/cdi" "tags.cncf.io/container-device-interface/specs-go" + + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" ) type spec struct { *specs.Spec - format string - permissions os.FileMode + format string + permissions os.FileMode + transformOnSave transform.Transformer } var _ Interface = (*spec)(nil) @@ -41,6 +44,12 @@ func New(opts ...Option) (Interface, error) { // Save writes the spec to the specified path and overwrites the file if it exists. func (s *spec) Save(path string) error { + if s.transformOnSave != nil { + err := s.transformOnSave.Transform(s.Raw()) + if err != nil { + return fmt.Errorf("error applying transform: %w", err) + } + } path, err := s.normalizePath(path) if err != nil { return fmt.Errorf("failed to normalize path: %w", err) diff --git a/pkg/nvcdi/spec/spec_test.go b/pkg/nvcdi/spec/spec_test.go new file mode 100644 index 00000000..15c84070 --- /dev/null +++ b/pkg/nvcdi/spec/spec_test.go @@ -0,0 +1,181 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package spec + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + "tags.cncf.io/container-device-interface/specs-go" + + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform/root" +) + +func TestSpec(t *testing.T) { + testCases := []struct { + description string + options []Option + expectedNewError error + transform transform.Transformer + expectedSpec string + }{ + { + description: "default options return empty spec", + expectedSpec: `--- +cdiVersion: 0.3.0 +containerEdits: {} +devices: null +kind: nvidia.com/gpu +`, + }, + { + description: "version is overridden", + options: []Option{WithVersion("0.5.0")}, + expectedSpec: `--- +cdiVersion: 0.5.0 +containerEdits: {} +devices: null +kind: nvidia.com/gpu +`, + }, + { + description: "raw spec is used as is", + options: []Option{WithRawSpec( + &specs.Spec{ + Version: "0.5.0", + Kind: "nvidia.com/gpu", + ContainerEdits: specs.ContainerEdits{ + Env: []string{"FOO=bar"}, + }, + }, + )}, + expectedSpec: `--- +cdiVersion: 0.5.0 +containerEdits: + env: + - FOO=bar +devices: null +kind: nvidia.com/gpu +`, + }, + { + description: "raw spec with no version uses minimum version", + options: []Option{WithRawSpec( + &specs.Spec{ + Kind: "nvidia.com/gpu", + ContainerEdits: specs.ContainerEdits{ + Env: []string{"FOO=bar"}, + }, + }, + )}, + expectedSpec: `--- +cdiVersion: 0.3.0 +containerEdits: + env: + - FOO=bar +devices: null +kind: nvidia.com/gpu +`, + }, + { + description: "spec with host dev path uses 0.5.0 version", + options: []Option{WithRawSpec( + &specs.Spec{ + Kind: "nvidia.com/gpu", + ContainerEdits: specs.ContainerEdits{ + Env: []string{"FOO=bar"}, + DeviceNodes: []*specs.DeviceNode{ + { + HostPath: "/some/dev/dev0", + Path: "/dev/dev0", + }, + }, + }, + }, + )}, + expectedSpec: `--- +cdiVersion: 0.5.0 +containerEdits: + deviceNodes: + - hostPath: /some/dev/dev0 + path: /dev/dev0 + env: + - FOO=bar +devices: null +kind: nvidia.com/gpu +`, + }, + { + description: "transformed spec uses minimum version", + options: []Option{WithRawSpec( + &specs.Spec{ + Kind: "nvidia.com/gpu", + ContainerEdits: specs.ContainerEdits{ + Env: []string{"FOO=bar"}, + DeviceNodes: []*specs.DeviceNode{ + { + HostPath: "/some/dev/dev0", + Path: "/dev/dev0", + }, + }, + }, + }, + )}, + transform: transform.Merge( + root.New( + root.WithRoot("/some/dev/"), + root.WithTargetRoot("/dev/"), + root.WithRelativeTo("host"), + ), + transform.NewSimplifier(), + ), + expectedSpec: `--- +cdiVersion: 0.5.0 +containerEdits: + deviceNodes: + - hostPath: /dev/dev0 + path: /dev/dev0 + env: + - FOO=bar +devices: null +kind: nvidia.com/gpu +`, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + + s, err := New(tc.options...) + require.ErrorIs(t, err, tc.expectedNewError) + + if tc.transform != nil { + err := tc.transform.Transform(s.Raw()) + require.NoError(t, err) + } + + buf := new(bytes.Buffer) + + _, err = s.WriteTo(buf) + require.NoError(t, err) + + require.EqualValues(t, tc.expectedSpec, buf.String()) + }) + } +}