Merge pull request #1090 from ArangoGutierrez/hookcreator
Some checks are pending
CI Pipeline / code-scanning (push) Waiting to run
CI Pipeline / variables (push) Waiting to run
CI Pipeline / golang (push) Waiting to run
CI Pipeline / image (push) Blocked by required conditions
CI Pipeline / e2e-test (push) Blocked by required conditions

Refactor the way we handle Hook Creation
This commit is contained in:
Evan Lezar 2025-05-21 14:23:51 +02:00 committed by GitHub
commit f93d96a0de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 203 additions and 215 deletions

7
.gitignore vendored
View File

@ -4,11 +4,8 @@
*.swo *.swo
/coverage.out* /coverage.out*
/tests/output/ /tests/output/
/nvidia-container-runtime /nvidia-*
/nvidia-container-runtime.*
/nvidia-container-runtime-hook
/nvidia-container-toolkit
/nvidia-ctk
/shared-* /shared-*
/release-* /release-*
/bin /bin
/toolkit-test

View File

@ -9,16 +9,12 @@ import (
// NewCUDACompatHookDiscoverer creates a discoverer for a enable-cuda-compat hook. // NewCUDACompatHookDiscoverer creates a discoverer for a enable-cuda-compat hook.
// This hook is responsible for setting up CUDA compatibility in the container and depends on the host driver version. // This hook is responsible for setting up CUDA compatibility in the container and depends on the host driver version.
func NewCUDACompatHookDiscoverer(logger logger.Interface, nvidiaCDIHookPath string, driver *root.Driver) Discover { func NewCUDACompatHookDiscoverer(logger logger.Interface, hookCreator HookCreator, driver *root.Driver) Discover {
_, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver) _, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver)
var args []string var args []string
if !strings.Contains(cudaVersionPattern, "*") { if !strings.Contains(cudaVersionPattern, "*") {
args = append(args, "--host-driver-version="+cudaVersionPattern) args = append(args, "--host-driver-version="+cudaVersionPattern)
} }
return CreateNvidiaCDIHook( return hookCreator.Create("enable-cuda-compat", args...)
nvidiaCDIHookPath,
"enable-cuda-compat",
args...,
)
} }

View File

@ -36,21 +36,21 @@ import (
// TODO: The logic for creating DRM devices should be consolidated between this // TODO: The logic for creating DRM devices should be consolidated between this
// and the logic for generating CDI specs for a single device. This is only used // and the logic for generating CDI specs for a single device. This is only used
// when applying OCI spec modifications to an incoming spec in "legacy" mode. // when applying OCI spec modifications to an incoming spec in "legacy" mode.
func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, nvidiaCDIHookPath string) (Discover, error) { func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, hookCreator HookCreator) (Discover, error) {
drmDeviceNodes, err := newDRMDeviceDiscoverer(logger, devices, devRoot) drmDeviceNodes, err := newDRMDeviceDiscoverer(logger, devices, devRoot)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create DRM device discoverer: %v", err) return nil, fmt.Errorf("failed to create DRM device discoverer: %v", err)
} }
drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, devRoot, nvidiaCDIHookPath) drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, devRoot, hookCreator)
discover := Merge(drmDeviceNodes, drmByPathSymlinks) discover := Merge(drmDeviceNodes, drmByPathSymlinks)
return discover, nil return discover, nil
} }
// NewGraphicsMountsDiscoverer creates a discoverer for the mounts required by graphics tools such as vulkan. // NewGraphicsMountsDiscoverer creates a discoverer for the mounts required by graphics tools such as vulkan.
func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string) (Discover, error) { func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator HookCreator) (Discover, error) {
libraries := newGraphicsLibrariesDiscoverer(logger, driver, nvidiaCDIHookPath) libraries := newGraphicsLibrariesDiscoverer(logger, driver, hookCreator)
configs := NewMounts( configs := NewMounts(
logger, logger,
@ -95,13 +95,13 @@ func newVulkanConfigsDiscover(logger logger.Interface, driver *root.Driver) Disc
type graphicsDriverLibraries struct { type graphicsDriverLibraries struct {
Discover Discover
logger logger.Interface logger logger.Interface
nvidiaCDIHookPath string hookCreator HookCreator
} }
var _ Discover = (*graphicsDriverLibraries)(nil) var _ Discover = (*graphicsDriverLibraries)(nil)
func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string) Discover { func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator HookCreator) Discover {
cudaLibRoot, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver) cudaLibRoot, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver)
libraries := NewMounts( libraries := NewMounts(
@ -140,9 +140,9 @@ func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver
) )
return &graphicsDriverLibraries{ return &graphicsDriverLibraries{
Discover: Merge(libraries, xorgLibraries), Discover: Merge(libraries, xorgLibraries),
logger: logger, logger: logger,
nvidiaCDIHookPath: nvidiaCDIHookPath, hookCreator: hookCreator,
} }
} }
@ -203,9 +203,9 @@ func (d graphicsDriverLibraries) Hooks() ([]Hook, error) {
return nil, nil return nil, nil
} }
hooks := CreateCreateSymlinkHook(d.nvidiaCDIHookPath, links) hook := d.hookCreator.Create("create-symlinks", links...)
return hooks.Hooks() return hook.Hooks()
} }
// isDriverLibrary checks whether the specified filename is a specific driver library. // isDriverLibrary checks whether the specified filename is a specific driver library.
@ -275,19 +275,19 @@ func buildXOrgSearchPaths(libRoot string) []string {
type drmDevicesByPath struct { type drmDevicesByPath struct {
None None
logger logger.Interface logger logger.Interface
nvidiaCDIHookPath string hookCreator HookCreator
devRoot string devRoot string
devicesFrom Discover devicesFrom Discover
} }
// newCreateDRMByPathSymlinks creates a discoverer for a hook to create the by-path symlinks for DRM devices discovered by the specified devices discoverer // newCreateDRMByPathSymlinks creates a discoverer for a hook to create the by-path symlinks for DRM devices discovered by the specified devices discoverer
func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, nvidiaCDIHookPath string) Discover { func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, hookCreator HookCreator) Discover {
d := drmDevicesByPath{ d := drmDevicesByPath{
logger: logger, logger: logger,
nvidiaCDIHookPath: nvidiaCDIHookPath, hookCreator: hookCreator,
devRoot: devRoot, devRoot: devRoot,
devicesFrom: devices, devicesFrom: devices,
} }
return &d return &d
@ -315,13 +315,9 @@ func (d drmDevicesByPath) Hooks() ([]Hook, error) {
args = append(args, "--link", l) args = append(args, "--link", l)
} }
hook := CreateNvidiaCDIHook( hook := d.hookCreator.Create("create-symlinks", args...)
d.nvidiaCDIHookPath,
"create-symlinks",
args...,
)
return []Hook{hook}, nil return hook.Hooks()
} }
// getSpecificLinkArgs returns the required specific links that need to be created // getSpecificLinkArgs returns the required specific links that need to be created

