diff --git a/internal/platform-support/tegra/csv.go b/internal/platform-support/tegra/csv.go index a4552c59..3a66f471 100644 --- a/internal/platform-support/tegra/csv.go +++ b/internal/platform-support/tegra/csv.go @@ -28,35 +28,73 @@ import ( // newDiscovererFromCSVFiles creates a discoverer for the specified CSV files. A logger is also supplied. // The constructed discoverer is comprised of a list, with each element in the list being associated with a // single CSV files. -func newDiscovererFromCSVFiles(logger logger.Interface, files []string, driverRoot string) (discover.Discover, error) { +func newDiscovererFromCSVFiles(logger logger.Interface, files []string, driverRoot string, nvidiaCTKPath string) (discover.Discover, error) { if len(files) == 0 { logger.Warningf("No CSV files specified") return discover.None{}, nil } - symlinkLocator := lookup.NewSymlinkLocator( - lookup.WithLogger(logger), - lookup.WithRoot(driverRoot), - ) - locators := map[csv.MountSpecType]lookup.Locator{ - csv.MountSpecDev: lookup.NewCharDeviceLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)), - csv.MountSpecDir: lookup.NewDirectoryLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)), - // Libraries and symlinks are handled in the same way - csv.MountSpecLib: symlinkLocator, - csv.MountSpecSym: symlinkLocator, - } + targetsByType := getTargetsFromCSVFiles(logger, files) - var mountSpecs []*csv.MountSpec + devices := discover.NewDeviceDiscoverer( + logger, + lookup.NewCharDeviceLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)), + driverRoot, + targetsByType[csv.MountSpecDev], + ) + + directories := discover.NewMounts( + logger, + lookup.NewDirectoryLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)), + driverRoot, + targetsByType[csv.MountSpecDir], + ) + + // Libraries and symlinks use the same locator. + symlinkLocator := lookup.NewSymlinkLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)) + libraries := discover.NewMounts( + logger, + symlinkLocator, + driverRoot, + targetsByType[csv.MountSpecLib], + ) + + nonLibSymlinks := ignoreFilenamePatterns{"*.so", "*.so.[0-9]"}.Apply(targetsByType[csv.MountSpecSym]...) + logger.Debugf("Non-lib symlinks: %v", nonLibSymlinks) + symlinks := discover.NewMounts( + logger, + symlinkLocator, + driverRoot, + nonLibSymlinks, + ) + createSymlinks := createCSVSymlinkHooks(logger, nonLibSymlinks, libraries, nvidiaCTKPath) + + d := discover.Merge( + devices, + directories, + libraries, + symlinks, + createSymlinks, + ) + + return d, nil +} + +// getTargetsFromCSVFiles returns the list of mount specs from the specified CSV files. +// These are aggregated by mount spec type. +func getTargetsFromCSVFiles(logger logger.Interface, files []string) map[csv.MountSpecType][]string { + targetsByType := make(map[csv.MountSpecType][]string) for _, filename := range files { targets, err := loadCSVFile(logger, filename) if err != nil { logger.Warningf("Skipping CSV file %v: %v", filename, err) continue } - mountSpecs = append(mountSpecs, targets...) + for _, t := range targets { + targetsByType[t.Type] = append(targetsByType[t.Type], t.Path) + } } - - return newFromMountSpecs(logger, locators, driverRoot, mountSpecs) + return targetsByType } // loadCSVFile loads the specified CSV file and returns the list of mount specs @@ -72,40 +110,3 @@ func loadCSVFile(logger logger.Interface, filename string) ([]*csv.MountSpec, er return targets, nil } - -// newFromMountSpecs creates a discoverer for the CSV file. A logger is also supplied. -// A list of csvDiscoverers is returned, with each being associated with a single MountSpecType. -func newFromMountSpecs(logger logger.Interface, locators map[csv.MountSpecType]lookup.Locator, driverRoot string, targets []*csv.MountSpec) (discover.Discover, error) { - if len(targets) == 0 { - return &discover.None{}, nil - } - - var discoverers []discover.Discover - var mountSpecTypes []csv.MountSpecType - candidatesByType := make(map[csv.MountSpecType][]string) - for _, t := range targets { - if _, exists := candidatesByType[t.Type]; !exists { - mountSpecTypes = append(mountSpecTypes, t.Type) - } - candidatesByType[t.Type] = append(candidatesByType[t.Type], t.Path) - } - - for _, t := range mountSpecTypes { - locator, exists := locators[t] - if !exists { - return nil, fmt.Errorf("no locator defined for '%v'", t) - } - - var m discover.Discover - switch t { - case csv.MountSpecDev: - m = discover.NewDeviceDiscoverer(logger, locator, driverRoot, candidatesByType[t]) - default: - m = discover.NewMounts(logger, locator, driverRoot, candidatesByType[t]) - } - discoverers = append(discoverers, m) - - } - - return discover.Merge(discoverers...), nil -} diff --git a/internal/platform-support/tegra/csv_test.go b/internal/platform-support/tegra/csv_test.go index 6889e200..d01ce260 100644 --- a/internal/platform-support/tegra/csv_test.go +++ b/internal/platform-support/tegra/csv_test.go @@ -15,105 +15,3 @@ **/ package tegra - -import ( - "fmt" - "testing" - - "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" - "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" - "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" - testlog "github.com/sirupsen/logrus/hooks/test" - "github.com/stretchr/testify/require" -) - -func TestNewFromMountSpec(t *testing.T) { - logger, _ := testlog.NewNullLogger() - - locators := map[csv.MountSpecType]lookup.Locator{ - "dev": &lookup.LocatorMock{ - LocateFunc: func(pattern string) ([]string, error) { - return []string{"/dev/" + pattern}, nil - }, - }, - "lib": &lookup.LocatorMock{ - LocateFunc: func(pattern string) ([]string, error) { - return []string{"/lib/" + pattern}, nil - }, - }, - } - - testCases := []struct { - description string - root string - targets []*csv.MountSpec - expectedError error - expectedDevices []discover.Device - expectedMounts []discover.Mount - expectedHooks []discover.Hook - }{ - { - description: "empty targets returns None discoverer list", - expectedDevices: []discover.Device{}, - expectedMounts: []discover.Mount{}, - expectedHooks: []discover.Hook{}, - }, - { - description: "unexpected locator returns error", - targets: []*csv.MountSpec{ - { - Type: "foo", - Path: "bar", - }, - }, - expectedError: fmt.Errorf("no locator defined for foo"), - }, - { - description: "creates discoverers based on type", - targets: []*csv.MountSpec{ - { - Type: "dev", - Path: "dev0", - }, - { - Type: "lib", - Path: "lib0", - }, - { - Type: "dev", - Path: "dev1", - }, - }, - expectedDevices: []discover.Device{ - {Path: "/dev/dev0", HostPath: "/dev/dev0"}, - {Path: "/dev/dev1", HostPath: "/dev/dev1"}, - }, - expectedMounts: []discover.Mount{ - {Path: "/lib/lib0", HostPath: "/lib/lib0", Options: []string{"ro", "nosuid", "nodev", "bind"}}, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.description, func(t *testing.T) { - discoverer, err := newFromMountSpecs(logger, locators, tc.root, tc.targets) - if tc.expectedError != nil { - require.Error(t, err) - return - } - require.NoError(t, err) - - devices, err := discoverer.Devices() - require.NoError(t, err) - require.EqualValues(t, tc.expectedDevices, devices) - - mounts, err := discoverer.Mounts() - require.NoError(t, err) - require.EqualValues(t, tc.expectedMounts, mounts) - - hooks, err := discoverer.Hooks() - require.NoError(t, err) - require.EqualValues(t, tc.expectedHooks, hooks) - }) - } -} diff --git a/internal/platform-support/tegra/filter.go b/internal/platform-support/tegra/filter.go new file mode 100644 index 00000000..7d6e8e15 --- /dev/null +++ b/internal/platform-support/tegra/filter.go @@ -0,0 +1,41 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package tegra + +import "path/filepath" + +type ignoreFilenamePatterns []string + +func (d ignoreFilenamePatterns) Match(name string) bool { + for _, pattern := range d { + if match, _ := filepath.Match(pattern, filepath.Base(name)); match { + return true + } + } + return false +} + +func (d ignoreFilenamePatterns) Apply(input ...string) []string { + var filtered []string + for _, name := range input { + if d.Match(name) { + continue + } + filtered = append(filtered, name) + } + return filtered +} diff --git a/internal/platform-support/tegra/filter_test.go b/internal/platform-support/tegra/filter_test.go new file mode 100644 index 00000000..5cca505d --- /dev/null +++ b/internal/platform-support/tegra/filter_test.go @@ -0,0 +1,29 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package tegra + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIgnorePatterns(t *testing.T) { + filtered := ignoreFilenamePatterns{"*.so", "*.so.[0-9]"}.Apply("/foo/bar/libsomething.so", "libsometing.so", "libsometing.so.1", "libsometing.so.1.2.3") + + require.ElementsMatch(t, []string{"libsometing.so.1.2.3"}, filtered) +} diff --git a/internal/platform-support/tegra/symlinks.go b/internal/platform-support/tegra/symlinks.go index 607e96c6..505565d5 100644 --- a/internal/platform-support/tegra/symlinks.go +++ b/internal/platform-support/tegra/symlinks.go @@ -25,7 +25,6 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks" - "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" ) type symlinkHook struct { @@ -33,20 +32,18 @@ type symlinkHook struct { logger logger.Interface driverRoot string nvidiaCTKPath string - csvFiles []string + targets []string mountsFrom discover.Discover } // createCSVSymlinkHooks creates a discoverer for a hook that creates required symlinks in the container -func createCSVSymlinkHooks(logger logger.Interface, csvFiles []string, mounts discover.Discover, nvidiaCTKPath string) (discover.Discover, error) { - d := symlinkHook{ +func createCSVSymlinkHooks(logger logger.Interface, targets []string, mounts discover.Discover, nvidiaCTKPath string) discover.Discover { + return symlinkHook{ logger: logger, nvidiaCTKPath: nvidiaCTKPath, - csvFiles: csvFiles, + targets: targets, mountsFrom: mounts, } - - return &d, nil } // Hooks returns a hook to create the symlinks from the required CSV files @@ -106,36 +103,30 @@ func (d symlinkHook) getSpecificLinks() ([]string, error) { return links, nil } -func (d symlinkHook) getCSVFileSymlinks() []string { +// getSymlinkCandidates returns a list of symlinks that are candidates for being created. +func (d symlinkHook) getSymlinkCandidates() []string { chainLocator := lookup.NewSymlinkChainLocator( lookup.WithLogger(d.logger), lookup.WithRoot(d.driverRoot), ) var candidates []string - for _, file := range d.csvFiles { - mountSpecs, err := csv.NewCSVFileParser(d.logger, file).Parse() + for _, target := range d.targets { + reslovedSymlinkChain, err := chainLocator.Locate(target) if err != nil { - d.logger.Debugf("Skipping CSV file %v: %v", file, err) + d.logger.Warningf("Failed to locate symlink %v", target) continue } - - for _, ms := range mountSpecs { - if ms.Type != csv.MountSpecSym { - continue - } - targets, err := chainLocator.Locate(ms.Path) - if err != nil { - d.logger.Warningf("Failed to locate symlink %v", ms.Path) - } - candidates = append(candidates, targets...) - } + candidates = append(candidates, reslovedSymlinkChain...) } + return candidates +} +func (d symlinkHook) getCSVFileSymlinks() []string { var links []string created := make(map[string]bool) // candidates is a list of absolute paths to symlinks in a chain, or the final target of the chain. - for _, candidate := range candidates { + for _, candidate := range d.getSymlinkCandidates() { target, err := symlinks.Resolve(candidate) if err != nil { d.logger.Debugf("Skipping invalid link: %v", err) diff --git a/internal/platform-support/tegra/tegra.go b/internal/platform-support/tegra/tegra.go index bf75853b..a52cb6f4 100644 --- a/internal/platform-support/tegra/tegra.go +++ b/internal/platform-support/tegra/tegra.go @@ -41,16 +41,11 @@ func New(opts ...Option) (discover.Discover, error) { opt(o) } - csvDiscoverer, err := newDiscovererFromCSVFiles(o.logger, o.csvFiles, o.driverRoot) + csvDiscoverer, err := newDiscovererFromCSVFiles(o.logger, o.csvFiles, o.driverRoot, o.nvidiaCTKPath) if err != nil { return nil, fmt.Errorf("failed to create CSV discoverer: %v", err) } - createSymlinksHook, err := createCSVSymlinkHooks(o.logger, o.csvFiles, csvDiscoverer, o.nvidiaCTKPath) - if err != nil { - return nil, fmt.Errorf("failed to create symlink hook discoverer: %v", err) - } - ldcacheUpdateHook, err := discover.NewLDCacheUpdateHook(o.logger, csvDiscoverer, o.nvidiaCTKPath) if err != nil { return nil, fmt.Errorf("failed to create ldcach update hook discoverer: %v", err) @@ -68,7 +63,6 @@ func New(opts ...Option) (discover.Discover, error) { d := discover.Merge( csvDiscoverer, - createSymlinksHook, // The ldcacheUpdateHook is added last to ensure that the created symlinks are included ldcacheUpdateHook, tegraSystemMounts,