From 963250a58f58df469285a868f462485665f5c40b Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Fri, 22 Sep 2023 15:23:12 +0200 Subject: [PATCH] Refactor CSV discovery for testability This change improves the testibility of the CSV discoverer. This is done by adding injection points for mocks for library discovery and symlink resolution. Note that this highlights a bug in the current implementation where the library filter causes valid symlinks to be skipped. Signed-off-by: Evan Lezar --- internal/platform-support/tegra/csv.go | 46 ++++---- internal/platform-support/tegra/csv_test.go | 124 ++++++++++++++++++++ internal/platform-support/tegra/symlinks.go | 27 ++--- internal/platform-support/tegra/tegra.go | 29 ++++- 4 files changed, 186 insertions(+), 40 deletions(-) diff --git a/internal/platform-support/tegra/csv.go b/internal/platform-support/tegra/csv.go index 3da4b0dd..54ce7ee9 100644 --- a/internal/platform-support/tegra/csv.go +++ b/internal/platform-support/tegra/csv.go @@ -28,51 +28,45 @@ import ( // newDiscovererFromCSVFiles creates a discoverer for the specified CSV files. A logger is also supplied. // The constructed discoverer is comprised of a list, with each element in the list being associated with a // single CSV files. -func newDiscovererFromCSVFiles(logger logger.Interface, files []string, driverRoot string, nvidiaCTKPath string, librarySearchPaths []string) (discover.Discover, error) { - if len(files) == 0 { - logger.Warningf("No CSV files specified") +func (o tegraOptions) newDiscovererFromCSVFiles() (discover.Discover, error) { + if len(o.csvFiles) == 0 { + o.logger.Warningf("No CSV files specified") return discover.None{}, nil } - targetsByType := getTargetsFromCSVFiles(logger, files) + targetsByType := getTargetsFromCSVFiles(o.logger, o.csvFiles) devices := discover.NewDeviceDiscoverer( - logger, - lookup.NewCharDeviceLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)), - driverRoot, + o.logger, + lookup.NewCharDeviceLocator(lookup.WithLogger(o.logger), lookup.WithRoot(o.driverRoot)), + o.driverRoot, targetsByType[csv.MountSpecDev], ) directories := discover.NewMounts( - logger, - lookup.NewDirectoryLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)), - driverRoot, + o.logger, + lookup.NewDirectoryLocator(lookup.WithLogger(o.logger), lookup.WithRoot(o.driverRoot)), + o.driverRoot, targetsByType[csv.MountSpecDir], ) // Libraries and symlinks use the same locator. - searchPaths := append(librarySearchPaths, "/") - symlinkLocator := lookup.NewSymlinkLocator( - lookup.WithLogger(logger), - lookup.WithRoot(driverRoot), - lookup.WithSearchPaths(searchPaths...), - ) libraries := discover.NewMounts( - logger, - symlinkLocator, - driverRoot, + o.logger, + o.symlinkLocator, + o.driverRoot, targetsByType[csv.MountSpecLib], ) nonLibSymlinks := ignoreFilenamePatterns{"*.so", "*.so.[0-9]"}.Apply(targetsByType[csv.MountSpecSym]...) - logger.Debugf("Non-lib symlinks: %v", nonLibSymlinks) + o.logger.Debugf("Non-lib symlinks: %v", nonLibSymlinks) symlinks := discover.NewMounts( - logger, - symlinkLocator, - driverRoot, + o.logger, + o.symlinkLocator, + o.driverRoot, nonLibSymlinks, ) - createSymlinks := createCSVSymlinkHooks(logger, nonLibSymlinks, libraries, nvidiaCTKPath) + createSymlinks := o.createCSVSymlinkHooks(nonLibSymlinks, libraries) d := discover.Merge( devices, @@ -87,7 +81,9 @@ func newDiscovererFromCSVFiles(logger logger.Interface, files []string, driverRo // getTargetsFromCSVFiles returns the list of mount specs from the specified CSV files. // These are aggregated by mount spec type. -func getTargetsFromCSVFiles(logger logger.Interface, files []string) map[csv.MountSpecType][]string { +// TODO: We use a function variable here to allow this to be overridden for testing. +// This should be properly mocked. +var getTargetsFromCSVFiles = func(logger logger.Interface, files []string) map[csv.MountSpecType][]string { targetsByType := make(map[csv.MountSpecType][]string) for _, filename := range files { targets, err := loadCSVFile(logger, filename) diff --git a/internal/platform-support/tegra/csv_test.go b/internal/platform-support/tegra/csv_test.go index d01ce260..1f1cbd12 100644 --- a/internal/platform-support/tegra/csv_test.go +++ b/internal/platform-support/tegra/csv_test.go @@ -15,3 +15,127 @@ **/ package tegra + +import ( + "fmt" + "testing" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" +) + +func TestDiscovererFromCSVFiles(t *testing.T) { + logger, _ := testlog.NewNullLogger() + testCases := []struct { + description string + moutSpecs map[csv.MountSpecType][]string + symlinkLocator lookup.Locator + symlinkChainLocator lookup.Locator + symlinkResolver func(string) (string, error) + expectedError error + expectedMounts []discover.Mount + expectedMountsError error + expectedHooks []discover.Hook + expectedHooksError error + }{ + { + // TODO: This current resolves to two mounts that are the same. + // These are deduplicated at a later stage. We could consider deduplicating earlier in the pipeline. + description: "symlink is resolved to target; mounts and symlink are created", + moutSpecs: map[csv.MountSpecType][]string{ + "lib": {"/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so"}, + "sym": {"/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so"}, + }, + symlinkLocator: &lookup.LocatorMock{ + LocateFunc: func(path string) ([]string, error) { + if path == "/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so" { + return []string{"/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so"}, nil + } + return []string{path}, nil + }, + }, + symlinkChainLocator: &lookup.LocatorMock{ + LocateFunc: func(path string) ([]string, error) { + if path == "/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so" { + return []string{"/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so", "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so"}, nil + } + return nil, fmt.Errorf("Unexpected path: %v", path) + }, + }, + symlinkResolver: func(path string) (string, error) { + if path == "/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so" { + return "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so", nil + } + return path, nil + }, + expectedMounts: []discover.Mount{ + { + Path: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so", + HostPath: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so", + Options: []string{"ro", "nosuid", "nodev", "bind"}, + }, + { + Path: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so", + HostPath: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so", + Options: []string{"ro", "nosuid", "nodev", "bind"}, + }, + }, + expectedHooks: []discover.Hook{ + { + Lifecycle: "createContainer", + Path: "/usr/bin/nvidia-ctk", + Args: []string{ + "nvidia-ctk", + "hook", + "create-symlinks", + "--link", + "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so::/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so", + }, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + defer setGetTargetsFromCSVFiles(tc.moutSpecs)() + + o := tegraOptions{ + logger: logger, + nvidiaCTKPath: "/usr/bin/nvidia-ctk", + csvFiles: []string{"dummy"}, + symlinkLocator: tc.symlinkLocator, + symlinkChainLocator: tc.symlinkChainLocator, + resolveSymlink: tc.symlinkResolver, + } + + d, err := o.newDiscovererFromCSVFiles() + require.ErrorIs(t, err, tc.expectedError) + + hooks, err := d.Hooks() + require.ErrorIs(t, err, tc.expectedHooksError) + require.EqualValues(t, tc.expectedHooks, hooks) + + mounts, err := d.Mounts() + require.ErrorIs(t, err, tc.expectedMountsError) + require.EqualValues(t, tc.expectedMounts, mounts) + + }) + } +} + +func setGetTargetsFromCSVFiles(ovverride map[csv.MountSpecType][]string) func() { + original := getTargetsFromCSVFiles + getTargetsFromCSVFiles = func(logger logger.Interface, files []string) map[csv.MountSpecType][]string { + return ovverride + } + + return func() { + getTargetsFromCSVFiles = original + } +} diff --git a/internal/platform-support/tegra/symlinks.go b/internal/platform-support/tegra/symlinks.go index 505565d5..c64138db 100644 --- a/internal/platform-support/tegra/symlinks.go +++ b/internal/platform-support/tegra/symlinks.go @@ -24,25 +24,29 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" - "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks" ) type symlinkHook struct { discover.None logger logger.Interface - driverRoot string nvidiaCTKPath string targets []string mountsFrom discover.Discover + + // The following can be overridden for testing + symlinkChainLocator lookup.Locator + resolveSymlink func(string) (string, error) } // createCSVSymlinkHooks creates a discoverer for a hook that creates required symlinks in the container -func createCSVSymlinkHooks(logger logger.Interface, targets []string, mounts discover.Discover, nvidiaCTKPath string) discover.Discover { +func (o tegraOptions) createCSVSymlinkHooks(targets []string, mounts discover.Discover) discover.Discover { return symlinkHook{ - logger: logger, - nvidiaCTKPath: nvidiaCTKPath, - targets: targets, - mountsFrom: mounts, + logger: o.logger, + nvidiaCTKPath: o.nvidiaCTKPath, + targets: targets, + mountsFrom: mounts, + symlinkChainLocator: o.symlinkChainLocator, + resolveSymlink: o.resolveSymlink, } } @@ -105,14 +109,9 @@ func (d symlinkHook) getSpecificLinks() ([]string, error) { // getSymlinkCandidates returns a list of symlinks that are candidates for being created. func (d symlinkHook) getSymlinkCandidates() []string { - chainLocator := lookup.NewSymlinkChainLocator( - lookup.WithLogger(d.logger), - lookup.WithRoot(d.driverRoot), - ) - var candidates []string for _, target := range d.targets { - reslovedSymlinkChain, err := chainLocator.Locate(target) + reslovedSymlinkChain, err := d.symlinkChainLocator.Locate(target) if err != nil { d.logger.Warningf("Failed to locate symlink %v", target) continue @@ -127,7 +126,7 @@ func (d symlinkHook) getCSVFileSymlinks() []string { created := make(map[string]bool) // candidates is a list of absolute paths to symlinks in a chain, or the final target of the chain. for _, candidate := range d.getSymlinkCandidates() { - target, err := symlinks.Resolve(candidate) + target, err := d.resolveSymlink(candidate) if err != nil { d.logger.Debugf("Skipping invalid link: %v", err) continue diff --git a/internal/platform-support/tegra/tegra.go b/internal/platform-support/tegra/tegra.go index 38c0ffe8..695f8414 100644 --- a/internal/platform-support/tegra/tegra.go +++ b/internal/platform-support/tegra/tegra.go @@ -22,6 +22,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks" ) type tegraOptions struct { @@ -30,6 +31,12 @@ type tegraOptions struct { driverRoot string nvidiaCTKPath string librarySearchPaths []string + + // The following can be overridden for testing + symlinkLocator lookup.Locator + symlinkChainLocator lookup.Locator + // TODO: This should be replaced by a regular mock + resolveSymlink func(string) (string, error) } // Option defines a functional option for configuring a Tegra discoverer. @@ -42,7 +49,27 @@ func New(opts ...Option) (discover.Discover, error) { opt(o) } - csvDiscoverer, err := newDiscovererFromCSVFiles(o.logger, o.csvFiles, o.driverRoot, o.nvidiaCTKPath, o.librarySearchPaths) + if o.symlinkLocator == nil { + searchPaths := append(o.librarySearchPaths, "/") + o.symlinkLocator = lookup.NewSymlinkLocator( + lookup.WithLogger(o.logger), + lookup.WithRoot(o.driverRoot), + lookup.WithSearchPaths(searchPaths...), + ) + } + + if o.symlinkChainLocator == nil { + o.symlinkChainLocator = lookup.NewSymlinkChainLocator( + lookup.WithLogger(o.logger), + lookup.WithRoot(o.driverRoot), + ) + } + + if o.resolveSymlink == nil { + o.resolveSymlink = symlinks.Resolve + } + + csvDiscoverer, err := o.newDiscovererFromCSVFiles() if err != nil { return nil, fmt.Errorf("failed to create CSV discoverer: %v", err) }