View File

@ -25,6 +25,7 @@ import (
func TestGraphicsLibrariesDiscoverer(t *testing.T) { func TestGraphicsLibrariesDiscoverer(t *testing.T) {
logger, _ := testlog.NewNullLogger() logger, _ := testlog.NewNullLogger()
hookCreator := NewHookCreator("/usr/bin/nvidia-cdi-hook")
testCases := []struct { testCases := []struct {
description string description string
@ -136,9 +137,9 @@ func TestGraphicsLibrariesDiscoverer(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
d := &graphicsDriverLibraries{ d := &graphicsDriverLibraries{
Discover: tc.libraries, Discover: tc.libraries,
logger: logger, logger: logger,
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook", hookCreator: hookCreator,
} }
devices, err := d.Devices() devices, err := d.Devices()

View File

@ -25,54 +25,66 @@ import (
var _ Discover = (*Hook)(nil) var _ Discover = (*Hook)(nil)
// Devices returns an empty list of devices for a Hook discoverer. // Devices returns an empty list of devices for a Hook discoverer.
func (h Hook) Devices() ([]Device, error) { func (h *Hook) Devices() ([]Device, error) {
return nil, nil return nil, nil
} }
// Mounts returns an empty list of mounts for a Hook discoverer. // Mounts returns an empty list of mounts for a Hook discoverer.
func (h Hook) Mounts() ([]Mount, error) { func (h *Hook) Mounts() ([]Mount, error) {
return nil, nil return nil, nil
} }
// Hooks allows the Hook type to also implement the Discoverer interface. // Hooks allows the Hook type to also implement the Discoverer interface.
// It returns a single hook // It returns a single hook
func (h Hook) Hooks() ([]Hook, error) { func (h *Hook) Hooks() ([]Hook, error) {
return []Hook{h}, nil if h == nil {
} return nil, nil
// CreateCreateSymlinkHook creates a hook which creates a symlink from link -> target.
func CreateCreateSymlinkHook(nvidiaCDIHookPath string, links []string) Discover {
if len(links) == 0 {
return None{}
} }
var args []string return []Hook{*h}, nil
for _, link := range links { }
args = append(args, "--link", link)
// Option is a function that configures the nvcdilib
type Option func(*CDIHook)
type CDIHook struct {
nvidiaCDIHookPath string
}
type HookCreator interface {
Create(string, ...string) *Hook
}
func NewHookCreator(nvidiaCDIHookPath string) HookCreator {
CDIHook := &CDIHook{
nvidiaCDIHookPath: nvidiaCDIHookPath,
} }
return CreateNvidiaCDIHook(
nvidiaCDIHookPath, return CDIHook
"create-symlinks",
args...,
)
} }
// CreateNvidiaCDIHook creates a hook which invokes the NVIDIA Container CLI hook subcommand. func (c CDIHook) Create(name string, args ...string) *Hook {
func CreateNvidiaCDIHook(nvidiaCDIHookPath string, hookName string, additionalArgs ...string) Hook { if name == "create-symlinks" {
return cdiHook(nvidiaCDIHookPath).Create(hookName, additionalArgs...) if len(args) == 0 {
} return nil
}
type cdiHook string links := []string{}
for _, arg := range args {
links = append(links, "--link", arg)
}
args = links
}
func (c cdiHook) Create(name string, args ...string) Hook { return &Hook{
return Hook{
Lifecycle: cdi.CreateContainerHook, Lifecycle: cdi.CreateContainerHook,
Path: string(c), Path: c.nvidiaCDIHookPath,
Args: append(c.requiredArgs(name), args...), Args: append(c.requiredArgs(name), args...),
} }
} }
func (c cdiHook) requiredArgs(name string) []string {
base := filepath.Base(string(c)) func (c CDIHook) requiredArgs(name string) []string {
base := filepath.Base(c.nvidiaCDIHookPath)
if base == "nvidia-ctk" { if base == "nvidia-ctk" {
return []string{base, "hook", name} return []string{base, "hook", name}
} }

View File

@ -25,12 +25,12 @@ import (
) )
// NewLDCacheUpdateHook creates a discoverer that updates the ldcache for the specified mounts. A logger can also be specified // NewLDCacheUpdateHook creates a discoverer that updates the ldcache for the specified mounts. A logger can also be specified
func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, nvidiaCDIHookPath, ldconfigPath string) (Discover, error) { func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, hookCreator HookCreator, ldconfigPath string) (Discover, error) {
d := ldconfig{ d := ldconfig{
logger: logger, logger: logger,
nvidiaCDIHookPath: nvidiaCDIHookPath, hookCreator: hookCreator,
ldconfigPath: ldconfigPath, ldconfigPath: ldconfigPath,
mountsFrom: mounts, mountsFrom: mounts,
} }
return &d, nil return &d, nil
@ -38,10 +38,10 @@ func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, nvidiaCDIHoo
type ldconfig struct { type ldconfig struct {
None None
logger logger.Interface logger logger.Interface
nvidiaCDIHookPath string hookCreator HookCreator
ldconfigPath string ldconfigPath string
mountsFrom Discover mountsFrom Discover
} }
// Hooks checks the required mounts for libraries and returns a hook to update the LDcache for the discovered paths. // Hooks checks the required mounts for libraries and returns a hook to update the LDcache for the discovered paths.
@ -50,16 +50,18 @@ func (d ldconfig) Hooks() ([]Hook, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to discover mounts for ldcache update: %v", err) return nil, fmt.Errorf("failed to discover mounts for ldcache update: %v", err)
} }
h := CreateLDCacheUpdateHook(
d.nvidiaCDIHookPath, h := createLDCacheUpdateHook(
d.hookCreator,
d.ldconfigPath, d.ldconfigPath,
getLibraryPaths(mounts), getLibraryPaths(mounts),
) )
return []Hook{h}, nil
return h.Hooks()
} }
// CreateLDCacheUpdateHook locates the NVIDIA Container Toolkit CLI and creates a hook for updating the LD Cache // createLDCacheUpdateHook locates the NVIDIA Container Toolkit CLI and creates a hook for updating the LD Cache
func CreateLDCacheUpdateHook(executable string, ldconfig string, libraries []string) Hook { func createLDCacheUpdateHook(hookCreator HookCreator, ldconfig string, libraries []string) *Hook {
var args []string var args []string
if ldconfig != "" { if ldconfig != "" {
@ -70,13 +72,7 @@ func CreateLDCacheUpdateHook(executable string, ldconfig string, libraries []str
args = append(args, "--folder", f) args = append(args, "--folder", f)
} }
hook := CreateNvidiaCDIHook( return hookCreator.Create("update-ldcache", args...)
executable,
"update-ldcache",
args...,
)
return hook
} }
// getLibraryPaths extracts the library dirs from the specified mounts // getLibraryPaths extracts the library dirs from the specified mounts

View File

@ -31,6 +31,7 @@ const (
func TestLDCacheUpdateHook(t *testing.T) { func TestLDCacheUpdateHook(t *testing.T) {
logger, _ := testlog.NewNullLogger() logger, _ := testlog.NewNullLogger()
hookCreator := NewHookCreator(testNvidiaCDIHookPath)
testCases := []struct { testCases := []struct {
description string description string
@ -97,7 +98,7 @@ func TestLDCacheUpdateHook(t *testing.T) {
Lifecycle: "createContainer", Lifecycle: "createContainer",
} }
d, err := NewLDCacheUpdateHook(logger, mountMock, testNvidiaCDIHookPath, tc.ldconfigPath) d, err := NewLDCacheUpdateHook(logger, mountMock, hookCreator, tc.ldconfigPath)
require.NoError(t, err) require.NoError(t, err)
hooks, err := d.Hooks() hooks, err := d.Hooks()

View File

@ -23,20 +23,20 @@ import (
type additionalSymlinks struct { type additionalSymlinks struct {
Discover Discover
version string version string
nvidiaCDIHookPath string hookCreator HookCreator
} }
// WithDriverDotSoSymlinks decorates the provided discoverer. // WithDriverDotSoSymlinks decorates the provided discoverer.
// A hook is added that checks for specific driver symlinks that need to be created. // A hook is added that checks for specific driver symlinks that need to be created.
func WithDriverDotSoSymlinks(mounts Discover, version string, nvidiaCDIHookPath string) Discover { func WithDriverDotSoSymlinks(mounts Discover, version string, hookCreator HookCreator) Discover {
if version == "" { if version == "" {
version = "*.*" version = "*.*"
} }
return &additionalSymlinks{ return &additionalSymlinks{
Discover: mounts, Discover: mounts,
nvidiaCDIHookPath: nvidiaCDIHookPath, hookCreator: hookCreator,
version: version, version: version,
} }
} }
@ -73,8 +73,12 @@ func (d *additionalSymlinks) Hooks() ([]Hook, error) {
return hooks, nil return hooks, nil
} }
hook := CreateCreateSymlinkHook(d.nvidiaCDIHookPath, links).(Hook) createSymlinkHooks, err := d.hookCreator.Create("create-symlinks", links...).Hooks()
return append(hooks, hook), nil if err != nil {
return nil, fmt.Errorf("failed to create symlink hook: %v", err)
}
return append(hooks, createSymlinkHooks...), nil
} }
// getLinksForMount maps the path to created links if any. // getLinksForMount maps the path to created links if any.

View File

@ -306,12 +306,13 @@ func TestWithWithDriverDotSoSymlinks(t *testing.T) {
}, },
} }
hookCreator := NewHookCreator("/path/to/nvidia-cdi-hook")
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
d := WithDriverDotSoSymlinks( d := WithDriverDotSoSymlinks(
tc.discover, tc.discover,
tc.version, tc.version,
"/path/to/nvidia-cdi-hook", hookCreator,
) )
devices, err := d.Devices() devices, err := d.Devices()

View File

@ -36,7 +36,7 @@ import (
// NVIDIA_GDRCOPY=enabled // NVIDIA_GDRCOPY=enabled
// //
// If not devices are selected, no changes are made. // If not devices are selected, no changes are made.
func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver) (oci.SpecModifier, error) { func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) {
if devices := image.VisibleDevicesFromEnvVar(); len(devices) == 0 { if devices := image.VisibleDevicesFromEnvVar(); len(devices) == 0 {
logger.Infof("No modification required; no devices requested") logger.Infof("No modification required; no devices requested")
return nil, nil return nil, nil
@ -81,7 +81,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image
// If the feature flag has explicitly been toggled, we don't make any modification. // If the feature flag has explicitly been toggled, we don't make any modification.
if !cfg.Features.DisableCUDACompatLibHook.IsEnabled() { if !cfg.Features.DisableCUDACompatLibHook.IsEnabled() {
cudaCompatDiscoverer, err := getCudaCompatModeDiscoverer(logger, cfg, driver) cudaCompatDiscoverer, err := getCudaCompatModeDiscoverer(logger, cfg, driver, hookCreator)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to construct CUDA Compat discoverer: %w", err) return nil, fmt.Errorf("failed to construct CUDA Compat discoverer: %w", err)
} }
@ -91,13 +91,13 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image
return NewModifierFromDiscoverer(logger, discover.Merge(discoverers...)) return NewModifierFromDiscoverer(logger, discover.Merge(discoverers...))
} }
func getCudaCompatModeDiscoverer(logger logger.Interface, cfg *config.Config, driver *root.Driver) (discover.Discover, error) { func getCudaCompatModeDiscoverer(logger logger.Interface, cfg *config.Config, driver *root.Driver, hookCreator discover.HookCreator) (discover.Discover, error) {
// For legacy mode, we only include the enable-cuda-compat hook if cuda-compat-mode is set to hook. // For legacy mode, we only include the enable-cuda-compat hook if cuda-compat-mode is set to hook.
if cfg.NVIDIAContainerRuntimeConfig.Mode == "legacy" && cfg.NVIDIAContainerRuntimeConfig.Modes.Legacy.CUDACompatMode != config.CUDACompatModeHook { if cfg.NVIDIAContainerRuntimeConfig.Mode == "legacy" && cfg.NVIDIAContainerRuntimeConfig.Modes.Legacy.CUDACompatMode != config.CUDACompatModeHook {
return nil, nil return nil, nil
} }
compatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(logger, cfg.NVIDIACTKConfig.Path, driver) compatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(logger, hookCreator, driver)
// For non-legacy modes we return the hook as is. These modes *should* already include the update-ldcache hook. // For non-legacy modes we return the hook as is. These modes *should* already include the update-ldcache hook.
if cfg.NVIDIAContainerRuntimeConfig.Mode != "legacy" { if cfg.NVIDIAContainerRuntimeConfig.Mode != "legacy" {
return compatLibHookDiscoverer, nil return compatLibHookDiscoverer, nil
@ -108,7 +108,7 @@ func getCudaCompatModeDiscoverer(logger logger.Interface, cfg *config.Config, dr
ldcacheUpdateHookDiscoverer, err := discover.NewLDCacheUpdateHook( ldcacheUpdateHookDiscoverer, err := discover.NewLDCacheUpdateHook(
logger, logger,
discover.None{}, discover.None{},
cfg.NVIDIACTKConfig.Path, hookCreator,
"", "",
) )
if err != nil { if err != nil {

View File

@ -29,18 +29,16 @@ import (
// NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification. // NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification.
// The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made. // The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made.
func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerImage image.CUDA, driver *root.Driver) (oci.SpecModifier, error) { func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerImage image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) {
if required, reason := requiresGraphicsModifier(containerImage); !required { if required, reason := requiresGraphicsModifier(containerImage); !required {
logger.Infof("No graphics modifier required: %v", reason) logger.Infof("No graphics modifier required: %v", reason)
return nil, nil return nil, nil
} }
nvidiaCDIHookPath := cfg.NVIDIACTKConfig.Path
mounts, err := discover.NewGraphicsMountsDiscoverer( mounts, err := discover.NewGraphicsMountsDiscoverer(
logger, logger,
driver, driver,
nvidiaCDIHookPath, hookCreator,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create mounts discoverer: %v", err) return nil, fmt.Errorf("failed to create mounts discoverer: %v", err)
@ -52,7 +50,7 @@ func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerI
logger, logger,
containerImage.DevicesFromEnvvars(image.EnvVarNvidiaVisibleDevices), containerImage.DevicesFromEnvvars(image.EnvVarNvidiaVisibleDevices),
devRoot, devRoot,
nvidiaCDIHookPath, hookCreator,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to construct discoverer: %v", err) return nil, fmt.Errorf("failed to construct discoverer: %v", err)

View File

@ -27,11 +27,11 @@ import (
// byPathHookDiscoverer discovers the entities required for injecting by-path DRM device links // byPathHookDiscoverer discovers the entities required for injecting by-path DRM device links
type byPathHookDiscoverer struct { type byPathHookDiscoverer struct {
logger logger.Interface logger logger.Interface
devRoot string devRoot string
nvidiaCDIHookPath string hookCreator discover.HookCreator
pciBusID string pciBusID string
deviceNodes discover.Discover deviceNodes discover.Discover
} }
var _ discover.Discover = (*byPathHookDiscoverer)(nil) var _ discover.Discover = (*byPathHookDiscoverer)(nil)
@ -53,18 +53,9 @@ func (d *byPathHookDiscoverer) Hooks() ([]discover.Hook, error) {
return nil, nil return nil, nil
} }
var args []string hook := d.hookCreator.Create("create-symlinks", links...)
for _, l := range links {
args = append(args, "--link", l)
}
hook := discover.CreateNvidiaCDIHook( return hook.Hooks()
d.nvidiaCDIHookPath,
"create-symlinks",
args...,
)
return []discover.Hook{hook}, nil
} }
// Mounts returns an empty slice for a full GPU // Mounts returns an empty slice for a full GPU

View File

@ -58,11 +58,11 @@ func (o *options) newNvmlDGPUDiscoverer(d requiredInfo) (discover.Discover, erro
) )
byPathHooks := &byPathHookDiscoverer{ byPathHooks := &byPathHookDiscoverer{
logger: o.logger, logger: o.logger,
devRoot: o.devRoot, devRoot: o.devRoot,
nvidiaCDIHookPath: o.nvidiaCDIHookPath, hookCreator: o.hookCreator,
pciBusID: pciBusID, pciBusID: pciBusID,
deviceNodes: deviceNodes, deviceNodes: deviceNodes,
} }
dd := discover.Merge( dd := discover.Merge(

View File

@ -28,12 +28,12 @@ import (
) )
type nvsandboxutilsDGPU struct { type nvsandboxutilsDGPU struct {
lib nvsandboxutils.Interface lib nvsandboxutils.Interface
uuid string uuid string
devRoot string devRoot string
isMig bool isMig bool
nvidiaCDIHookPath string hookCreator discover.HookCreator
deviceLinks []string deviceLinks []string
} }
var _ discover.Discover = (*nvsandboxutilsDGPU)(nil) var _ discover.Discover = (*nvsandboxutilsDGPU)(nil)
@ -53,11 +53,11 @@ func (o *options) newNvsandboxutilsDGPUDiscoverer(d UUIDer) (discover.Discover,
} }
nvd := nvsandboxutilsDGPU{ nvd := nvsandboxutilsDGPU{
lib: o.nvsandboxutilslib, lib: o.nvsandboxutilslib,
uuid: uuid, uuid: uuid,
devRoot: strings.TrimSuffix(filepath.Clean(o.devRoot), "/dev"), devRoot: strings.TrimSuffix(filepath.Clean(o.devRoot), "/dev"),
isMig: o.isMigDevice, isMig: o.isMigDevice,
nvidiaCDIHookPath: o.nvidiaCDIHookPath, hookCreator: o.hookCreator,
} }
return &nvd, nil return &nvd, nil
@ -112,18 +112,9 @@ func (d *nvsandboxutilsDGPU) Hooks() ([]discover.Hook, error) {
return nil, nil return nil, nil
} }
var args []string hook := d.hookCreator.Create("create-symlinks", d.deviceLinks...)
for _, l := range d.deviceLinks {
args = append(args, "--link", l)
}
hook := discover.CreateNvidiaCDIHook( return hook.Hooks()
d.nvidiaCDIHookPath,
"create-symlinks",
args...,
)
return []discover.Hook{hook}, nil
} }
func (d *nvsandboxutilsDGPU) Mounts() ([]discover.Mount, error) { func (d *nvsandboxutilsDGPU) Mounts() ([]discover.Mount, error) {

View File

@ -17,15 +17,16 @@
package dgpu package dgpu
import ( import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps" "github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils" "github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils"
) )
type options struct { type options struct {
logger logger.Interface logger logger.Interface
devRoot string devRoot string
nvidiaCDIHookPath string hookCreator discover.HookCreator
isMigDevice bool isMigDevice bool
// migCaps stores the MIG capabilities for the system. // migCaps stores the MIG capabilities for the system.
@ -52,10 +53,10 @@ func WithLogger(logger logger.Interface) Option {
} }
} }
// WithNVIDIACDIHookPath sets the path to the NVIDIA Container Toolkit CLI path for the library // WithHookCreator sets the hook creator for the library
func WithNVIDIACDIHookPath(path string) Option { func WithHookCreator(hookCreator discover.HookCreator) Option {
return func(l *options) { return func(l *options) {
l.nvidiaCDIHookPath = path l.hookCreator = hookCreator
} }
} }

View File

@ -59,7 +59,7 @@ func (o tegraOptions) newDiscovererFromCSVFiles() (discover.Discover, error) {
targetsByType[csv.MountSpecLib], targetsByType[csv.MountSpecLib],
), ),
"", "",
o.nvidiaCDIHookPath, o.hookCreator,
) )
// We process the explicitly requested symlinks. // We process the explicitly requested symlinks.

View File

@ -181,13 +181,14 @@ func TestDiscovererFromCSVFiles(t *testing.T) {
}, },
} }
hookCreator := discover.NewHookCreator("/usr/bin/nvidia-cdi-hook")
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
defer setGetTargetsFromCSVFiles(tc.moutSpecs)() defer setGetTargetsFromCSVFiles(tc.moutSpecs)()
o := tegraOptions{ o := tegraOptions{
logger: logger, logger: logger,
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook", hookCreator: hookCreator,
csvFiles: []string{"dummy"}, csvFiles: []string{"dummy"},
ignorePatterns: tc.ignorePatterns, ignorePatterns: tc.ignorePatterns,
symlinkLocator: tc.symlinkLocator, symlinkLocator: tc.symlinkLocator,

View File

@ -26,9 +26,9 @@ import (
type symlinkHook struct { type symlinkHook struct {
discover.None discover.None
logger logger.Interface logger logger.Interface
nvidiaCDIHookPath string hookCreator discover.HookCreator
targets []string targets []string
// The following can be overridden for testing // The following can be overridden for testing
symlinkChainLocator lookup.Locator symlinkChainLocator lookup.Locator
@ -39,7 +39,7 @@ type symlinkHook struct {
func (o tegraOptions) createCSVSymlinkHooks(targets []string) discover.Discover { func (o tegraOptions) createCSVSymlinkHooks(targets []string) discover.Discover {
return symlinkHook{ return symlinkHook{
logger: o.logger, logger: o.logger,
nvidiaCDIHookPath: o.nvidiaCDIHookPath, hookCreator: o.hookCreator,
targets: targets, targets: targets,
symlinkChainLocator: o.symlinkChainLocator, symlinkChainLocator: o.symlinkChainLocator,
resolveSymlink: o.resolveSymlink, resolveSymlink: o.resolveSymlink,
@ -48,10 +48,7 @@ func (o tegraOptions) createCSVSymlinkHooks(targets []string) discover.Discover
// Hooks returns a hook to create the symlinks from the required CSV files // Hooks returns a hook to create the symlinks from the required CSV files
func (d symlinkHook) Hooks() ([]discover.Hook, error) { func (d symlinkHook) Hooks() ([]discover.Hook, error) {
return discover.CreateCreateSymlinkHook( return d.hookCreator.Create("create-symlinks", d.getCSVFileSymlinks()...).Hooks()
d.nvidiaCDIHookPath,
d.getCSVFileSymlinks(),
).Hooks()
} }
// getSymlinkCandidates returns a list of symlinks that are candidates for being created. // getSymlinkCandidates returns a list of symlinks that are candidates for being created.

View File

@ -30,7 +30,7 @@ type tegraOptions struct {
csvFiles []string csvFiles []string
driverRoot string driverRoot string
devRoot string devRoot string
nvidiaCDIHookPath string hookCreator discover.HookCreator
ldconfigPath string ldconfigPath string
librarySearchPaths []string librarySearchPaths []string
ignorePatterns ignoreMountSpecPatterns ignorePatterns ignoreMountSpecPatterns
@ -80,7 +80,7 @@ func New(opts ...Option) (discover.Discover, error) {
return nil, fmt.Errorf("failed to create CSV discoverer: %v", err) return nil, fmt.Errorf("failed to create CSV discoverer: %v", err)
} }
ldcacheUpdateHook, err := discover.NewLDCacheUpdateHook(o.logger, csvDiscoverer, o.nvidiaCDIHookPath, o.ldconfigPath) ldcacheUpdateHook, err := discover.NewLDCacheUpdateHook(o.logger, csvDiscoverer, o.hookCreator, o.ldconfigPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create ldcach update hook discoverer: %v", err) return nil, fmt.Errorf("failed to create ldcach update hook discoverer: %v", err)
} }
@ -133,10 +133,10 @@ func WithCSVFiles(csvFiles []string) Option {
} }
} }
// WithNVIDIACDIHookPath sets the path to the nvidia-cdi-hook binary. // WithHookCreator sets the hook creator for the discoverer.
func WithNVIDIACDIHookPath(nvidiaCDIHookPath string) Option { func WithHookCreator(hookCreator discover.HookCreator) Option {
return func(o *tegraOptions) { return func(o *tegraOptions) {
o.nvidiaCDIHookPath = nvidiaCDIHookPath o.hookCreator = hookCreator
} }
} }

View File

@ -21,6 +21,7 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info" "github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
@ -74,6 +75,8 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
return nil, err return nil, err
} }
hookCreator := discover.NewHookCreator(cfg.NVIDIACTKConfig.Path)
mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image) mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image)
// We update the mode here so that we can continue passing just the config to other functions. // We update the mode here so that we can continue passing just the config to other functions.
cfg.NVIDIAContainerRuntimeConfig.Mode = mode cfg.NVIDIAContainerRuntimeConfig.Mode = mode
@ -90,13 +93,13 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
case "nvidia-hook-remover": case "nvidia-hook-remover":
modifiers = append(modifiers, modifier.NewNvidiaContainerRuntimeHookRemover(logger)) modifiers = append(modifiers, modifier.NewNvidiaContainerRuntimeHookRemover(logger))
case "graphics": case "graphics":
graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver) graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver, hookCreator)
if err != nil { if err != nil {
return nil, err return nil, err
} }
modifiers = append(modifiers, graphicsModifier) modifiers = append(modifiers, graphicsModifier)
case "feature-gated": case "feature-gated":
featureGatedModifier, err := modifier.NewFeatureGatedModifier(logger, cfg, image, driver) featureGatedModifier, err := modifier.NewFeatureGatedModifier(logger, cfg, image, driver, hookCreator)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -36,7 +36,7 @@ func (l *nvmllib) newCommonNVMLDiscoverer() (discover.Discover, error) {
}, },
) )
graphicsMounts, err := discover.NewGraphicsMountsDiscoverer(l.logger, l.driver, l.nvidiaCDIHookPath) graphicsMounts, err := discover.NewGraphicsMountsDiscoverer(l.logger, l.driver, l.hookCreator)
if err != nil { if err != nil {
l.logger.Warningf("failed to create discoverer for graphics mounts: %v", err) l.logger.Warningf("failed to create discoverer for graphics mounts: %v", err)
} }

