Refactor CSV discovery for testability

This change improves the testibility of the CSV discoverer.
This is done by adding injection points for mocks for library discovery and
symlink resolution.

Note that this highlights a bug in the current implementation where the
library filter causes valid symlinks to be skipped.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2023-09-22 15:23:12 +02:00
parent be570fce65
commit 963250a58f
4 changed files with 186 additions and 40 deletions

View File

@ -28,51 +28,45 @@ import (
// newDiscovererFromCSVFiles creates a discoverer for the specified CSV files. A logger is also supplied.
// The constructed discoverer is comprised of a list, with each element in the list being associated with a
// single CSV files.
func newDiscovererFromCSVFiles(logger logger.Interface, files []string, driverRoot string, nvidiaCTKPath string, librarySearchPaths []string) (discover.Discover, error) {
if len(files) == 0 {
logger.Warningf("No CSV files specified")
func (o tegraOptions) newDiscovererFromCSVFiles() (discover.Discover, error) {
if len(o.csvFiles) == 0 {
o.logger.Warningf("No CSV files specified")
return discover.None{}, nil
}
targetsByType := getTargetsFromCSVFiles(logger, files)
targetsByType := getTargetsFromCSVFiles(o.logger, o.csvFiles)
devices := discover.NewDeviceDiscoverer(
logger,
lookup.NewCharDeviceLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)),
driverRoot,
o.logger,
lookup.NewCharDeviceLocator(lookup.WithLogger(o.logger), lookup.WithRoot(o.driverRoot)),
o.driverRoot,
targetsByType[csv.MountSpecDev],
)
directories := discover.NewMounts(
logger,
lookup.NewDirectoryLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)),
driverRoot,
o.logger,
lookup.NewDirectoryLocator(lookup.WithLogger(o.logger), lookup.WithRoot(o.driverRoot)),
o.driverRoot,
targetsByType[csv.MountSpecDir],
)
// Libraries and symlinks use the same locator.
searchPaths := append(librarySearchPaths, "/")
symlinkLocator := lookup.NewSymlinkLocator(
lookup.WithLogger(logger),
lookup.WithRoot(driverRoot),
lookup.WithSearchPaths(searchPaths...),
)
libraries := discover.NewMounts(
logger,
symlinkLocator,
driverRoot,
o.logger,
o.symlinkLocator,
o.driverRoot,
targetsByType[csv.MountSpecLib],
)
nonLibSymlinks := ignoreFilenamePatterns{"*.so", "*.so.[0-9]"}.Apply(targetsByType[csv.MountSpecSym]...)
logger.Debugf("Non-lib symlinks: %v", nonLibSymlinks)
o.logger.Debugf("Non-lib symlinks: %v", nonLibSymlinks)
symlinks := discover.NewMounts(
logger,
symlinkLocator,
driverRoot,
o.logger,
o.symlinkLocator,
o.driverRoot,
nonLibSymlinks,
)
createSymlinks := createCSVSymlinkHooks(logger, nonLibSymlinks, libraries, nvidiaCTKPath)
createSymlinks := o.createCSVSymlinkHooks(nonLibSymlinks, libraries)
d := discover.Merge(
devices,
@ -87,7 +81,9 @@ func newDiscovererFromCSVFiles(logger logger.Interface, files []string, driverRo
// getTargetsFromCSVFiles returns the list of mount specs from the specified CSV files.
// These are aggregated by mount spec type.
func getTargetsFromCSVFiles(logger logger.Interface, files []string) map[csv.MountSpecType][]string {
// TODO: We use a function variable here to allow this to be overridden for testing.
// This should be properly mocked.
var getTargetsFromCSVFiles = func(logger logger.Interface, files []string) map[csv.MountSpecType][]string {
targetsByType := make(map[csv.MountSpecType][]string)
for _, filename := range files {
targets, err := loadCSVFile(logger, filename)

View File

@ -15,3 +15,127 @@
**/
package tegra
import (
"fmt"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
)
func TestDiscovererFromCSVFiles(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := []struct {
description string
moutSpecs map[csv.MountSpecType][]string
symlinkLocator lookup.Locator
symlinkChainLocator lookup.Locator
symlinkResolver func(string) (string, error)
expectedError error
expectedMounts []discover.Mount
expectedMountsError error
expectedHooks []discover.Hook
expectedHooksError error
}{
{
// TODO: This current resolves to two mounts that are the same.
// These are deduplicated at a later stage. We could consider deduplicating earlier in the pipeline.
description: "symlink is resolved to target; mounts and symlink are created",
moutSpecs: map[csv.MountSpecType][]string{
"lib": {"/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so"},
"sym": {"/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so"},
},
symlinkLocator: &lookup.LocatorMock{
LocateFunc: func(path string) ([]string, error) {
if path == "/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so" {
return []string{"/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so"}, nil
}
return []string{path}, nil
},
},
symlinkChainLocator: &lookup.LocatorMock{
LocateFunc: func(path string) ([]string, error) {
if path == "/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so" {
return []string{"/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so", "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so"}, nil
}
return nil, fmt.Errorf("Unexpected path: %v", path)
},
},
symlinkResolver: func(path string) (string, error) {
if path == "/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so" {
return "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so", nil
}
return path, nil
},
expectedMounts: []discover.Mount{
{
Path: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so",
HostPath: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so",
Options: []string{"ro", "nosuid", "nodev", "bind"},
},
{
Path: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so",
HostPath: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so",
Options: []string{"ro", "nosuid", "nodev", "bind"},
},
},
expectedHooks: []discover.Hook{
{
Lifecycle: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Args: []string{
"nvidia-ctk",
"hook",
"create-symlinks",
"--link",
"/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so::/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so",
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
defer setGetTargetsFromCSVFiles(tc.moutSpecs)()
o := tegraOptions{
logger: logger,
nvidiaCTKPath: "/usr/bin/nvidia-ctk",
csvFiles: []string{"dummy"},
symlinkLocator: tc.symlinkLocator,
symlinkChainLocator: tc.symlinkChainLocator,
resolveSymlink: tc.symlinkResolver,
}
d, err := o.newDiscovererFromCSVFiles()
require.ErrorIs(t, err, tc.expectedError)
hooks, err := d.Hooks()
require.ErrorIs(t, err, tc.expectedHooksError)
require.EqualValues(t, tc.expectedHooks, hooks)
mounts, err := d.Mounts()
require.ErrorIs(t, err, tc.expectedMountsError)
require.EqualValues(t, tc.expectedMounts, mounts)
})
}
}
func setGetTargetsFromCSVFiles(ovverride map[csv.MountSpecType][]string) func() {
original := getTargetsFromCSVFiles
getTargetsFromCSVFiles = func(logger logger.Interface, files []string) map[csv.MountSpecType][]string {
return ovverride
}
return func() {
getTargetsFromCSVFiles = original
}
}

View File

@ -24,25 +24,29 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks"
)
type symlinkHook struct {
discover.None
logger logger.Interface
driverRoot string
nvidiaCTKPath string
targets []string
mountsFrom discover.Discover
// The following can be overridden for testing
symlinkChainLocator lookup.Locator
resolveSymlink func(string) (string, error)
}
// createCSVSymlinkHooks creates a discoverer for a hook that creates required symlinks in the container
func createCSVSymlinkHooks(logger logger.Interface, targets []string, mounts discover.Discover, nvidiaCTKPath string) discover.Discover {
func (o tegraOptions) createCSVSymlinkHooks(targets []string, mounts discover.Discover) discover.Discover {
return symlinkHook{
logger: logger,
nvidiaCTKPath: nvidiaCTKPath,
targets: targets,
mountsFrom: mounts,
logger: o.logger,
nvidiaCTKPath: o.nvidiaCTKPath,
targets: targets,
mountsFrom: mounts,
symlinkChainLocator: o.symlinkChainLocator,
resolveSymlink: o.resolveSymlink,
}
}
@ -105,14 +109,9 @@ func (d symlinkHook) getSpecificLinks() ([]string, error) {
// getSymlinkCandidates returns a list of symlinks that are candidates for being created.
func (d symlinkHook) getSymlinkCandidates() []string {
chainLocator := lookup.NewSymlinkChainLocator(
lookup.WithLogger(d.logger),
lookup.WithRoot(d.driverRoot),
)
var candidates []string
for _, target := range d.targets {
reslovedSymlinkChain, err := chainLocator.Locate(target)
reslovedSymlinkChain, err := d.symlinkChainLocator.Locate(target)
if err != nil {
d.logger.Warningf("Failed to locate symlink %v", target)
continue
@ -127,7 +126,7 @@ func (d symlinkHook) getCSVFileSymlinks() []string {
created := make(map[string]bool)
// candidates is a list of absolute paths to symlinks in a chain, or the final target of the chain.
for _, candidate := range d.getSymlinkCandidates() {
target, err := symlinks.Resolve(candidate)
target, err := d.resolveSymlink(candidate)
if err != nil {
d.logger.Debugf("Skipping invalid link: %v", err)
continue

View File

@ -22,6 +22,7 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks"
)
type tegraOptions struct {
@ -30,6 +31,12 @@ type tegraOptions struct {
driverRoot string
nvidiaCTKPath string
librarySearchPaths []string
// The following can be overridden for testing
symlinkLocator lookup.Locator
symlinkChainLocator lookup.Locator
// TODO: This should be replaced by a regular mock
resolveSymlink func(string) (string, error)
}
// Option defines a functional option for configuring a Tegra discoverer.
@ -42,7 +49,27 @@ func New(opts ...Option) (discover.Discover, error) {
opt(o)
}
csvDiscoverer, err := newDiscovererFromCSVFiles(o.logger, o.csvFiles, o.driverRoot, o.nvidiaCTKPath, o.librarySearchPaths)
if o.symlinkLocator == nil {
searchPaths := append(o.librarySearchPaths, "/")
o.symlinkLocator = lookup.NewSymlinkLocator(
lookup.WithLogger(o.logger),
lookup.WithRoot(o.driverRoot),
lookup.WithSearchPaths(searchPaths...),
)
}
if o.symlinkChainLocator == nil {
o.symlinkChainLocator = lookup.NewSymlinkChainLocator(
lookup.WithLogger(o.logger),
lookup.WithRoot(o.driverRoot),
)
}
if o.resolveSymlink == nil {
o.resolveSymlink = symlinks.Resolve
}
csvDiscoverer, err := o.newDiscovererFromCSVFiles()
if err != nil {
return nil, fmt.Errorf("failed to create CSV discoverer: %v", err)
}