mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2024-11-22 00:08:11 +00:00
Refactor CSV discovery for testability
This change improves the testibility of the CSV discoverer. This is done by adding injection points for mocks for library discovery and symlink resolution. Note that this highlights a bug in the current implementation where the library filter causes valid symlinks to be skipped. Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
be570fce65
commit
963250a58f
@ -28,51 +28,45 @@ import (
|
||||
// newDiscovererFromCSVFiles creates a discoverer for the specified CSV files. A logger is also supplied.
|
||||
// The constructed discoverer is comprised of a list, with each element in the list being associated with a
|
||||
// single CSV files.
|
||||
func newDiscovererFromCSVFiles(logger logger.Interface, files []string, driverRoot string, nvidiaCTKPath string, librarySearchPaths []string) (discover.Discover, error) {
|
||||
if len(files) == 0 {
|
||||
logger.Warningf("No CSV files specified")
|
||||
func (o tegraOptions) newDiscovererFromCSVFiles() (discover.Discover, error) {
|
||||
if len(o.csvFiles) == 0 {
|
||||
o.logger.Warningf("No CSV files specified")
|
||||
return discover.None{}, nil
|
||||
}
|
||||
|
||||
targetsByType := getTargetsFromCSVFiles(logger, files)
|
||||
targetsByType := getTargetsFromCSVFiles(o.logger, o.csvFiles)
|
||||
|
||||
devices := discover.NewDeviceDiscoverer(
|
||||
logger,
|
||||
lookup.NewCharDeviceLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)),
|
||||
driverRoot,
|
||||
o.logger,
|
||||
lookup.NewCharDeviceLocator(lookup.WithLogger(o.logger), lookup.WithRoot(o.driverRoot)),
|
||||
o.driverRoot,
|
||||
targetsByType[csv.MountSpecDev],
|
||||
)
|
||||
|
||||
directories := discover.NewMounts(
|
||||
logger,
|
||||
lookup.NewDirectoryLocator(lookup.WithLogger(logger), lookup.WithRoot(driverRoot)),
|
||||
driverRoot,
|
||||
o.logger,
|
||||
lookup.NewDirectoryLocator(lookup.WithLogger(o.logger), lookup.WithRoot(o.driverRoot)),
|
||||
o.driverRoot,
|
||||
targetsByType[csv.MountSpecDir],
|
||||
)
|
||||
|
||||
// Libraries and symlinks use the same locator.
|
||||
searchPaths := append(librarySearchPaths, "/")
|
||||
symlinkLocator := lookup.NewSymlinkLocator(
|
||||
lookup.WithLogger(logger),
|
||||
lookup.WithRoot(driverRoot),
|
||||
lookup.WithSearchPaths(searchPaths...),
|
||||
)
|
||||
libraries := discover.NewMounts(
|
||||
logger,
|
||||
symlinkLocator,
|
||||
driverRoot,
|
||||
o.logger,
|
||||
o.symlinkLocator,
|
||||
o.driverRoot,
|
||||
targetsByType[csv.MountSpecLib],
|
||||
)
|
||||
|
||||
nonLibSymlinks := ignoreFilenamePatterns{"*.so", "*.so.[0-9]"}.Apply(targetsByType[csv.MountSpecSym]...)
|
||||
logger.Debugf("Non-lib symlinks: %v", nonLibSymlinks)
|
||||
o.logger.Debugf("Non-lib symlinks: %v", nonLibSymlinks)
|
||||
symlinks := discover.NewMounts(
|
||||
logger,
|
||||
symlinkLocator,
|
||||
driverRoot,
|
||||
o.logger,
|
||||
o.symlinkLocator,
|
||||
o.driverRoot,
|
||||
nonLibSymlinks,
|
||||
)
|
||||
createSymlinks := createCSVSymlinkHooks(logger, nonLibSymlinks, libraries, nvidiaCTKPath)
|
||||
createSymlinks := o.createCSVSymlinkHooks(nonLibSymlinks, libraries)
|
||||
|
||||
d := discover.Merge(
|
||||
devices,
|
||||
@ -87,7 +81,9 @@ func newDiscovererFromCSVFiles(logger logger.Interface, files []string, driverRo
|
||||
|
||||
// getTargetsFromCSVFiles returns the list of mount specs from the specified CSV files.
|
||||
// These are aggregated by mount spec type.
|
||||
func getTargetsFromCSVFiles(logger logger.Interface, files []string) map[csv.MountSpecType][]string {
|
||||
// TODO: We use a function variable here to allow this to be overridden for testing.
|
||||
// This should be properly mocked.
|
||||
var getTargetsFromCSVFiles = func(logger logger.Interface, files []string) map[csv.MountSpecType][]string {
|
||||
targetsByType := make(map[csv.MountSpecType][]string)
|
||||
for _, filename := range files {
|
||||
targets, err := loadCSVFile(logger, filename)
|
||||
|
@ -15,3 +15,127 @@
|
||||
**/
|
||||
|
||||
package tegra
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
|
||||
)
|
||||
|
||||
func TestDiscovererFromCSVFiles(t *testing.T) {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
testCases := []struct {
|
||||
description string
|
||||
moutSpecs map[csv.MountSpecType][]string
|
||||
symlinkLocator lookup.Locator
|
||||
symlinkChainLocator lookup.Locator
|
||||
symlinkResolver func(string) (string, error)
|
||||
expectedError error
|
||||
expectedMounts []discover.Mount
|
||||
expectedMountsError error
|
||||
expectedHooks []discover.Hook
|
||||
expectedHooksError error
|
||||
}{
|
||||
{
|
||||
// TODO: This current resolves to two mounts that are the same.
|
||||
// These are deduplicated at a later stage. We could consider deduplicating earlier in the pipeline.
|
||||
description: "symlink is resolved to target; mounts and symlink are created",
|
||||
moutSpecs: map[csv.MountSpecType][]string{
|
||||
"lib": {"/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so"},
|
||||
"sym": {"/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so"},
|
||||
},
|
||||
symlinkLocator: &lookup.LocatorMock{
|
||||
LocateFunc: func(path string) ([]string, error) {
|
||||
if path == "/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so" {
|
||||
return []string{"/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so"}, nil
|
||||
}
|
||||
return []string{path}, nil
|
||||
},
|
||||
},
|
||||
symlinkChainLocator: &lookup.LocatorMock{
|
||||
LocateFunc: func(path string) ([]string, error) {
|
||||
if path == "/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so" {
|
||||
return []string{"/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so", "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so"}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("Unexpected path: %v", path)
|
||||
},
|
||||
},
|
||||
symlinkResolver: func(path string) (string, error) {
|
||||
if path == "/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so" {
|
||||
return "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so", nil
|
||||
}
|
||||
return path, nil
|
||||
},
|
||||
expectedMounts: []discover.Mount{
|
||||
{
|
||||
Path: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so",
|
||||
HostPath: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so",
|
||||
Options: []string{"ro", "nosuid", "nodev", "bind"},
|
||||
},
|
||||
{
|
||||
Path: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so",
|
||||
HostPath: "/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so",
|
||||
Options: []string{"ro", "nosuid", "nodev", "bind"},
|
||||
},
|
||||
},
|
||||
expectedHooks: []discover.Hook{
|
||||
{
|
||||
Lifecycle: "createContainer",
|
||||
Path: "/usr/bin/nvidia-ctk",
|
||||
Args: []string{
|
||||
"nvidia-ctk",
|
||||
"hook",
|
||||
"create-symlinks",
|
||||
"--link",
|
||||
"/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so::/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
defer setGetTargetsFromCSVFiles(tc.moutSpecs)()
|
||||
|
||||
o := tegraOptions{
|
||||
logger: logger,
|
||||
nvidiaCTKPath: "/usr/bin/nvidia-ctk",
|
||||
csvFiles: []string{"dummy"},
|
||||
symlinkLocator: tc.symlinkLocator,
|
||||
symlinkChainLocator: tc.symlinkChainLocator,
|
||||
resolveSymlink: tc.symlinkResolver,
|
||||
}
|
||||
|
||||
d, err := o.newDiscovererFromCSVFiles()
|
||||
require.ErrorIs(t, err, tc.expectedError)
|
||||
|
||||
hooks, err := d.Hooks()
|
||||
require.ErrorIs(t, err, tc.expectedHooksError)
|
||||
require.EqualValues(t, tc.expectedHooks, hooks)
|
||||
|
||||
mounts, err := d.Mounts()
|
||||
require.ErrorIs(t, err, tc.expectedMountsError)
|
||||
require.EqualValues(t, tc.expectedMounts, mounts)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setGetTargetsFromCSVFiles(ovverride map[csv.MountSpecType][]string) func() {
|
||||
original := getTargetsFromCSVFiles
|
||||
getTargetsFromCSVFiles = func(logger logger.Interface, files []string) map[csv.MountSpecType][]string {
|
||||
return ovverride
|
||||
}
|
||||
|
||||
return func() {
|
||||
getTargetsFromCSVFiles = original
|
||||
}
|
||||
}
|
||||
|
@ -24,25 +24,29 @@ import (
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks"
|
||||
)
|
||||
|
||||
type symlinkHook struct {
|
||||
discover.None
|
||||
logger logger.Interface
|
||||
driverRoot string
|
||||
nvidiaCTKPath string
|
||||
targets []string
|
||||
mountsFrom discover.Discover
|
||||
|
||||
// The following can be overridden for testing
|
||||
symlinkChainLocator lookup.Locator
|
||||
resolveSymlink func(string) (string, error)
|
||||
}
|
||||
|
||||
// createCSVSymlinkHooks creates a discoverer for a hook that creates required symlinks in the container
|
||||
func createCSVSymlinkHooks(logger logger.Interface, targets []string, mounts discover.Discover, nvidiaCTKPath string) discover.Discover {
|
||||
func (o tegraOptions) createCSVSymlinkHooks(targets []string, mounts discover.Discover) discover.Discover {
|
||||
return symlinkHook{
|
||||
logger: logger,
|
||||
nvidiaCTKPath: nvidiaCTKPath,
|
||||
targets: targets,
|
||||
mountsFrom: mounts,
|
||||
logger: o.logger,
|
||||
nvidiaCTKPath: o.nvidiaCTKPath,
|
||||
targets: targets,
|
||||
mountsFrom: mounts,
|
||||
symlinkChainLocator: o.symlinkChainLocator,
|
||||
resolveSymlink: o.resolveSymlink,
|
||||
}
|
||||
}
|
||||
|
||||
@ -105,14 +109,9 @@ func (d symlinkHook) getSpecificLinks() ([]string, error) {
|
||||
|
||||
// getSymlinkCandidates returns a list of symlinks that are candidates for being created.
|
||||
func (d symlinkHook) getSymlinkCandidates() []string {
|
||||
chainLocator := lookup.NewSymlinkChainLocator(
|
||||
lookup.WithLogger(d.logger),
|
||||
lookup.WithRoot(d.driverRoot),
|
||||
)
|
||||
|
||||
var candidates []string
|
||||
for _, target := range d.targets {
|
||||
reslovedSymlinkChain, err := chainLocator.Locate(target)
|
||||
reslovedSymlinkChain, err := d.symlinkChainLocator.Locate(target)
|
||||
if err != nil {
|
||||
d.logger.Warningf("Failed to locate symlink %v", target)
|
||||
continue
|
||||
@ -127,7 +126,7 @@ func (d symlinkHook) getCSVFileSymlinks() []string {
|
||||
created := make(map[string]bool)
|
||||
// candidates is a list of absolute paths to symlinks in a chain, or the final target of the chain.
|
||||
for _, candidate := range d.getSymlinkCandidates() {
|
||||
target, err := symlinks.Resolve(candidate)
|
||||
target, err := d.resolveSymlink(candidate)
|
||||
if err != nil {
|
||||
d.logger.Debugf("Skipping invalid link: %v", err)
|
||||
continue
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks"
|
||||
)
|
||||
|
||||
type tegraOptions struct {
|
||||
@ -30,6 +31,12 @@ type tegraOptions struct {
|
||||
driverRoot string
|
||||
nvidiaCTKPath string
|
||||
librarySearchPaths []string
|
||||
|
||||
// The following can be overridden for testing
|
||||
symlinkLocator lookup.Locator
|
||||
symlinkChainLocator lookup.Locator
|
||||
// TODO: This should be replaced by a regular mock
|
||||
resolveSymlink func(string) (string, error)
|
||||
}
|
||||
|
||||
// Option defines a functional option for configuring a Tegra discoverer.
|
||||
@ -42,7 +49,27 @@ func New(opts ...Option) (discover.Discover, error) {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
csvDiscoverer, err := newDiscovererFromCSVFiles(o.logger, o.csvFiles, o.driverRoot, o.nvidiaCTKPath, o.librarySearchPaths)
|
||||
if o.symlinkLocator == nil {
|
||||
searchPaths := append(o.librarySearchPaths, "/")
|
||||
o.symlinkLocator = lookup.NewSymlinkLocator(
|
||||
lookup.WithLogger(o.logger),
|
||||
lookup.WithRoot(o.driverRoot),
|
||||
lookup.WithSearchPaths(searchPaths...),
|
||||
)
|
||||
}
|
||||
|
||||
if o.symlinkChainLocator == nil {
|
||||
o.symlinkChainLocator = lookup.NewSymlinkChainLocator(
|
||||
lookup.WithLogger(o.logger),
|
||||
lookup.WithRoot(o.driverRoot),
|
||||
)
|
||||
}
|
||||
|
||||
if o.resolveSymlink == nil {
|
||||
o.resolveSymlink = symlinks.Resolve
|
||||
}
|
||||
|
||||
csvDiscoverer, err := o.newDiscovererFromCSVFiles()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create CSV discoverer: %v", err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user