diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index a5e0f3d0..69e2a9fa 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -23,9 +23,9 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" - "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" - "github.com/opencontainers/runtime-spec/specs-go" + "github.com/container-orchestrated-devices/container-device-interface/pkg/parser" ) type cdiModifier struct { @@ -48,13 +48,11 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spe } logger.Debugf("Creating CDI modifier for devices: %v", devices) - m := cdiModifier{ - logger: logger, - specDirs: cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.SpecDirs, - devices: devices, - } - - return m, nil + return cdi.New( + cdi.WithLogger(logger), + cdi.WithDevices(devices...), + cdi.WithSpecDirs(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.SpecDirs...), + ) } func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.Config) ([]string, error) { @@ -80,7 +78,7 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C var devices []string seen := make(map[string]bool) for _, name := range envDevices.List() { - if !cdi.IsQualifiedName(name) { + if !parser.IsQualifiedName(name) { name = fmt.Sprintf("%s=%s", cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind, name) } if seen[name] { @@ -121,7 +119,7 @@ func getAnnotationDevices(prefixes []string, annotations map[string]string) ([]s var annotationDevices []string for key, devices := range devicesByKey { for _, device := range devices { - if !cdi.IsQualifiedName(device) { + if !parser.IsQualifiedName(device) { return nil, fmt.Errorf("invalid device name %q in annotation %q", device, key) } if seen[device] { @@ -134,22 +132,3 @@ func getAnnotationDevices(prefixes []string, annotations map[string]string) ([]s return annotationDevices, nil } - -// Modify loads the CDI registry and injects the specified CDI devices into the OCI runtime specification. -func (m cdiModifier) Modify(spec *specs.Spec) error { - registry := cdi.GetRegistry( - cdi.WithSpecDirs(m.specDirs...), - cdi.WithAutoRefresh(false), - ) - if err := registry.Refresh(); err != nil { - m.logger.Debugf("The following error was triggered when refreshing the CDI registry: %v", err) - } - - m.logger.Debugf("Injecting devices using CDI: %v", m.devices) - _, err := registry.InjectDevices(spec, m.devices...) - if err != nil { - return fmt.Errorf("failed to inject CDI devices: %v", err) - } - - return nil -} diff --git a/internal/modifier/cdi/builder.go b/internal/modifier/cdi/builder.go new file mode 100644 index 00000000..1a491f19 --- /dev/null +++ b/internal/modifier/cdi/builder.go @@ -0,0 +1,106 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package cdi + +import ( + "fmt" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" + "github.com/container-orchestrated-devices/container-device-interface/specs-go" +) + +type builder struct { + logger logger.Interface + specDirs []string + devices []string + cdiSpec *specs.Spec +} + +// Option represents a functional option for creating a CDI mofifier. +type Option func(*builder) + +// New creates a new CDI modifier. +func New(opts ...Option) (oci.SpecModifier, error) { + b := &builder{} + for _, opt := range opts { + opt(b) + } + if b.logger == nil { + b.logger = logger.New() + } + return b.build() +} + +// build uses the applied options and constructs a CDI modifier using the builder. +func (m builder) build() (oci.SpecModifier, error) { + if len(m.devices) == 0 && m.cdiSpec == nil { + return nil, nil + } + + if m.cdiSpec != nil { + modifier := fromCDISpec{ + cdiSpec: &cdi.Spec{Spec: m.cdiSpec}, + } + return modifier, nil + } + + registry, err := cdi.NewCache( + cdi.WithAutoRefresh(false), + cdi.WithSpecDirs(m.specDirs...), + ) + if err != nil { + return nil, fmt.Errorf("failed to create CDI registry: %v", err) + } + + modifier := fromRegistry{ + logger: m.logger, + registry: registry, + devices: m.devices, + } + + return modifier, nil +} + +// WithLogger sets the logger for the CDI modifier builder. +func WithLogger(logger logger.Interface) Option { + return func(b *builder) { + b.logger = logger + } +} + +// WithSpecDirs sets the spec directories for the CDI modifier builder. +func WithSpecDirs(specDirs ...string) Option { + return func(b *builder) { + b.specDirs = specDirs + } +} + +// WithDevices sets the devices for the CDI modifier builder. +func WithDevices(devices ...string) Option { + return func(b *builder) { + b.devices = devices + } +} + +// WithSpec sets the spec for the CDI modifier builder. +func WithSpec(spec *specs.Spec) Option { + return func(b *builder) { + b.cdiSpec = spec + } +} diff --git a/internal/modifier/cdi/registry.go b/internal/modifier/cdi/registry.go new file mode 100644 index 00000000..ad3666c1 --- /dev/null +++ b/internal/modifier/cdi/registry.go @@ -0,0 +1,62 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package cdi + +import ( + "errors" + "fmt" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" + "github.com/opencontainers/runtime-spec/specs-go" +) + +// fromRegistry represents the modifications performed using a CDI registry. +type fromRegistry struct { + logger logger.Interface + registry *cdi.Cache + devices []string +} + +var _ oci.SpecModifier = (*fromRegistry)(nil) + +// Modify applies the mofiications defined by the CDI registry to the incomming OCI spec. +func (m fromRegistry) Modify(spec *specs.Spec) error { + if err := m.registry.Refresh(); err != nil { + m.logger.Debugf("The following error was triggered when refreshing the CDI registry: %v", err) + } + + m.logger.Debugf("Injecting devices using CDI: %v", m.devices) + _, err := m.registry.InjectDevices(spec, m.devices...) + if err != nil { + var refreshErrors []error + for _, rerrs := range m.registry.GetErrors() { + refreshErrors = append(refreshErrors, rerrs...) + } + if rerr := errors.Join(refreshErrors...); rerr != nil { + // We log the errors that may have been generated while refreshing the CDI registry. + // These may be due to malformed specifications or device name conflicts that could be + // the cause of an injection failure. + m.logger.Warningf("Refreshing the CDI registry generated errors: %v", rerr) + } + + return fmt.Errorf("failed to inject CDI devices: %v", err) + } + + return nil +} diff --git a/internal/modifier/cdi/spec.go b/internal/modifier/cdi/spec.go new file mode 100644 index 00000000..57c65e9a --- /dev/null +++ b/internal/modifier/cdi/spec.go @@ -0,0 +1,46 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package cdi + +import ( + "fmt" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" + "github.com/opencontainers/runtime-spec/specs-go" +) + +// fromCDISpec represents the modifications performed from a raw CDI spec. +type fromCDISpec struct { + cdiSpec *cdi.Spec +} + +var _ oci.SpecModifier = (*fromCDISpec)(nil) + +// Modify applies the mofiications defined by the raw CDI spec to the incomming OCI spec. +func (m fromCDISpec) Modify(spec *specs.Spec) error { + for _, device := range m.cdiSpec.Devices { + cdiDevice := cdi.Device{ + Device: &device, + } + if err := cdiDevice.ApplyEdits(spec); err != nil { + return fmt.Errorf("failed to apply edits for device %q: %v", cdiDevice.GetQualifiedName(), err) + } + } + + return m.cdiSpec.ApplyEdits(spec) +}