View File

@ -102,17 +102,17 @@ func (l *nvcdilib) NewDriverLibraryDiscoverer(version string) (discover.Discover
driverDotSoSymlinksDiscoverer := discover.WithDriverDotSoSymlinks( driverDotSoSymlinksDiscoverer := discover.WithDriverDotSoSymlinks(
libraries, libraries,
version, version,
l.nvidiaCDIHookPath, l.hookCreator,
) )
discoverers = append(discoverers, driverDotSoSymlinksDiscoverer) discoverers = append(discoverers, driverDotSoSymlinksDiscoverer)
if l.HookIsSupported(HookEnableCudaCompat) { if l.HookIsSupported(HookEnableCudaCompat) {
// TODO: The following should use the version directly. // TODO: The following should use the version directly.
cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.nvidiaCDIHookPath, l.driver) cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, l.driver)
discoverers = append(discoverers, cudaCompatLibHookDiscoverer) discoverers = append(discoverers, cudaCompatLibHookDiscoverer)
} }
updateLDCache, _ := discover.NewLDCacheUpdateHook(l.logger, libraries, l.nvidiaCDIHookPath, l.ldconfigPath) updateLDCache, _ := discover.NewLDCacheUpdateHook(l.logger, libraries, l.hookCreator, l.ldconfigPath)
discoverers = append(discoverers, updateLDCache) discoverers = append(discoverers, updateLDCache)
d := discover.Merge(discoverers...) d := discover.Merge(discoverers...)

