diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index 24ad57a4..8a6f061b 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -23,8 +23,8 @@ import ( "strings" "github.com/NVIDIA/nvidia-container-toolkit/internal/config" - "github.com/NVIDIA/nvidia-container-toolkit/internal/discover/csv" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" @@ -51,7 +51,8 @@ type options struct { class string csv struct { - files cli.StringSlice + files cli.StringSlice + librarySearchPaths cli.StringSlice } } @@ -134,6 +135,11 @@ func (m command) build() *cli.Command { Value: cli.NewStringSlice(csv.DefaultFileList()...), Destination: &opts.csv.files, }, + &cli.StringSliceFlag{ + Name: "csv.library-search-path", + Usage: "Specify the path to search for libraries when discovering the entities that should be included in the CDI specification. This currently only affects CDI mode", + Destination: &opts.csv.librarySearchPaths, + }, } return &c @@ -227,6 +233,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) { nvcdi.WithDeviceNamer(deviceNamer), nvcdi.WithMode(string(opts.mode)), nvcdi.WithCSVFiles(opts.csv.files.Value()), + nvcdi.WithLibrarySearchPaths(opts.csv.librarySearchPaths.Value()), ) if err != nil { return nil, fmt.Errorf("failed to create CDI library: %v", err) diff --git a/cmd/nvidia-ctk/hook/create-symlinks/create-symlinks.go b/cmd/nvidia-ctk/hook/create-symlinks/create-symlinks.go index 3aa5dc9f..16ef3ed5 100644 --- a/cmd/nvidia-ctk/hook/create-symlinks/create-symlinks.go +++ b/cmd/nvidia-ctk/hook/create-symlinks/create-symlinks.go @@ -22,11 +22,11 @@ import ( "path/filepath" "strings" - "github.com/NVIDIA/nvidia-container-toolkit/internal/discover/csv" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" "github.com/urfave/cli/v2" ) diff --git a/internal/discover/csv.go b/internal/discover/csv.go deleted file mode 100644 index 2d3b4018..00000000 --- a/internal/discover/csv.go +++ /dev/null @@ -1,110 +0,0 @@ -/** -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -**/ - -package discover - -import ( - "fmt" - - "github.com/NVIDIA/nvidia-container-toolkit/internal/discover/csv" - "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" - "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" -) - -// NewFromCSVFiles creates a discoverer for the specified CSV files. A logger is also supplied. -// The constructed discoverer is comprised of a list, with each element in the list being associated with a -// single CSV files. -func NewFromCSVFiles(logger logger.Interface, files []string, driverRoot string) (Discover, error) { - if len(files) == 0 { - logger.Warningf("No CSV files specified") - return None{}, nil - } - - symlinkLocator := lookup.NewSymlinkLocator( - lookup.WithLogger(logger), - lookup.WithRoot(driverRoot), - ) - locators := map[csv.MountSpecType]lookup.Locator{ - csv.MountSpecDev: lookup.NewCharDeviceLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)), - csv.MountSpecDir: lookup.NewDirectoryLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)), - // Libraries and symlinks are handled in the same way - csv.MountSpecLib: symlinkLocator, - csv.MountSpecSym: symlinkLocator, - } - - var mountSpecs []*csv.MountSpec - for _, filename := range files { - targets, err := loadCSVFile(logger, filename) - if err != nil { - logger.Warningf("Skipping CSV file %v: %v", filename, err) - continue - } - mountSpecs = append(mountSpecs, targets...) - } - - return newFromMountSpecs(logger, locators, driverRoot, mountSpecs) -} - -// loadCSVFile loads the specified CSV file and returns the list of mount specs -func loadCSVFile(logger logger.Interface, filename string) ([]*csv.MountSpec, error) { - // Create a discoverer for each file-kind combination - targets, err := csv.NewCSVFileParser(logger, filename).Parse() - if err != nil { - return nil, fmt.Errorf("failed to parse CSV file: %v", err) - } - if len(targets) == 0 { - return nil, fmt.Errorf("CSV file is empty") - } - - return targets, nil -} - -// newFromMountSpecs creates a discoverer for the CSV file. A logger is also supplied. -// A list of csvDiscoverers is returned, with each being associated with a single MountSpecType. -func newFromMountSpecs(logger logger.Interface, locators map[csv.MountSpecType]lookup.Locator, driverRoot string, targets []*csv.MountSpec) (Discover, error) { - if len(targets) == 0 { - return &None{}, nil - } - - var discoverers []Discover - var mountSpecTypes []csv.MountSpecType - candidatesByType := make(map[csv.MountSpecType][]string) - for _, t := range targets { - if _, exists := candidatesByType[t.Type]; !exists { - mountSpecTypes = append(mountSpecTypes, t.Type) - } - candidatesByType[t.Type] = append(candidatesByType[t.Type], t.Path) - } - - for _, t := range mountSpecTypes { - locator, exists := locators[t] - if !exists { - return nil, fmt.Errorf("no locator defined for '%v'", t) - } - - var m Discover - switch t { - case csv.MountSpecDev: - m = NewDeviceDiscoverer(logger, locator, driverRoot, candidatesByType[t]) - default: - m = NewMounts(logger, locator, driverRoot, candidatesByType[t]) - } - discoverers = append(discoverers, m) - - } - - return &list{discoverers: discoverers}, nil -} diff --git a/internal/discover/csv_test.go b/internal/discover/csv_test.go deleted file mode 100644 index f3c6dfeb..00000000 --- a/internal/discover/csv_test.go +++ /dev/null @@ -1,142 +0,0 @@ -/** -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -**/ - -package discover - -import ( - "fmt" - "testing" - - "github.com/NVIDIA/nvidia-container-toolkit/internal/discover/csv" - "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" - testlog "github.com/sirupsen/logrus/hooks/test" - "github.com/stretchr/testify/require" -) - -func TestNewFromMountSpec(t *testing.T) { - logger, _ := testlog.NewNullLogger() - - locators := map[csv.MountSpecType]lookup.Locator{ - "dev": &lookup.LocatorMock{}, - "lib": &lookup.LocatorMock{}, - } - - testCases := []struct { - description string - root string - targets []*csv.MountSpec - expectedError error - expectedDiscoverer Discover - }{ - { - description: "empty targets returns None discoverer list", - expectedDiscoverer: &None{}, - }, - { - description: "unexpected locator returns error", - targets: []*csv.MountSpec{ - { - Type: "foo", - Path: "bar", - }, - }, - expectedError: fmt.Errorf("no locator defined for foo"), - }, - { - description: "creates discoverers based on type", - targets: []*csv.MountSpec{ - { - Type: "dev", - Path: "dev0", - }, - { - Type: "lib", - Path: "lib0", - }, - { - Type: "dev", - Path: "dev1", - }, - }, - expectedDiscoverer: &list{ - discoverers: []Discover{ - (*charDevices)( - &mounts{ - logger: logger, - lookup: locators["dev"], - root: "/", - required: []string{"dev0", "dev1"}, - }, - ), - &mounts{ - logger: logger, - lookup: locators["lib"], - root: "/", - required: []string{"lib0"}, - }, - }, - }, - }, - { - description: "sets root", - targets: []*csv.MountSpec{ - { - Type: "dev", - Path: "dev0", - }, - { - Type: "lib", - Path: "lib0", - }, - { - Type: "dev", - Path: "dev1", - }, - }, - root: "/some/root", - expectedDiscoverer: &list{ - discoverers: []Discover{ - (*charDevices)( - &mounts{ - logger: logger, - lookup: locators["dev"], - root: "/some/root", - required: []string{"dev0", "dev1"}, - }, - ), - &mounts{ - logger: logger, - lookup: locators["lib"], - root: "/some/root", - required: []string{"lib0"}, - }, - }, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.description, func(t *testing.T) { - discoverer, err := newFromMountSpecs(logger, locators, tc.root, tc.targets) - if tc.expectedError != nil { - require.Error(t, err) - return - } - require.NoError(t, err) - require.EqualValues(t, tc.expectedDiscoverer, discoverer) - }) - } -} diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index a5e0f3d0..69e2a9fa 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -23,9 +23,9 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" - "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" - "github.com/opencontainers/runtime-spec/specs-go" + "github.com/container-orchestrated-devices/container-device-interface/pkg/parser" ) type cdiModifier struct { @@ -48,13 +48,11 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spe } logger.Debugf("Creating CDI modifier for devices: %v", devices) - m := cdiModifier{ - logger: logger, - specDirs: cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.SpecDirs, - devices: devices, - } - - return m, nil + return cdi.New( + cdi.WithLogger(logger), + cdi.WithDevices(devices...), + cdi.WithSpecDirs(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.SpecDirs...), + ) } func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.Config) ([]string, error) { @@ -80,7 +78,7 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C var devices []string seen := make(map[string]bool) for _, name := range envDevices.List() { - if !cdi.IsQualifiedName(name) { + if !parser.IsQualifiedName(name) { name = fmt.Sprintf("%s=%s", cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind, name) } if seen[name] { @@ -121,7 +119,7 @@ func getAnnotationDevices(prefixes []string, annotations map[string]string) ([]s var annotationDevices []string for key, devices := range devicesByKey { for _, device := range devices { - if !cdi.IsQualifiedName(device) { + if !parser.IsQualifiedName(device) { return nil, fmt.Errorf("invalid device name %q in annotation %q", device, key) } if seen[device] { @@ -134,22 +132,3 @@ func getAnnotationDevices(prefixes []string, annotations map[string]string) ([]s return annotationDevices, nil } - -// Modify loads the CDI registry and injects the specified CDI devices into the OCI runtime specification. -func (m cdiModifier) Modify(spec *specs.Spec) error { - registry := cdi.GetRegistry( - cdi.WithSpecDirs(m.specDirs...), - cdi.WithAutoRefresh(false), - ) - if err := registry.Refresh(); err != nil { - m.logger.Debugf("The following error was triggered when refreshing the CDI registry: %v", err) - } - - m.logger.Debugf("Injecting devices using CDI: %v", m.devices) - _, err := registry.InjectDevices(spec, m.devices...) - if err != nil { - return fmt.Errorf("failed to inject CDI devices: %v", err) - } - - return nil -} diff --git a/internal/modifier/cdi/builder.go b/internal/modifier/cdi/builder.go new file mode 100644 index 00000000..1a491f19 --- /dev/null +++ b/internal/modifier/cdi/builder.go @@ -0,0 +1,106 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package cdi + +import ( + "fmt" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" + "github.com/container-orchestrated-devices/container-device-interface/specs-go" +) + +type builder struct { + logger logger.Interface + specDirs []string + devices []string + cdiSpec *specs.Spec +} + +// Option represents a functional option for creating a CDI mofifier. +type Option func(*builder) + +// New creates a new CDI modifier. +func New(opts ...Option) (oci.SpecModifier, error) { + b := &builder{} + for _, opt := range opts { + opt(b) + } + if b.logger == nil { + b.logger = logger.New() + } + return b.build() +} + +// build uses the applied options and constructs a CDI modifier using the builder. +func (m builder) build() (oci.SpecModifier, error) { + if len(m.devices) == 0 && m.cdiSpec == nil { + return nil, nil + } + + if m.cdiSpec != nil { + modifier := fromCDISpec{ + cdiSpec: &cdi.Spec{Spec: m.cdiSpec}, + } + return modifier, nil + } + + registry, err := cdi.NewCache( + cdi.WithAutoRefresh(false), + cdi.WithSpecDirs(m.specDirs...), + ) + if err != nil { + return nil, fmt.Errorf("failed to create CDI registry: %v", err) + } + + modifier := fromRegistry{ + logger: m.logger, + registry: registry, + devices: m.devices, + } + + return modifier, nil +} + +// WithLogger sets the logger for the CDI modifier builder. +func WithLogger(logger logger.Interface) Option { + return func(b *builder) { + b.logger = logger + } +} + +// WithSpecDirs sets the spec directories for the CDI modifier builder. +func WithSpecDirs(specDirs ...string) Option { + return func(b *builder) { + b.specDirs = specDirs + } +} + +// WithDevices sets the devices for the CDI modifier builder. +func WithDevices(devices ...string) Option { + return func(b *builder) { + b.devices = devices + } +} + +// WithSpec sets the spec for the CDI modifier builder. +func WithSpec(spec *specs.Spec) Option { + return func(b *builder) { + b.cdiSpec = spec + } +} diff --git a/internal/modifier/cdi/registry.go b/internal/modifier/cdi/registry.go new file mode 100644 index 00000000..ad3666c1 --- /dev/null +++ b/internal/modifier/cdi/registry.go @@ -0,0 +1,62 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package cdi + +import ( + "errors" + "fmt" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" + "github.com/opencontainers/runtime-spec/specs-go" +) + +// fromRegistry represents the modifications performed using a CDI registry. +type fromRegistry struct { + logger logger.Interface + registry *cdi.Cache + devices []string +} + +var _ oci.SpecModifier = (*fromRegistry)(nil) + +// Modify applies the mofiications defined by the CDI registry to the incomming OCI spec. +func (m fromRegistry) Modify(spec *specs.Spec) error { + if err := m.registry.Refresh(); err != nil { + m.logger.Debugf("The following error was triggered when refreshing the CDI registry: %v", err) + } + + m.logger.Debugf("Injecting devices using CDI: %v", m.devices) + _, err := m.registry.InjectDevices(spec, m.devices...) + if err != nil { + var refreshErrors []error + for _, rerrs := range m.registry.GetErrors() { + refreshErrors = append(refreshErrors, rerrs...) + } + if rerr := errors.Join(refreshErrors...); rerr != nil { + // We log the errors that may have been generated while refreshing the CDI registry. + // These may be due to malformed specifications or device name conflicts that could be + // the cause of an injection failure. + m.logger.Warningf("Refreshing the CDI registry generated errors: %v", rerr) + } + + return fmt.Errorf("failed to inject CDI devices: %v", err) + } + + return nil +} diff --git a/internal/modifier/cdi/spec.go b/internal/modifier/cdi/spec.go new file mode 100644 index 00000000..57c65e9a --- /dev/null +++ b/internal/modifier/cdi/spec.go @@ -0,0 +1,46 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package cdi + +import ( + "fmt" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" + "github.com/opencontainers/runtime-spec/specs-go" +) + +// fromCDISpec represents the modifications performed from a raw CDI spec. +type fromCDISpec struct { + cdiSpec *cdi.Spec +} + +var _ oci.SpecModifier = (*fromCDISpec)(nil) + +// Modify applies the mofiications defined by the raw CDI spec to the incomming OCI spec. +func (m fromCDISpec) Modify(spec *specs.Spec) error { + for _, device := range m.cdiSpec.Devices { + cdiDevice := cdi.Device{ + Device: &device, + } + if err := cdiDevice.ApplyEdits(spec); err != nil { + return fmt.Errorf("failed to apply edits for device %q: %v", cdiDevice.GetQualifiedName(), err) + } + } + + return m.cdiSpec.ApplyEdits(spec) +} diff --git a/internal/modifier/csv.go b/internal/modifier/csv.go index 0a834d0d..56adcdc7 100644 --- a/internal/modifier/csv.go +++ b/internal/modifier/csv.go @@ -23,11 +23,12 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/cuda" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" - "github.com/NVIDIA/nvidia-container-toolkit/internal/discover/csv" - "github.com/NVIDIA/nvidia-container-toolkit/internal/discover/tegra" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" "github.com/NVIDIA/nvidia-container-toolkit/internal/requirements" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" ) // csvMode represents the modifications as performed by the csv runtime mode @@ -65,24 +66,33 @@ func NewCSVModifier(logger logger.Interface, cfg *config.Config, image image.CUD csvFiles = csv.BaseFilesOnly(csvFiles) } - d, err := tegra.New( - tegra.WithLogger(logger), - tegra.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root), - tegra.WithNVIDIACTKPath(cfg.NVIDIACTKConfig.Path), - tegra.WithCSVFiles(csvFiles), + cdilib, err := nvcdi.New( + nvcdi.WithLogger(logger), + nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root), + nvcdi.WithNVIDIACTKPath(cfg.NVIDIACTKConfig.Path), + nvcdi.WithMode(nvcdi.ModeCSV), + nvcdi.WithCSVFiles(csvFiles), ) if err != nil { - return nil, fmt.Errorf("failed to construct discoverer: %v", err) + return nil, fmt.Errorf("failed to construct CDI library: %v", err) } - discoverModifier, err := NewModifierFromDiscoverer(logger, d) + spec, err := cdilib.GetSpec() if err != nil { - return nil, fmt.Errorf("failed to construct modifier: %v", err) + return nil, fmt.Errorf("failed to get CDI spec: %v", err) + } + + cdiModifier, err := cdi.New( + cdi.WithLogger(logger), + cdi.WithSpec(spec.Raw()), + ) + if err != nil { + return nil, fmt.Errorf("failed to construct CDI modifier: %v", err) } modifiers := Merge( nvidiaContainerRuntimeHookRemover{logger}, - discoverModifier, + cdiModifier, ) return modifiers, nil diff --git a/internal/platform-support/tegra/csv.go b/internal/platform-support/tegra/csv.go new file mode 100644 index 00000000..3da4b0dd --- /dev/null +++ b/internal/platform-support/tegra/csv.go @@ -0,0 +1,117 @@ +/** +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package tegra + +import ( + "fmt" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" + "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" +) + +// newDiscovererFromCSVFiles creates a discoverer for the specified CSV files. A logger is also supplied. +// The constructed discoverer is comprised of a list, with each element in the list being associated with a +// single CSV files. +func newDiscovererFromCSVFiles(logger logger.Interface, files []string, driverRoot string, nvidiaCTKPath string, librarySearchPaths []string) (discover.Discover, error) { + if len(files) == 0 { + logger.Warningf("No CSV files specified") + return discover.None{}, nil + } + + targetsByType := getTargetsFromCSVFiles(logger, files) + + devices := discover.NewDeviceDiscoverer( + logger, + lookup.NewCharDeviceLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)), + driverRoot, + targetsByType[csv.MountSpecDev], + ) + + directories := discover.NewMounts( + logger, + lookup.NewDirectoryLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)), + driverRoot, + targetsByType[csv.MountSpecDir], + ) + + // Libraries and symlinks use the same locator. + searchPaths := append(librarySearchPaths, "/") + symlinkLocator := lookup.NewSymlinkLocator( + lookup.WithLogger(logger), + lookup.WithRoot(driverRoot), + lookup.WithSearchPaths(searchPaths...), + ) + libraries := discover.NewMounts( + logger, + symlinkLocator, + driverRoot, + targetsByType[csv.MountSpecLib], + ) + + nonLibSymlinks := ignoreFilenamePatterns{"*.so", "*.so.[0-9]"}.Apply(targetsByType[csv.MountSpecSym]...) + logger.Debugf("Non-lib symlinks: %v", nonLibSymlinks) + symlinks := discover.NewMounts( + logger, + symlinkLocator, + driverRoot, + nonLibSymlinks, + ) + createSymlinks := createCSVSymlinkHooks(logger, nonLibSymlinks, libraries, nvidiaCTKPath) + + d := discover.Merge( + devices, + directories, + libraries, + symlinks, + createSymlinks, + ) + + return d, nil +} + +// getTargetsFromCSVFiles returns the list of mount specs from the specified CSV files. +// These are aggregated by mount spec type. +func getTargetsFromCSVFiles(logger logger.Interface, files []string) map[csv.MountSpecType][]string { + targetsByType := make(map[csv.MountSpecType][]string) + for _, filename := range files { + targets, err := loadCSVFile(logger, filename) + if err != nil { + logger.Warningf("Skipping CSV file %v: %v", filename, err) + continue + } + for _, t := range targets { + targetsByType[t.Type] = append(targetsByType[t.Type], t.Path) + } + } + return targetsByType +} + +// loadCSVFile loads the specified CSV file and returns the list of mount specs +func loadCSVFile(logger logger.Interface, filename string) ([]*csv.MountSpec, error) { + // Create a discoverer for each file-kind combination + targets, err := csv.NewCSVFileParser(logger, filename).Parse() + if err != nil { + return nil, fmt.Errorf("failed to parse CSV file: %v", err) + } + if len(targets) == 0 { + return nil, fmt.Errorf("CSV file is empty") + } + + return targets, nil +} diff --git a/internal/discover/csv/csv.go b/internal/platform-support/tegra/csv/csv.go similarity index 100% rename from internal/discover/csv/csv.go rename to internal/platform-support/tegra/csv/csv.go diff --git a/internal/discover/csv/csv_test.go b/internal/platform-support/tegra/csv/csv_test.go similarity index 100% rename from internal/discover/csv/csv_test.go rename to internal/platform-support/tegra/csv/csv_test.go diff --git a/internal/discover/csv/mount_spec.go b/internal/platform-support/tegra/csv/mount_spec.go similarity index 100% rename from internal/discover/csv/mount_spec.go rename to internal/platform-support/tegra/csv/mount_spec.go diff --git a/internal/discover/csv/mount_spec_test.go b/internal/platform-support/tegra/csv/mount_spec_test.go similarity index 100% rename from internal/discover/csv/mount_spec_test.go rename to internal/platform-support/tegra/csv/mount_spec_test.go diff --git a/internal/platform-support/tegra/csv_test.go b/internal/platform-support/tegra/csv_test.go new file mode 100644 index 00000000..d01ce260 --- /dev/null +++ b/internal/platform-support/tegra/csv_test.go @@ -0,0 +1,17 @@ +/** +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package tegra diff --git a/internal/platform-support/tegra/filter.go b/internal/platform-support/tegra/filter.go new file mode 100644 index 00000000..7d6e8e15 --- /dev/null +++ b/internal/platform-support/tegra/filter.go @@ -0,0 +1,41 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package tegra + +import "path/filepath" + +type ignoreFilenamePatterns []string + +func (d ignoreFilenamePatterns) Match(name string) bool { + for _, pattern := range d { + if match, _ := filepath.Match(pattern, filepath.Base(name)); match { + return true + } + } + return false +} + +func (d ignoreFilenamePatterns) Apply(input ...string) []string { + var filtered []string + for _, name := range input { + if d.Match(name) { + continue + } + filtered = append(filtered, name) + } + return filtered +} diff --git a/internal/platform-support/tegra/filter_test.go b/internal/platform-support/tegra/filter_test.go new file mode 100644 index 00000000..5cca505d --- /dev/null +++ b/internal/platform-support/tegra/filter_test.go @@ -0,0 +1,29 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package tegra + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIgnorePatterns(t *testing.T) { + filtered := ignoreFilenamePatterns{"*.so", "*.so.[0-9]"}.Apply("/foo/bar/libsomething.so", "libsometing.so", "libsometing.so.1", "libsometing.so.1.2.3") + + require.ElementsMatch(t, []string{"libsometing.so.1.2.3"}, filtered) +} diff --git a/internal/discover/symlinks.go b/internal/platform-support/tegra/symlinks.go similarity index 75% rename from internal/discover/symlinks.go rename to internal/platform-support/tegra/symlinks.go index 1ed8adda..505565d5 100644 --- a/internal/discover/symlinks.go +++ b/internal/platform-support/tegra/symlinks.go @@ -14,60 +14,51 @@ # limitations under the License. **/ -package discover +package tegra import ( "fmt" "path/filepath" "strings" - "github.com/NVIDIA/nvidia-container-toolkit/internal/discover/csv" + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks" ) type symlinkHook struct { - None + discover.None logger logger.Interface driverRoot string nvidiaCTKPath string - csvFiles []string - mountsFrom Discover + targets []string + mountsFrom discover.Discover } -// NewCreateSymlinksHook creates a discoverer for a hook that creates required symlinks in the container -func NewCreateSymlinksHook(logger logger.Interface, csvFiles []string, mounts Discover, nvidiaCTKPath string) (Discover, error) { - d := symlinkHook{ +// createCSVSymlinkHooks creates a discoverer for a hook that creates required symlinks in the container +func createCSVSymlinkHooks(logger logger.Interface, targets []string, mounts discover.Discover, nvidiaCTKPath string) discover.Discover { + return symlinkHook{ logger: logger, nvidiaCTKPath: nvidiaCTKPath, - csvFiles: csvFiles, + targets: targets, mountsFrom: mounts, } - - return &d, nil } // Hooks returns a hook to create the symlinks from the required CSV files -func (d symlinkHook) Hooks() ([]Hook, error) { +func (d symlinkHook) Hooks() ([]discover.Hook, error) { specificLinks, err := d.getSpecificLinks() if err != nil { return nil, fmt.Errorf("failed to determine specific links: %v", err) } csvSymlinks := d.getCSVFileSymlinks() - var args []string - for _, link := range append(csvSymlinks, specificLinks...) { - args = append(args, "--link", link) - } - hook := CreateNvidiaCTKHook( + return discover.CreateCreateSymlinkHook( d.nvidiaCTKPath, - "create-symlinks", - args..., - ) - - return []Hook{hook}, nil + append(csvSymlinks, specificLinks...), + ).Hooks() } // getSpecificLinks returns the required specic links that need to be created @@ -112,36 +103,30 @@ func (d symlinkHook) getSpecificLinks() ([]string, error) { return links, nil } -func (d symlinkHook) getCSVFileSymlinks() []string { +// getSymlinkCandidates returns a list of symlinks that are candidates for being created. +func (d symlinkHook) getSymlinkCandidates() []string { chainLocator := lookup.NewSymlinkChainLocator( lookup.WithLogger(d.logger), lookup.WithRoot(d.driverRoot), ) var candidates []string - for _, file := range d.csvFiles { - mountSpecs, err := csv.NewCSVFileParser(d.logger, file).Parse() + for _, target := range d.targets { + reslovedSymlinkChain, err := chainLocator.Locate(target) if err != nil { - d.logger.Debugf("Skipping CSV file %v: %v", file, err) + d.logger.Warningf("Failed to locate symlink %v", target) continue } - - for _, ms := range mountSpecs { - if ms.Type != csv.MountSpecSym { - continue - } - targets, err := chainLocator.Locate(ms.Path) - if err != nil { - d.logger.Warningf("Failed to locate symlink %v", ms.Path) - } - candidates = append(candidates, targets...) - } + candidates = append(candidates, reslovedSymlinkChain...) } + return candidates +} +func (d symlinkHook) getCSVFileSymlinks() []string { var links []string created := make(map[string]bool) // candidates is a list of absolute paths to symlinks in a chain, or the final target of the chain. - for _, candidate := range candidates { + for _, candidate := range d.getSymlinkCandidates() { target, err := symlinks.Resolve(candidate) if err != nil { d.logger.Debugf("Skipping invalid link: %v", err) diff --git a/internal/discover/tegra/tegra.go b/internal/platform-support/tegra/tegra.go similarity index 83% rename from internal/discover/tegra/tegra.go rename to internal/platform-support/tegra/tegra.go index 3091c500..019d2730 100644 --- a/internal/discover/tegra/tegra.go +++ b/internal/platform-support/tegra/tegra.go @@ -25,10 +25,11 @@ import ( ) type tegraOptions struct { - logger logger.Interface - csvFiles []string - driverRoot string - nvidiaCTKPath string + logger logger.Interface + csvFiles []string + driverRoot string + nvidiaCTKPath string + librarySearchPaths []string } // Option defines a functional option for configuring a Tegra discoverer. @@ -41,16 +42,11 @@ func New(opts ...Option) (discover.Discover, error) { opt(o) } - csvDiscoverer, err := discover.NewFromCSVFiles(o.logger, o.csvFiles, o.driverRoot) + csvDiscoverer, err := newDiscovererFromCSVFiles(o.logger, o.csvFiles, o.driverRoot, o.nvidiaCTKPath, o.librarySearchPaths) if err != nil { return nil, fmt.Errorf("failed to create CSV discoverer: %v", err) } - createSymlinksHook, err := discover.NewCreateSymlinksHook(o.logger, o.csvFiles, csvDiscoverer, o.nvidiaCTKPath) - if err != nil { - return nil, fmt.Errorf("failed to create symlink hook discoverer: %v", err) - } - ldcacheUpdateHook, err := discover.NewLDCacheUpdateHook(o.logger, csvDiscoverer, o.nvidiaCTKPath) if err != nil { return nil, fmt.Errorf("failed to create ldcach update hook discoverer: %v", err) @@ -68,7 +64,6 @@ func New(opts ...Option) (discover.Discover, error) { d := discover.Merge( csvDiscoverer, - createSymlinksHook, // The ldcacheUpdateHook is added last to ensure that the created symlinks are included ldcacheUpdateHook, tegraSystemMounts, @@ -104,3 +99,10 @@ func WithNVIDIACTKPath(nvidiaCTKPath string) Option { o.nvidiaCTKPath = nvidiaCTKPath } } + +// WithLibrarySearchPaths sets the library search paths for the discoverer. +func WithLibrarySearchPaths(librarySearchPaths ...string) Option { + return func(o *tegraOptions) { + o.librarySearchPaths = librarySearchPaths + } +} diff --git a/pkg/nvcdi/lib-csv.go b/pkg/nvcdi/lib-csv.go index 997724bf..5ae17964 100644 --- a/pkg/nvcdi/lib-csv.go +++ b/pkg/nvcdi/lib-csv.go @@ -20,8 +20,8 @@ import ( "fmt" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" - "github.com/NVIDIA/nvidia-container-toolkit/internal/discover/tegra" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" + "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/specs-go" @@ -44,6 +44,7 @@ func (l *csvlib) GetAllDeviceSpecs() ([]specs.Device, error) { tegra.WithDriverRoot(l.driverRoot), tegra.WithNVIDIACTKPath(l.nvidiaCTKPath), tegra.WithCSVFiles(l.csvFiles), + tegra.WithLibrarySearchPaths(l.librarySearchPaths...), ) if err != nil { return nil, fmt.Errorf("failed to create discoverer for CSV files: %v", err) diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index ad5dd270..30d12aae 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -19,8 +19,8 @@ package nvcdi import ( "fmt" - "github.com/NVIDIA/nvidia-container-toolkit/internal/discover/csv" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" @@ -38,13 +38,14 @@ type wrapper struct { } type nvcdilib struct { - logger logger.Interface - nvmllib nvml.Interface - mode string - devicelib device.Interface - deviceNamer DeviceNamer - driverRoot string - nvidiaCTKPath string + logger logger.Interface + nvmllib nvml.Interface + mode string + devicelib device.Interface + deviceNamer DeviceNamer + driverRoot string + nvidiaCTKPath string + librarySearchPaths []string csvFiles []string diff --git a/pkg/nvcdi/options.go b/pkg/nvcdi/options.go index 9ac772ba..f354e6df 100644 --- a/pkg/nvcdi/options.go +++ b/pkg/nvcdi/options.go @@ -103,3 +103,11 @@ func WithCSVFiles(csvFiles []string) Option { o.csvFiles = csvFiles } } + +// WithLibrarySearchPaths sets the library search paths. +// This is currently only used for CSV-mode. +func WithLibrarySearchPaths(paths []string) Option { + return func(o *nvcdilib) { + o.librarySearchPaths = paths + } +}