Refactor the way we create CDI Hooks

Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
This commit is contained in:
Carlos Eduardo Arango Gutierrez
2025-05-16 18:44:59 +02:00
parent ac8f190c99
commit d08bfa8074
29 changed files with 207 additions and 205 deletions

View File

@@ -27,11 +27,11 @@ import (
// byPathHookDiscoverer discovers the entities required for injecting by-path DRM device links
type byPathHookDiscoverer struct {
logger logger.Interface
devRoot string
nvidiaCDIHookPath string
pciBusID string
deviceNodes discover.Discover
logger logger.Interface
devRoot string
hookCreator discover.HookCreator
pciBusID string
deviceNodes discover.Discover
}
var _ discover.Discover = (*byPathHookDiscoverer)(nil)
@@ -53,18 +53,9 @@ func (d *byPathHookDiscoverer) Hooks() ([]discover.Hook, error) {
return nil, nil
}
var args []string
for _, l := range links {
args = append(args, "--link", l)
}
hook := d.hookCreator.Create("create-symlinks", links...)
hook := discover.CreateNvidiaCDIHook(
d.nvidiaCDIHookPath,
"create-symlinks",
args...,
)
return []discover.Hook{hook}, nil
return hook.Hooks()
}
// Mounts returns an empty slice for a full GPU

View File

@@ -58,11 +58,11 @@ func (o *options) newNvmlDGPUDiscoverer(d requiredInfo) (discover.Discover, erro
)
byPathHooks := &byPathHookDiscoverer{
logger: o.logger,
devRoot: o.devRoot,
nvidiaCDIHookPath: o.nvidiaCDIHookPath,
pciBusID: pciBusID,
deviceNodes: deviceNodes,
logger: o.logger,
devRoot: o.devRoot,
hookCreator: o.hookCreator,
pciBusID: pciBusID,
deviceNodes: deviceNodes,
}
dd := discover.Merge(

View File

@@ -28,12 +28,12 @@ import (
)
type nvsandboxutilsDGPU struct {
lib nvsandboxutils.Interface
uuid string
devRoot string
isMig bool
nvidiaCDIHookPath string
deviceLinks []string
lib nvsandboxutils.Interface
uuid string
devRoot string
isMig bool
hookCreator discover.HookCreator
deviceLinks []string
}
var _ discover.Discover = (*nvsandboxutilsDGPU)(nil)
@@ -53,11 +53,11 @@ func (o *options) newNvsandboxutilsDGPUDiscoverer(d UUIDer) (discover.Discover,
}
nvd := nvsandboxutilsDGPU{
lib: o.nvsandboxutilslib,
uuid: uuid,
devRoot: strings.TrimSuffix(filepath.Clean(o.devRoot), "/dev"),
isMig: o.isMigDevice,
nvidiaCDIHookPath: o.nvidiaCDIHookPath,
lib: o.nvsandboxutilslib,
uuid: uuid,
devRoot: strings.TrimSuffix(filepath.Clean(o.devRoot), "/dev"),
isMig: o.isMigDevice,
hookCreator: o.hookCreator,
}
return &nvd, nil
@@ -112,18 +112,9 @@ func (d *nvsandboxutilsDGPU) Hooks() ([]discover.Hook, error) {
return nil, nil
}
var args []string
for _, l := range d.deviceLinks {
args = append(args, "--link", l)
}
hook := d.hookCreator.Create("create-symlinks", d.deviceLinks...)
hook := discover.CreateNvidiaCDIHook(
d.nvidiaCDIHookPath,
"create-symlinks",
args...,
)
return []discover.Hook{hook}, nil
return hook.Hooks()
}
func (d *nvsandboxutilsDGPU) Mounts() ([]discover.Mount, error) {

View File

@@ -17,15 +17,16 @@
package dgpu
import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils"
)
type options struct {
logger logger.Interface
devRoot string
nvidiaCDIHookPath string
logger logger.Interface
devRoot string
hookCreator discover.HookCreator
isMigDevice bool
// migCaps stores the MIG capabilities for the system.
@@ -52,10 +53,10 @@ func WithLogger(logger logger.Interface) Option {
}
}
// WithNVIDIACDIHookPath sets the path to the NVIDIA Container Toolkit CLI path for the library
func WithNVIDIACDIHookPath(path string) Option {
// WithHookCreator sets the hook creator for the library
func WithHookCreator(hookCreator discover.HookCreator) Option {
return func(l *options) {
l.nvidiaCDIHookPath = path
l.hookCreator = hookCreator
}
}

View File

@@ -59,7 +59,7 @@ func (o tegraOptions) newDiscovererFromCSVFiles() (discover.Discover, error) {
targetsByType[csv.MountSpecLib],
),
"",
o.nvidiaCDIHookPath,
o.hookCreator,
)
// We process the explicitly requested symlinks.