View File

@ -39,7 +39,7 @@ var requiredDriverStoreFiles = []string{
} }
// newWSLDriverDiscoverer returns a Discoverer for WSL2 drivers. // newWSLDriverDiscoverer returns a Discoverer for WSL2 drivers.
func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, nvidiaCDIHookPath, ldconfigPath string) (discover.Discover, error) { func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, hookCreator discover.HookCreator, ldconfigPath string) (discover.Discover, error) {
err := dxcore.Init() err := dxcore.Init()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize dxcore: %v", err) return nil, fmt.Errorf("failed to initialize dxcore: %v", err)
@ -56,11 +56,11 @@ func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, nvidiaCD
} }
logger.Infof("Using WSL driver store paths: %v", driverStorePaths) logger.Infof("Using WSL driver store paths: %v", driverStorePaths)
return newWSLDriverStoreDiscoverer(logger, driverRoot, nvidiaCDIHookPath, ldconfigPath, driverStorePaths) return newWSLDriverStoreDiscoverer(logger, driverRoot, hookCreator, ldconfigPath, driverStorePaths)
} }
// newWSLDriverStoreDiscoverer returns a Discoverer for WSL2 drivers in the driver store associated with a dxcore adapter. // newWSLDriverStoreDiscoverer returns a Discoverer for WSL2 drivers in the driver store associated with a dxcore adapter.
func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, nvidiaCDIHookPath string, ldconfigPath string, driverStorePaths []string) (discover.Discover, error) { func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, hookCreator discover.HookCreator, ldconfigPath string, driverStorePaths []string) (discover.Discover, error) {
var searchPaths []string var searchPaths []string
seen := make(map[string]bool) seen := make(map[string]bool)
for _, path := range driverStorePaths { for _, path := range driverStorePaths {
@ -88,12 +88,12 @@ func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, nvi
) )
symlinkHook := nvidiaSMISimlinkHook{ symlinkHook := nvidiaSMISimlinkHook{
logger: logger, logger: logger,
mountsFrom: libraries, mountsFrom: libraries,
nvidiaCDIHookPath: nvidiaCDIHookPath, hookCreator: hookCreator,
} }
ldcacheHook, _ := discover.NewLDCacheUpdateHook(logger, libraries, nvidiaCDIHookPath, ldconfigPath) ldcacheHook, _ := discover.NewLDCacheUpdateHook(logger, libraries, hookCreator, ldconfigPath)
d := discover.Merge( d := discover.Merge(
libraries, libraries,
@ -106,9 +106,9 @@ func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, nvi
type nvidiaSMISimlinkHook struct { type nvidiaSMISimlinkHook struct {
discover.None discover.None
logger logger.Interface logger logger.Interface
mountsFrom discover.Discover mountsFrom discover.Discover
nvidiaCDIHookPath string hookCreator discover.HookCreator
} }
// Hooks returns a hook that creates a symlink to nvidia-smi in the driver store. // Hooks returns a hook that creates a symlink to nvidia-smi in the driver store.
@ -135,7 +135,7 @@ func (m nvidiaSMISimlinkHook) Hooks() ([]discover.Hook, error) {
} }
link := "/usr/bin/nvidia-smi" link := "/usr/bin/nvidia-smi"
links := []string{fmt.Sprintf("%s::%s", target, link)} links := []string{fmt.Sprintf("%s::%s", target, link)}
symlinkHook := discover.CreateCreateSymlinkHook(m.nvidiaCDIHookPath, links) symlinkHook := m.hookCreator.Create("create-symlinks", links...)
return symlinkHook.Hooks() return symlinkHook.Hooks()
} }

