diff --git a/internal/platform-support/tegra/csv.go b/internal/platform-support/tegra/csv.go index 3da4b0dd..54ce7ee9 100644 --- a/internal/platform-support/tegra/csv.go +++ b/internal/platform-support/tegra/csv.go @@ -28,51 +28,45 @@ import ( // newDiscovererFromCSVFiles creates a discoverer for the specified CSV files. A logger is also supplied. // The constructed discoverer is comprised of a list, with each element in the list being associated with a // single CSV files. -func newDiscovererFromCSVFiles(logger logger.Interface, files []string, driverRoot string, nvidiaCTKPath string, librarySearchPaths []string) (discover.Discover, error) { - if len(files) == 0 { - logger.Warningf("No CSV files specified") +func (o tegraOptions) newDiscovererFromCSVFiles() (discover.Discover, error) { + if len(o.csvFiles) == 0 { + o.logger.Warningf("No CSV files specified") return discover.None{}, nil } - targetsByType := getTargetsFromCSVFiles(logger, files) + targetsByType := getTargetsFromCSVFiles(o.logger, o.csvFiles) devices := discover.NewDeviceDiscoverer( - logger, - lookup.NewCharDeviceLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)), - driverRoot, + o.logger, + lookup.NewCharDeviceLocator(lookup.WithLogger(o.logger), lookup.WithRoot(o.driverRoot)), + o.driverRoot, targetsByType[csv.MountSpecDev], ) directories := discover.NewMounts( - logger, - lookup.NewDirectoryLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)), - driverRoot, + o.logger, + lookup.NewDirectoryLocator(lookup.WithLogger(o.logger), lookup.WithRoot(o.driverRoot)), + o.driverRoot, targetsByType[csv.MountSpecDir], ) // Libraries and symlinks use the same locator. - searchPaths := append(librarySearchPaths, "/") - symlinkLocator := lookup.NewSymlinkLocator( - lookup.WithLogger(logger), - lookup.WithRoot(driverRoot), - lookup.WithSearchPaths(searchPaths...), - ) libraries := discover.NewMounts( - logger, - symlinkLocator, - driverRoot, + o.logger, + o.symlinkLocator, + o.driverRoot, targetsByType[csv.MountSpecLib], ) nonLibSymlinks := ignoreFilenamePatterns{"*.so", "*.so.[0-9]"}.Apply(targetsByType[csv.MountSpecSym]...) - logger.Debugf("Non-lib symlinks: %v", nonLibSymlinks) + o.logger.Debugf("Non-lib symlinks: %v", nonLibSymlinks) symlinks := discover.NewMounts( - logger, - symlinkLocator, - driverRoot, + o.logger, + o.symlinkLocator, + o.driverRoot, nonLibSymlinks, ) - createSymlinks := createCSVSymlinkHooks(logger, nonLibSymlinks, libraries, nvidiaCTKPath) + createSymlinks := o.createCSVSymlinkHooks(nonLibSymlinks, libraries) d := discover.Merge( devices, @@ -87,7 +81,9 @@ func newDiscovererFromCSVFiles(logger logger.Interface, files []string, driverRo // getTargetsFromCSVFiles returns the list of mount specs from the specified CSV files. // These are aggregated by mount spec type. -func getTargetsFromCSVFiles(logger logger.Interface, files []string) map[csv.MountSpecType][]string { +// TODO: We use a function variable here to allow this to be overridden for testing. +// This should be properly mocked. +var getTargetsFromCSVFiles = func(logger logger.Interface, files []string) map[csv.MountSpecType][]string { targetsByType := make(map[csv.MountSpecType][]string) for _, filename := range files { targets, err := loadCSVFile(logger, filename) diff --git a/internal/platform-support/tegra/csv_test.go b/internal/platform-support/tegra/csv_test.go index d01ce260..1f1cbd12 100644 --- a/internal/platform-support/tegra/csv_test.go +++ b/internal/platform-support/tegra/csv_test.go @@ -15,3 +15,127 @@ **/ package tegra + +import ( + "fmt" + "testing" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" +) + +func TestDiscovererFromCSVFiles(t *testing.T) { + logger, _ := testlog.NewNullLogger() + testCases := []struct { + description string + moutSpecs map[csv.MountSpecType][]string + symlinkLocator lookup.Locator + symlinkChainLocator lookup.Locator + symlinkResolver func(string) (string, error) + expectedError error + expectedMounts []discover.Mount + expectedMountsError error + expectedHooks []discover.Hook + expectedHooksError error + }{ + { + // TODO: This current resolves to two mounts that are the same. + // These are deduplicated at a later stage. We could consider deduplicating earlier in the pipeline. + description: "symlink is resolved to target; mounts and symlink are created", + moutSpecs: map[csv.MountSpecType][]string{ + "lib": {"/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so"}, + "sym": {"/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so"}, + }, + symlinkLocator: &lookup.LocatorMock{ + LocateFunc: func(path string) ([]string, error) { + if path == "/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so" { + return []string{"/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so"}, nil + } + return []string{path}, nil + }, + }, + symlinkChainLocator: &lookup.LocatorMock{ + LocateFunc: func(path string) ([]string, error) { + if path == "/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so" { + return []string{"/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so", "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so"}, nil + } + return nil, fmt.Errorf("Unexpected path: %v", path) + }, + }, + symlinkResolver: func(path string) (string, error) { + if path == "/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so" { + return "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so", nil + } + return path, nil + }, + expectedMounts: []discover.Mount{ + { + Path: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so", + HostPath: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so", + Options: []string{"ro", "nosuid", "nodev", "bind"}, + }, + { + Path: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so", + HostPath: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so", + Options: []string{"ro", "nosuid", "nodev", "bind"}, + }, + }, + expectedHooks: []discover.Hook{ + { + Lifecycle: "createContainer", + Path: "/usr/bin/nvidia-ctk", + Args: []string{ + "nvidia-ctk", + "hook", + "create-symlinks", + "--link", + "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so::/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so", + }, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + defer setGetTargetsFromCSVFiles(tc.moutSpecs)() + + o := tegraOptions{ + logger: logger, + nvidiaCTKPath: "/usr/bin/nvidia-ctk", + csvFiles: []string{"dummy"}, + symlinkLocator: tc.symlinkLocator, + symlinkChainLocator: tc.symlinkChainLocator, + resolveSymlink: tc.symlinkResolver, + } + + d, err := o.newDiscovererFromCSVFiles() + require.ErrorIs(t, err, tc.expectedError) + + hooks, err := d.Hooks() + require.ErrorIs(t, err, tc.expectedHooksError) + require.EqualValues(t, tc.expectedHooks, hooks) + + mounts, err := d.Mounts() + require.ErrorIs(t, err, tc.expectedMountsError) + require.EqualValues(t, tc.expectedMounts, mounts) + + }) + } +} + +func setGetTargetsFromCSVFiles(ovverride map[csv.MountSpecType][]string) func() { + original := getTargetsFromCSVFiles + getTargetsFromCSVFiles = func(logger logger.Interface, files []string) map[csv.MountSpecType][]string { + return ovverride + } + + return func() { + getTargetsFromCSVFiles = original + } +} diff --git a/internal/platform-support/tegra/symlinks.go b/internal/platform-support/tegra/symlinks.go index 505565d5..c64138db 100644 --- a/internal/platform-support/tegra/symlinks.go +++ b/internal/platform-support/tegra/symlinks.go @@ -24,25 +24,29 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" - "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks" ) type symlinkHook struct { discover.None logger logger.Interface - driverRoot string nvidiaCTKPath string targets []string mountsFrom discover.Discover + + // The following can be overridden for testing + symlinkChainLocator lookup.Locator + resolveSymlink func(string) (string, error) } // createCSVSymlinkHooks creates a discoverer for a hook that creates required symlinks in the container -func createCSVSymlinkHooks(logger logger.Interface, targets []string, mounts discover.Discover, nvidiaCTKPath string) discover.Discover { +func (o tegraOptions) createCSVSymlinkHooks(targets []string, mounts discover.Discover) discover.Discover { return symlinkHook{ - logger: logger, - nvidiaCTKPath: nvidiaCTKPath, - targets: targets, - mountsFrom: mounts, + logger: o.logger, + nvidiaCTKPath: o.nvidiaCTKPath, + targets: targets, + mountsFrom: mounts, + symlinkChainLocator: o.symlinkChainLocator, + resolveSymlink: o.resolveSymlink, } } @@ -105,14 +109,9 @@ func (d symlinkHook) getSpecificLinks() ([]string, error) { // getSymlinkCandidates returns a list of symlinks that are candidates for being created. func (d symlinkHook) getSymlinkCandidates() []string { - chainLocator := lookup.NewSymlinkChainLocator( - lookup.WithLogger(d.logger), - lookup.WithRoot(d.driverRoot), - ) - var candidates []string for _, target := range d.targets { - reslovedSymlinkChain, err := chainLocator.Locate(target) + reslovedSymlinkChain, err := d.symlinkChainLocator.Locate(target) if err != nil { d.logger.Warningf("Failed to locate symlink %v", target) continue @@ -127,7 +126,7 @@ func (d symlinkHook) getCSVFileSymlinks() []string { created := make(map[string]bool) // candidates is a list of absolute paths to symlinks in a chain, or the final target of the chain. for _, candidate := range d.getSymlinkCandidates() { - target, err := symlinks.Resolve(candidate) + target, err := d.resolveSymlink(candidate) if err != nil { d.logger.Debugf("Skipping invalid link: %v", err) continue diff --git a/internal/platform-support/tegra/tegra.go b/internal/platform-support/tegra/tegra.go index 38c0ffe8..695f8414 100644 --- a/internal/platform-support/tegra/tegra.go +++ b/internal/platform-support/tegra/tegra.go @@ -22,6 +22,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks" ) type tegraOptions struct { @@ -30,6 +31,12 @@ type tegraOptions struct { driverRoot string nvidiaCTKPath string librarySearchPaths []string + + // The following can be overridden for testing + symlinkLocator lookup.Locator + symlinkChainLocator lookup.Locator + // TODO: This should be replaced by a regular mock + resolveSymlink func(string) (string, error) } // Option defines a functional option for configuring a Tegra discoverer. @@ -42,7 +49,27 @@ func New(opts ...Option) (discover.Discover, error) { opt(o) } - csvDiscoverer, err := newDiscovererFromCSVFiles(o.logger, o.csvFiles, o.driverRoot, o.nvidiaCTKPath, o.librarySearchPaths) + if o.symlinkLocator == nil { + searchPaths := append(o.librarySearchPaths, "/") + o.symlinkLocator = lookup.NewSymlinkLocator( + lookup.WithLogger(o.logger), + lookup.WithRoot(o.driverRoot), + lookup.WithSearchPaths(searchPaths...), + ) + } + + if o.symlinkChainLocator == nil { + o.symlinkChainLocator = lookup.NewSymlinkChainLocator( + lookup.WithLogger(o.logger), + lookup.WithRoot(o.driverRoot), + ) + } + + if o.resolveSymlink == nil { + o.resolveSymlink = symlinks.Resolve + } + + csvDiscoverer, err := o.newDiscovererFromCSVFiles() if err != nil { return nil, fmt.Errorf("failed to create CSV discoverer: %v", err) }