View File

@@ -184,10 +184,11 @@ func TestDiscovererFromCSVFiles(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
defer setGetTargetsFromCSVFiles(tc.moutSpecs)()
hookCreator := discover.NewHookCreator("/usr/bin/nvidia-cdi-hook")
o := tegraOptions{
logger: logger,
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
hookCreator: hookCreator,
csvFiles: []string{"dummy"},
ignorePatterns: tc.ignorePatterns,
symlinkLocator: tc.symlinkLocator,

View File

@@ -26,9 +26,9 @@ import (
type symlinkHook struct {
discover.None
logger logger.Interface
nvidiaCDIHookPath string
targets []string
logger logger.Interface
hookCreator discover.HookCreator
targets []string
// The following can be overridden for testing
symlinkChainLocator lookup.Locator
@@ -39,7 +39,7 @@ type symlinkHook struct {
func (o tegraOptions) createCSVSymlinkHooks(targets []string) discover.Discover {
return symlinkHook{
logger: o.logger,
nvidiaCDIHookPath: o.nvidiaCDIHookPath,
hookCreator: o.hookCreator,
targets: targets,
symlinkChainLocator: o.symlinkChainLocator,
resolveSymlink: o.resolveSymlink,
@@ -48,10 +48,7 @@ func (o tegraOptions) createCSVSymlinkHooks(targets []string) discover.Discover
// Hooks returns a hook to create the symlinks from the required CSV files
func (d symlinkHook) Hooks() ([]discover.Hook, error) {
return discover.CreateCreateSymlinkHook(
d.nvidiaCDIHookPath,
d.getCSVFileSymlinks(),
).Hooks()
return d.hookCreator.Create("create-symlinks", d.getCSVFileSymlinks()...).Hooks()
}
// getSymlinkCandidates returns a list of symlinks that are candidates for being created.

View File

@@ -30,7 +30,7 @@ type tegraOptions struct {
csvFiles []string
driverRoot string
devRoot string
nvidiaCDIHookPath string
hookCreator discover.HookCreator
ldconfigPath string
librarySearchPaths []string
ignorePatterns ignoreMountSpecPatterns
@@ -80,7 +80,7 @@ func New(opts ...Option) (discover.Discover, error) {
return nil, fmt.Errorf("failed to create CSV discoverer: %v", err)
}
ldcacheUpdateHook, err := discover.NewLDCacheUpdateHook(o.logger, csvDiscoverer, o.nvidiaCDIHookPath, o.ldconfigPath)
ldcacheUpdateHook, err := discover.NewLDCacheUpdateHook(o.logger, csvDiscoverer, o.hookCreator, o.ldconfigPath)
if err != nil {
return nil, fmt.Errorf("failed to create ldcach update hook discoverer: %v", err)
}
@@ -133,10 +133,10 @@ func WithCSVFiles(csvFiles []string) Option {
}
}
// WithNVIDIACDIHookPath sets the path to the nvidia-cdi-hook binary.
func WithNVIDIACDIHookPath(nvidiaCDIHookPath string) Option {
// WithHookCreator sets the hook creator for the discoverer.
func WithHookCreator(hookCreator discover.HookCreator) Option {
return func(o *tegraOptions) {
o.nvidiaCDIHookPath = nvidiaCDIHookPath
o.hookCreator = hookCreator
}
}