View File

@ -29,6 +29,7 @@ import (
func TestNvidiaSMISymlinkHook(t *testing.T) { func TestNvidiaSMISymlinkHook(t *testing.T) {
logger, _ := testlog.NewNullLogger() logger, _ := testlog.NewNullLogger()
hookCreator := discover.NewHookCreator("nvidia-cdi-hook")
errMounts := errors.New("mounts error") errMounts := errors.New("mounts error")
@ -143,9 +144,9 @@ func TestNvidiaSMISymlinkHook(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
m := nvidiaSMISimlinkHook{ m := nvidiaSMISimlinkHook{
logger: logger, logger: logger,
mountsFrom: tc.mounts, mountsFrom: tc.mounts,
nvidiaCDIHookPath: "nvidia-cdi-hook", hookCreator: hookCreator,
} }
devices, err := m.Devices() devices, err := m.Devices()

View File

@ -71,7 +71,7 @@ func (l *nvmllib) newFullGPUDiscoverer(d device.Device) (discover.Discover, erro
deviceNodes, err := dgpu.NewForDevice(d, deviceNodes, err := dgpu.NewForDevice(d,
dgpu.WithDevRoot(l.devRoot), dgpu.WithDevRoot(l.devRoot),
dgpu.WithLogger(l.logger), dgpu.WithLogger(l.logger),
dgpu.WithNVIDIACDIHookPath(l.nvidiaCDIHookPath), dgpu.WithHookCreator(l.hookCreator),
dgpu.WithNvsandboxuitilsLib(l.nvsandboxutilslib), dgpu.WithNvsandboxuitilsLib(l.nvsandboxutilslib),
) )
if err != nil { if err != nil {
@ -81,7 +81,7 @@ func (l *nvmllib) newFullGPUDiscoverer(d device.Device) (discover.Discover, erro
deviceFolderPermissionHooks := newDeviceFolderPermissionHookDiscoverer( deviceFolderPermissionHooks := newDeviceFolderPermissionHookDiscoverer(
l.logger, l.logger,
l.devRoot, l.devRoot,
l.nvidiaCDIHookPath, l.hookCreator,
deviceNodes, deviceNodes,
) )

View File

@ -44,7 +44,7 @@ func (l *csvlib) GetAllDeviceSpecs() ([]specs.Device, error) {
tegra.WithLogger(l.logger), tegra.WithLogger(l.logger),
tegra.WithDriverRoot(l.driverRoot), tegra.WithDriverRoot(l.driverRoot),
tegra.WithDevRoot(l.devRoot), tegra.WithDevRoot(l.devRoot),
tegra.WithNVIDIACDIHookPath(l.nvidiaCDIHookPath), tegra.WithHookCreator(l.hookCreator),
tegra.WithLdconfigPath(l.ldconfigPath), tegra.WithLdconfigPath(l.ldconfigPath),
tegra.WithCSVFiles(l.csvFiles), tegra.WithCSVFiles(l.csvFiles),
tegra.WithLibrarySearchPaths(l.librarySearchPaths...), tegra.WithLibrarySearchPaths(l.librarySearchPaths...),

View File

@ -54,7 +54,7 @@ func (l *wsllib) GetAllDeviceSpecs() ([]specs.Device, error) {
// GetCommonEdits generates a CDI specification that can be used for ANY devices // GetCommonEdits generates a CDI specification that can be used for ANY devices
func (l *wsllib) GetCommonEdits() (*cdi.ContainerEdits, error) { func (l *wsllib) GetCommonEdits() (*cdi.ContainerEdits, error) {
driver, err := newWSLDriverDiscoverer(l.logger, l.driverRoot, l.nvidiaCDIHookPath, l.ldconfigPath) driver, err := newWSLDriverDiscoverer(l.logger, l.driverRoot, l.hookCreator, l.ldconfigPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create discoverer for WSL driver: %v", err) return nil, fmt.Errorf("failed to create discoverer for WSL driver: %v", err)
} }

View File

@ -23,6 +23,7 @@ import (
"github.com/NVIDIA/go-nvlib/pkg/nvlib/info" "github.com/NVIDIA/go-nvlib/pkg/nvlib/info"
"github.com/NVIDIA/go-nvml/pkg/nvml" "github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils" "github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils"
@ -56,6 +57,7 @@ type nvcdilib struct {
mergedDeviceOptions []transform.MergedDeviceOption mergedDeviceOptions []transform.MergedDeviceOption
disabledHooks disabledHooks disabledHooks disabledHooks
hookCreator discover.HookCreator
} }
// New creates a new nvcdi library // New creates a new nvcdi library
@ -79,6 +81,9 @@ func New(opts ...Option) (Interface, error) {
if l.nvidiaCDIHookPath == "" { if l.nvidiaCDIHookPath == "" {
l.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook" l.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook"
} }
// create hookCreator
l.hookCreator = discover.NewHookCreator(l.nvidiaCDIHookPath)
if l.driverRoot == "" { if l.driverRoot == "" {
l.driverRoot = "/" l.driverRoot = "/"
} }

View File

@ -138,7 +138,7 @@ func (m *managementlib) newManagementDeviceDiscoverer() (discover.Discover, erro
deviceFolderPermissionHooks := newDeviceFolderPermissionHookDiscoverer( deviceFolderPermissionHooks := newDeviceFolderPermissionHookDiscoverer(
m.logger, m.logger,
m.devRoot, m.devRoot,
m.nvidiaCDIHookPath, m.hookCreator,
deviceNodes, deviceNodes,
) )

View File

@ -54,7 +54,7 @@ func (l *nvmllib) GetMIGDeviceEdits(parent device.Device, mig device.MigDevice)
deviceNodes, err := dgpu.NewForMigDevice(parent, mig, deviceNodes, err := dgpu.NewForMigDevice(parent, mig,
dgpu.WithDevRoot(l.devRoot), dgpu.WithDevRoot(l.devRoot),
dgpu.WithLogger(l.logger), dgpu.WithLogger(l.logger),
dgpu.WithNVIDIACDIHookPath(l.nvidiaCDIHookPath), dgpu.WithHookCreator(l.hookCreator),
dgpu.WithNvsandboxuitilsLib(l.nvsandboxutilslib), dgpu.WithNvsandboxuitilsLib(l.nvsandboxutilslib),
) )
if err != nil { if err != nil {

View File

@ -25,10 +25,10 @@ import (
) )
type deviceFolderPermissions struct { type deviceFolderPermissions struct {
logger logger.Interface logger logger.Interface
devRoot string devRoot string
nvidiaCDIHookPath string devices discover.Discover
devices discover.Discover hookCreator discover.HookCreator
} }
var _ discover.Discover = (*deviceFolderPermissions)(nil) var _ discover.Discover = (*deviceFolderPermissions)(nil)
@ -39,12 +39,12 @@ var _ discover.Discover = (*deviceFolderPermissions)(nil)
// The nested devices that are applicable to the NVIDIA GPU devices are: // The nested devices that are applicable to the NVIDIA GPU devices are:
// - DRM devices at /dev/dri/* // - DRM devices at /dev/dri/*
// - NVIDIA Caps devices at /dev/nvidia-caps/* // - NVIDIA Caps devices at /dev/nvidia-caps/*
func newDeviceFolderPermissionHookDiscoverer(logger logger.Interface, devRoot string, nvidiaCDIHookPath string, devices discover.Discover) discover.Discover { func newDeviceFolderPermissionHookDiscoverer(logger logger.Interface, devRoot string, hookCreator discover.HookCreator, devices discover.Discover) discover.Discover {
d := &deviceFolderPermissions{ d := &deviceFolderPermissions{
logger: logger, logger: logger,
devRoot: devRoot, devRoot: devRoot,
nvidiaCDIHookPath: nvidiaCDIHookPath, hookCreator: hookCreator,
devices: devices, devices: devices,
} }
return d return d
@ -70,13 +70,9 @@ func (d *deviceFolderPermissions) Hooks() ([]discover.Hook, error) {
args = append(args, "--path", folder) args = append(args, "--path", folder)
} }
hook := discover.CreateNvidiaCDIHook( hook := d.hookCreator.Create("chmod", args...)
d.nvidiaCDIHookPath,
"chmod",
args...,
)
return []discover.Hook{hook}, nil return []discover.Hook{*hook}, nil
} }
func (d *deviceFolderPermissions) getDeviceSubfolders() ([]string, error) { func (d *deviceFolderPermissions) getDeviceSubfolders() ([]string, error) {