Create an internal/hooks pkg to centralize hook management

Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
This commit is contained in:
Carlos Eduardo Arango Gutierrez 2025-05-21 17:00:23 +02:00
parent f93d96a0de
commit 61ae3dc746
No known key found for this signature in database
GPG Key ID: 42D9CB42F300A852
27 changed files with 469 additions and 204 deletions

View File

@ -3,18 +3,48 @@ package discover
import (
"strings"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
)
// cudaCompatHook is a discoverer for the enable-cuda-compat hook.
type cudaCompatHook struct {
hooks.Hook
}
// NewCUDACompatHookDiscoverer creates a discoverer for a enable-cuda-compat hook.
// This hook is responsible for setting up CUDA compatibility in the container and depends on the host driver version.
func NewCUDACompatHookDiscoverer(logger logger.Interface, hookCreator HookCreator, driver *root.Driver) Discover {
func NewCUDACompatHookDiscoverer(logger logger.Interface, hookCreator hooks.HookCreator, driver *root.Driver) Discover {
_, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver)
var args []string
if !strings.Contains(cudaVersionPattern, "*") {
args = append(args, "--host-driver-version="+cudaVersionPattern)
}
return hookCreator.Create("enable-cuda-compat", args...)
hook := hookCreator.Create(hooks.EnableCudaCompat, args...)
if hook == nil {
return nil
}
return &cudaCompatHook{
Hook: *hook,
}
}
func (h *cudaCompatHook) Hooks() ([]Hook, error) {
return []Hook{
{
Lifecycle: h.Lifecycle,
Path: h.Path,
Args: h.Args,
}}, nil
}
func (h *cudaCompatHook) Devices() ([]Device, error) {
return nil, nil
}
func (h *cudaCompatHook) Mounts() ([]Mount, error) {
return nil, nil
}

View File

@ -23,6 +23,7 @@ import (
"strings"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info/drm"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
@ -36,7 +37,7 @@ import (
// TODO: The logic for creating DRM devices should be consolidated between this
// and the logic for generating CDI specs for a single device. This is only used
// when applying OCI spec modifications to an incoming spec in "legacy" mode.
func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, hookCreator HookCreator) (Discover, error) {
func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, hookCreator hooks.HookCreator) (Discover, error) {
drmDeviceNodes, err := newDRMDeviceDiscoverer(logger, devices, devRoot)
if err != nil {
return nil, fmt.Errorf("failed to create DRM device discoverer: %v", err)
@ -49,7 +50,7 @@ func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices
}
// NewGraphicsMountsDiscoverer creates a discoverer for the mounts required by graphics tools such as vulkan.
func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator HookCreator) (Discover, error) {
func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator hooks.HookCreator) (Discover, error) {
libraries := newGraphicsLibrariesDiscoverer(logger, driver, hookCreator)
configs := NewMounts(
@ -96,12 +97,12 @@ func newVulkanConfigsDiscover(logger logger.Interface, driver *root.Driver) Disc
type graphicsDriverLibraries struct {
Discover
logger logger.Interface
hookCreator HookCreator
hookCreator hooks.HookCreator
}
var _ Discover = (*graphicsDriverLibraries)(nil)
func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator HookCreator) Discover {
func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator hooks.HookCreator) Discover {
cudaLibRoot, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver)
libraries := NewMounts(
@ -203,9 +204,15 @@ func (d graphicsDriverLibraries) Hooks() ([]Hook, error) {
return nil, nil
}
hook := d.hookCreator.Create("create-symlinks", links...)
hook := d.hookCreator.Create(hooks.CreateSymlinks, links...)
return hook.Hooks()
return []Hook{
{
Lifecycle: hook.Lifecycle,
Path: hook.Path,
Args: hook.Args,
},
}, nil
}
// isDriverLibrary checks whether the specified filename is a specific driver library.
@ -276,13 +283,13 @@ func buildXOrgSearchPaths(libRoot string) []string {
type drmDevicesByPath struct {
None
logger logger.Interface
hookCreator HookCreator
hookCreator hooks.HookCreator
devRoot string
devicesFrom Discover
}
// newCreateDRMByPathSymlinks creates a discoverer for a hook to create the by-path symlinks for DRM devices discovered by the specified devices discoverer
func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, hookCreator HookCreator) Discover {
func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, hookCreator hooks.HookCreator) Discover {
d := drmDevicesByPath{
logger: logger,
hookCreator: hookCreator,
@ -315,9 +322,15 @@ func (d drmDevicesByPath) Hooks() ([]Hook, error) {
args = append(args, "--link", l)
}
hook := d.hookCreator.Create("create-symlinks", args...)
hook := d.hookCreator.Create(hooks.CreateSymlinks, args...)
return hook.Hooks()
return []Hook{
{
Lifecycle: hook.Lifecycle,
Path: hook.Path,
Args: hook.Args,
},
}, nil
}
// getSpecificLinkArgs returns the required specific links that need to be created

View File

@ -19,13 +19,15 @@ package discover
import (
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestGraphicsLibrariesDiscoverer(t *testing.T) {
logger, _ := testlog.NewNullLogger()
hookCreator := NewHookCreator("/usr/bin/nvidia-cdi-hook")
hookCreator := hooks.NewHookCreator("/usr/bin/nvidia-cdi-hook", nil)
testCases := []struct {
description string

View File

@ -1,92 +0,0 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package discover
import (
"path/filepath"
"tags.cncf.io/container-device-interface/pkg/cdi"
)
var _ Discover = (*Hook)(nil)
// Devices returns an empty list of devices for a Hook discoverer.
func (h *Hook) Devices() ([]Device, error) {
return nil, nil
}
// Mounts returns an empty list of mounts for a Hook discoverer.
func (h *Hook) Mounts() ([]Mount, error) {
return nil, nil
}
// Hooks allows the Hook type to also implement the Discoverer interface.
// It returns a single hook
func (h *Hook) Hooks() ([]Hook, error) {
if h == nil {
return nil, nil
}
return []Hook{*h}, nil
}
// Option is a function that configures the nvcdilib
type Option func(*CDIHook)
type CDIHook struct {
nvidiaCDIHookPath string
}
type HookCreator interface {
Create(string, ...string) *Hook
}
func NewHookCreator(nvidiaCDIHookPath string) HookCreator {
CDIHook := &CDIHook{
nvidiaCDIHookPath: nvidiaCDIHookPath,
}
return CDIHook
}
func (c CDIHook) Create(name string, args ...string) *Hook {
if name == "create-symlinks" {
if len(args) == 0 {
return nil
}
links := []string{}
for _, arg := range args {
links = append(links, "--link", arg)
}
args = links
}
return &Hook{
Lifecycle: cdi.CreateContainerHook,
Path: c.nvidiaCDIHookPath,
Args: append(c.requiredArgs(name), args...),
}
}
func (c CDIHook) requiredArgs(name string) []string {
base := filepath.Base(c.nvidiaCDIHookPath)
if base == "nvidia-ctk" {
return []string{base, "hook", name}
}
return []string{base, name}
}

View File

@ -21,11 +21,12 @@ import (
"path/filepath"
"strings"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)
// NewLDCacheUpdateHook creates a discoverer that updates the ldcache for the specified mounts. A logger can also be specified
func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, hookCreator HookCreator, ldconfigPath string) (Discover, error) {
func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, hookCreator hooks.HookCreator, ldconfigPath string) (Discover, error) {
d := ldconfig{
logger: logger,
hookCreator: hookCreator,
@ -39,7 +40,7 @@ func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, hookCreator
type ldconfig struct {
None
logger logger.Interface
hookCreator HookCreator
hookCreator hooks.HookCreator
ldconfigPath string
mountsFrom Discover
}
@ -57,11 +58,17 @@ func (d ldconfig) Hooks() ([]Hook, error) {
getLibraryPaths(mounts),
)
return h.Hooks()
return []Hook{
{
Lifecycle: h.Lifecycle,
Path: h.Path,
Args: h.Args,
},
}, nil
}
// createLDCacheUpdateHook locates the NVIDIA Container Toolkit CLI and creates a hook for updating the LD Cache
func createLDCacheUpdateHook(hookCreator HookCreator, ldconfig string, libraries []string) *Hook {
func createLDCacheUpdateHook(hookCreator hooks.HookCreator, ldconfig string, libraries []string) *Hook {
var args []string
if ldconfig != "" {
@ -72,7 +79,12 @@ func createLDCacheUpdateHook(hookCreator HookCreator, ldconfig string, libraries
args = append(args, "--folder", f)
}
return hookCreator.Create("update-ldcache", args...)
h := hookCreator.Create(hooks.UpdateLDCache, args...)
return &Hook{
Lifecycle: h.Lifecycle,
Path: h.Path,
Args: h.Args,
}
}
// getLibraryPaths extracts the library dirs from the specified mounts

View File

@ -20,6 +20,8 @@ import (
"fmt"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
@ -31,7 +33,7 @@ const (
func TestLDCacheUpdateHook(t *testing.T) {
logger, _ := testlog.NewNullLogger()
hookCreator := NewHookCreator(testNvidiaCDIHookPath)
hookCreator := hooks.NewHookCreator(testNvidiaCDIHookPath, nil)
testCases := []struct {
description string

View File

@ -19,17 +19,19 @@ package discover
import (
"fmt"
"path/filepath"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
)
type additionalSymlinks struct {
Discover
version string
hookCreator HookCreator
hookCreator hooks.HookCreator
}
// WithDriverDotSoSymlinks decorates the provided discoverer.
// A hook is added that checks for specific driver symlinks that need to be created.
func WithDriverDotSoSymlinks(mounts Discover, version string, hookCreator HookCreator) Discover {
func WithDriverDotSoSymlinks(mounts Discover, version string, hookCreator hooks.HookCreator) Discover {
if version == "" {
version = "*.*"
}
@ -46,7 +48,7 @@ func (d *additionalSymlinks) Hooks() ([]Hook, error) {
if err != nil {
return nil, fmt.Errorf("failed to get library mounts: %v", err)
}
hooks, err := d.Discover.Hooks()
h, err := d.Discover.Hooks()
if err != nil {
return nil, fmt.Errorf("failed to get hooks: %v", err)
}
@ -70,15 +72,16 @@ func (d *additionalSymlinks) Hooks() ([]Hook, error) {
}
if len(links) == 0 {
return hooks, nil
return h, nil
}
createSymlinkHooks, err := d.hookCreator.Create("create-symlinks", links...).Hooks()
if err != nil {
return nil, fmt.Errorf("failed to create symlink hook: %v", err)
}
createSymlinkHook := d.hookCreator.Create(hooks.CreateSymlinks, links...)
return append(hooks, createSymlinkHooks...), nil
return append(h, Hook{
Lifecycle: createSymlinkHook.Lifecycle,
Path: createSymlinkHook.Path,
Args: createSymlinkHook.Args,
}), nil
}
// getLinksForMount maps the path to created links if any.

View File

@ -19,6 +19,8 @@ package discover
import (
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
"github.com/stretchr/testify/require"
)
@ -306,7 +308,7 @@ func TestWithWithDriverDotSoSymlinks(t *testing.T) {
},
}
hookCreator := NewHookCreator("/path/to/nvidia-cdi-hook")
hookCreator := hooks.NewHookCreator("/path/to/nvidia-cdi-hook", nil)
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
d := WithDriverDotSoSymlinks(

116
internal/hooks/hooks.go Normal file
View File

@ -0,0 +1,116 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package hooks
import (
"path/filepath"
"tags.cncf.io/container-device-interface/pkg/cdi"
)
// A HookName refers to one of the predefined set of CDI hooks that may be
// included in the generated CDI specification.
type HookName string
// DisabledHooks allows individual hooks to be disabled.
type DisabledHooks map[HookName]bool
const (
// EnableCudaCompat refers to the hook used to enable CUDA Forward Compatibility.
// This was added with v1.17.5 of the NVIDIA Container Toolkit.
EnableCudaCompat = HookName("enable-cuda-compat")
// CreateSymlinks refers to the hook used to create symlinks for the NVIDIA
// Container Toolkit. This was added with v1.17.5 of the NVIDIA Container Toolkit.
CreateSymlinks = HookName("create-symlinks")
// UpdateLDCache refers to the hook used to update the LD cache for the NVIDIA
// Container Toolkit. This was added with v1.17.5 of the NVIDIA Container Toolkit.
UpdateLDCache = HookName("update-ldcache")
)
// Hook represents an OCI container hook.
type Hook struct {
Lifecycle string
Path string
Args []string
}
// Option is a function that configures the nvcdilib
type Option func(*CDIHook)
type CDIHook struct {
nvidiaCDIHookPath string
disabledHooks DisabledHooks
}
type HookCreator interface {
Create(HookName, ...string) *Hook
}
func NewHookCreator(nvidiaCDIHookPath string, disabledHooks DisabledHooks) HookCreator {
if disabledHooks == nil {
disabledHooks = make(DisabledHooks)
}
CDIHook := &CDIHook{
nvidiaCDIHookPath: nvidiaCDIHookPath,
disabledHooks: disabledHooks,
}
return CDIHook
}
func (c CDIHook) Create(name HookName, args ...string) *Hook {
if c.disabledHooks[name] {
return nil
}
if name == CreateSymlinks {
if len(args) == 0 {
return nil
}
links := []string{}
for _, arg := range args {
links = append(links, "--link", arg)
}
args = links
}
return &Hook{
Lifecycle: cdi.CreateContainerHook,
Path: c.nvidiaCDIHookPath,
Args: append(c.requiredArgs(name), args...),
}
}
func (c CDIHook) requiredArgs(name HookName) []string {
base := filepath.Base(c.nvidiaCDIHookPath)
if base == "nvidia-ctk" {
return []string{base, "hook", string(name)}
}
return []string{base, string(name)}
}
// IsSupported checks whether a hook of the specified name is supported.
// Hooks must be explicitly disabled, meaning that if no disabled hooks are
// all hooks are supported.
func (c CDIHook) IsSupported(h HookName) bool {
if len(c.disabledHooks) == 0 {
return true
}
return !c.disabledHooks[h]
}

View File

@ -0,0 +1,177 @@
package hooks
import (
"testing"
)
func TestNewHookCreator(t *testing.T) {
tests := []struct {
name string
hookPath string
disabledHooks DisabledHooks
expectedCreator HookCreator
}{
{
name: "nil disabled hooks",
hookPath: "/usr/bin/nvidia-ctk",
expectedCreator: &CDIHook{
nvidiaCDIHookPath: "/usr/bin/nvidia-ctk",
disabledHooks: DisabledHooks{},
},
},
{
name: "with disabled hooks",
hookPath: "/usr/bin/nvidia-ctk",
disabledHooks: DisabledHooks{
EnableCudaCompat: true,
},
expectedCreator: &CDIHook{
nvidiaCDIHookPath: "/usr/bin/nvidia-ctk",
disabledHooks: DisabledHooks{
EnableCudaCompat: true,
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
creator := NewHookCreator(tt.hookPath, tt.disabledHooks)
if creator == nil {
t.Fatal("NewHookCreator returned nil")
}
})
}
}
func TestCDIHook_Create(t *testing.T) {
tests := []struct {
name string
hookPath string
disabledHooks DisabledHooks
hookName HookName
args []string
expectedHook *Hook
}{
{
name: "disabled hook returns nil",
hookPath: "/usr/bin/nvidia-ctk",
disabledHooks: DisabledHooks{
EnableCudaCompat: true,
},
hookName: EnableCudaCompat,
expectedHook: nil,
},
{
name: "create symlinks with no args returns nil",
hookPath: "/usr/bin/nvidia-ctk",
hookName: CreateSymlinks,
expectedHook: nil,
},
{
name: "create symlinks with args",
hookPath: "/usr/bin/nvidia-ctk",
hookName: CreateSymlinks,
args: []string{"/path/to/lib1", "/path/to/lib2"},
expectedHook: &Hook{
Lifecycle: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "create-symlinks", "--link", "/path/to/lib1", "--link", "/path/to/lib2"},
},
},
{
name: "enable cuda compat",
hookPath: "/usr/bin/nvidia-ctk",
hookName: EnableCudaCompat,
expectedHook: &Hook{
Lifecycle: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "enable-cuda-compat"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hook := &CDIHook{
nvidiaCDIHookPath: tt.hookPath,
disabledHooks: tt.disabledHooks,
}
result := hook.Create(tt.hookName, tt.args...)
if tt.expectedHook == nil {
if result != nil {
t.Errorf("expected nil hook, got %v", result)
}
return
}
if result == nil {
t.Fatal("expected non-nil hook, got nil")
}
if result.Lifecycle != tt.expectedHook.Lifecycle {
t.Errorf("expected lifecycle %q, got %q", tt.expectedHook.Lifecycle, result.Lifecycle)
}
if result.Path != tt.expectedHook.Path {
t.Errorf("expected path %q, got %q", tt.expectedHook.Path, result.Path)
}
if len(result.Args) != len(tt.expectedHook.Args) {
t.Errorf("expected %d args, got %d", len(tt.expectedHook.Args), len(result.Args))
return
}
for i, arg := range tt.expectedHook.Args {
if result.Args[i] != arg {
t.Errorf("expected arg[%d] %q, got %q", i, arg, result.Args[i])
}
}
})
}
}
func TestCDIHook_IsSupported(t *testing.T) {
tests := []struct {
name string
disabledHooks DisabledHooks
hookName HookName
expected bool
}{
{
name: "no disabled hooks",
hookName: EnableCudaCompat,
expected: true,
},
{
name: "disabled hook",
disabledHooks: DisabledHooks{
EnableCudaCompat: true,
},
hookName: EnableCudaCompat,
expected: false,
},
{
name: "non-disabled hook",
disabledHooks: DisabledHooks{
EnableCudaCompat: true,
},
hookName: CreateSymlinks,
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hook := &CDIHook{
disabledHooks: tt.disabledHooks,
}
result := hook.IsSupported(tt.hookName)
if result != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, result)
}
})
}
}

View File

@ -22,6 +22,7 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
@ -36,7 +37,7 @@ import (
// NVIDIA_GDRCOPY=enabled
//
// If not devices are selected, no changes are made.
func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) {
func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver, hookCreator hooks.HookCreator) (oci.SpecModifier, error) {
if devices := image.VisibleDevicesFromEnvVar(); len(devices) == 0 {
logger.Infof("No modification required; no devices requested")
return nil, nil
@ -91,7 +92,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image
return NewModifierFromDiscoverer(logger, discover.Merge(discoverers...))
}
func getCudaCompatModeDiscoverer(logger logger.Interface, cfg *config.Config, driver *root.Driver, hookCreator discover.HookCreator) (discover.Discover, error) {
func getCudaCompatModeDiscoverer(logger logger.Interface, cfg *config.Config, driver *root.Driver, hookCreator hooks.HookCreator) (discover.Discover, error) {
// For legacy mode, we only include the enable-cuda-compat hook if cuda-compat-mode is set to hook.
if cfg.NVIDIAContainerRuntimeConfig.Mode == "legacy" && cfg.NVIDIAContainerRuntimeConfig.Modes.Legacy.CUDACompatMode != config.CUDACompatModeHook {
return nil, nil

View File

@ -22,6 +22,7 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
@ -29,7 +30,7 @@ import (
// NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification.
// The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made.
func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerImage image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) {
func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerImage image.CUDA, driver *root.Driver, hookCreator hooks.HookCreator) (oci.SpecModifier, error) {
if required, reason := requiresGraphicsModifier(containerImage); !required {
logger.Infof("No graphics modifier required: %v", reason)
return nil, nil

View File

@ -22,6 +22,7 @@ import (
"path/filepath"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)
@ -29,7 +30,7 @@ import (
type byPathHookDiscoverer struct {
logger logger.Interface
devRoot string
hookCreator discover.HookCreator
hookCreator hooks.HookCreator
pciBusID string
deviceNodes discover.Discover
}
@ -53,9 +54,14 @@ func (d *byPathHookDiscoverer) Hooks() ([]discover.Hook, error) {
return nil, nil
}
hook := d.hookCreator.Create("create-symlinks", links...)
hook := d.hookCreator.Create(hooks.CreateSymlinks, links...)
return hook.Hooks()
return []discover.Hook{
{
Lifecycle: hook.Lifecycle,
Path: hook.Path,
Args: hook.Args,
}}, nil
}
// Mounts returns an empty slice for a full GPU

View File

@ -24,6 +24,7 @@ import (
"github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils"
)
@ -32,7 +33,7 @@ type nvsandboxutilsDGPU struct {
uuid string
devRoot string
isMig bool
hookCreator discover.HookCreator
hookCreator hooks.HookCreator
deviceLinks []string
}
@ -112,9 +113,14 @@ func (d *nvsandboxutilsDGPU) Hooks() ([]discover.Hook, error) {
return nil, nil
}
hook := d.hookCreator.Create("create-symlinks", d.deviceLinks...)
hook := d.hookCreator.Create(hooks.CreateSymlinks, d.deviceLinks...)
return hook.Hooks()
return []discover.Hook{
{
Lifecycle: hook.Lifecycle,
Path: hook.Path,
Args: hook.Args,
}}, nil
}
func (d *nvsandboxutilsDGPU) Mounts() ([]discover.Mount, error) {

View File

@ -17,7 +17,7 @@
package dgpu
import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils"
@ -26,7 +26,7 @@ import (
type options struct {
logger logger.Interface
devRoot string
hookCreator discover.HookCreator
hookCreator hooks.HookCreator
isMigDevice bool
// migCaps stores the MIG capabilities for the system.
@ -54,7 +54,7 @@ func WithLogger(logger logger.Interface) Option {
}
// WithHookCreator sets the hook creator for the library
func WithHookCreator(hookCreator discover.HookCreator) Option {
func WithHookCreator(hookCreator hooks.HookCreator) Option {
return func(l *options) {
l.hookCreator = hookCreator
}

View File

@ -24,6 +24,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
@ -181,7 +182,7 @@ func TestDiscovererFromCSVFiles(t *testing.T) {
},
}
hookCreator := discover.NewHookCreator("/usr/bin/nvidia-cdi-hook")
hookCreator := hooks.NewHookCreator("/usr/bin/nvidia-cdi-hook", nil)
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
defer setGetTargetsFromCSVFiles(tc.moutSpecs)()

View File

@ -20,6 +20,7 @@ import (
"fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
)
@ -27,7 +28,7 @@ import (
type symlinkHook struct {
discover.None
logger logger.Interface
hookCreator discover.HookCreator
hookCreator hooks.HookCreator
targets []string
// The following can be overridden for testing
@ -48,7 +49,17 @@ func (o tegraOptions) createCSVSymlinkHooks(targets []string) discover.Discover
// Hooks returns a hook to create the symlinks from the required CSV files
func (d symlinkHook) Hooks() ([]discover.Hook, error) {
return d.hookCreator.Create("create-symlinks", d.getCSVFileSymlinks()...).Hooks()
hook := d.hookCreator.Create(hooks.CreateSymlinks, d.getCSVFileSymlinks()...)
if hook == nil {
return nil, nil
}
return []discover.Hook{
{
Lifecycle: hook.Lifecycle,
Path: hook.Path,
Args: hook.Args,
}}, nil
}
// getSymlinkCandidates returns a list of symlinks that are candidates for being created.

View File

@ -20,6 +20,7 @@ import (
"fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks"
@ -30,7 +31,7 @@ type tegraOptions struct {
csvFiles []string
driverRoot string
devRoot string
hookCreator discover.HookCreator
hookCreator hooks.HookCreator
ldconfigPath string
librarySearchPaths []string
ignorePatterns ignoreMountSpecPatterns
@ -134,7 +135,7 @@ func WithCSVFiles(csvFiles []string) Option {
}
// WithHookCreator sets the hook creator for the discoverer.
func WithHookCreator(hookCreator discover.HookCreator) Option {
func WithHookCreator(hookCreator hooks.HookCreator) Option {
return func(o *tegraOptions) {
o.hookCreator = hookCreator
}

View File

@ -21,7 +21,7 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
@ -75,7 +75,7 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
return nil, err
}
hookCreator := discover.NewHookCreator(cfg.NVIDIACTKConfig.Path)
hookCreator := hooks.NewHookCreator(cfg.NVIDIACTKConfig.Path, nil)
mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image)
// We update the mode here so that we can continue passing just the config to other functions.

View File

@ -35,13 +35,3 @@ type Interface interface {
GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) ([]specs.Device, error)
GetDeviceSpecsByID(...string) ([]specs.Device, error)
}
// A HookName refers to one of the predefined set of CDI hooks that may be
// included in the generated CDI specification.
type HookName string
const (
// HookEnableCudaCompat refers to the hook used to enable CUDA Forward Compatibility.
// This was added with v1.17.5 of the NVIDIA Container Toolkit.
HookEnableCudaCompat = HookName("enable-cuda-compat")
)

View File

@ -106,11 +106,9 @@ func (l *nvcdilib) NewDriverLibraryDiscoverer(version string) (discover.Discover
)
discoverers = append(discoverers, driverDotSoSymlinksDiscoverer)
if l.HookIsSupported(HookEnableCudaCompat) {
// TODO: The following should use the version directly.
cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, l.driver)
discoverers = append(discoverers, cudaCompatLibHookDiscoverer)
}
updateLDCache, _ := discover.NewLDCacheUpdateHook(l.logger, libraries, l.hookCreator, l.ldconfigPath)
discoverers = append(discoverers, updateLDCache)

View File

@ -22,6 +22,7 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/dxcore"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
)
@ -39,7 +40,7 @@ var requiredDriverStoreFiles = []string{
}
// newWSLDriverDiscoverer returns a Discoverer for WSL2 drivers.
func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, hookCreator discover.HookCreator, ldconfigPath string) (discover.Discover, error) {
func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, hookCreator hooks.HookCreator, ldconfigPath string) (discover.Discover, error) {
err := dxcore.Init()
if err != nil {
return nil, fmt.Errorf("failed to initialize dxcore: %v", err)
@ -60,7 +61,7 @@ func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, hookCrea
}
// newWSLDriverStoreDiscoverer returns a Discoverer for WSL2 drivers in the driver store associated with a dxcore adapter.
func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, hookCreator discover.HookCreator, ldconfigPath string, driverStorePaths []string) (discover.Discover, error) {
func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, hookCreator hooks.HookCreator, ldconfigPath string, driverStorePaths []string) (discover.Discover, error) {
var searchPaths []string
seen := make(map[string]bool)
for _, path := range driverStorePaths {
@ -108,7 +109,7 @@ type nvidiaSMISimlinkHook struct {
discover.None
logger logger.Interface
mountsFrom discover.Discover
hookCreator discover.HookCreator
hookCreator hooks.HookCreator
}
// Hooks returns a hook that creates a symlink to nvidia-smi in the driver store.
@ -135,7 +136,13 @@ func (m nvidiaSMISimlinkHook) Hooks() ([]discover.Hook, error) {
}
link := "/usr/bin/nvidia-smi"
links := []string{fmt.Sprintf("%s::%s", target, link)}
symlinkHook := m.hookCreator.Create("create-symlinks", links...)
symlinkHook := m.hookCreator.Create(hooks.CreateSymlinks, links...)
return symlinkHook.Hooks()
return []discover.Hook{
{
Lifecycle: symlinkHook.Lifecycle,
Path: symlinkHook.Path,
Args: symlinkHook.Args,
},
}, nil
}

View File

@ -23,13 +23,14 @@ import (
"github.com/stretchr/testify/require"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
testlog "github.com/sirupsen/logrus/hooks/test"
)
func TestNvidiaSMISymlinkHook(t *testing.T) {
logger, _ := testlog.NewNullLogger()
hookCreator := discover.NewHookCreator("nvidia-cdi-hook")
hookCreator := hooks.NewHookCreator("nvidia-cdi-hook", nil)
errMounts := errors.New("mounts error")

View File

@ -1,30 +0,0 @@
/**
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvcdi
// disabledHooks allows individual hooks to be disabled.
type disabledHooks map[HookName]bool
// HookIsSupported checks whether a hook of the specified name is supported.
// Hooks must be explicitly disabled, meaning that if no disabled hooks are
// all hooks are supported.
func (l *nvcdilib) HookIsSupported(h HookName) bool {
if len(l.disabledHooks) == 0 {
return true
}
return !l.disabledHooks[h]
}

View File

@ -23,7 +23,7 @@ import (
"github.com/NVIDIA/go-nvlib/pkg/nvlib/info"
"github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils"
@ -56,14 +56,14 @@ type nvcdilib struct {
mergedDeviceOptions []transform.MergedDeviceOption
disabledHooks disabledHooks
hookCreator discover.HookCreator
disabledHooks hooks.DisabledHooks
hookCreator hooks.HookCreator
}
// New creates a new nvcdi library
func New(opts ...Option) (Interface, error) {
l := &nvcdilib{
disabledHooks: make(disabledHooks),
disabledHooks: make(hooks.DisabledHooks),
}
for _, opt := range opts {
opt(l)
@ -81,9 +81,6 @@ func New(opts ...Option) (Interface, error) {
if l.nvidiaCDIHookPath == "" {
l.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook"
}
// create hookCreator
l.hookCreator = discover.NewHookCreator(l.nvidiaCDIHookPath)
if l.driverRoot == "" {
l.driverRoot = "/"
}
@ -150,7 +147,7 @@ func New(opts ...Option) (Interface, error) {
l.vendor = "management.nvidia.com"
}
// Management containers in general do not require CUDA Forward compatibility.
l.disabledHooks[HookEnableCudaCompat] = true
l.disabledHooks[hooks.EnableCudaCompat] = true
lib = (*managementlib)(l)
case ModeNvml:
lib = (*nvmllib)(l)
@ -175,6 +172,8 @@ func New(opts ...Option) (Interface, error) {
return nil, fmt.Errorf("unknown mode %q", l.mode)
}
l.hookCreator = hooks.NewHookCreator(l.nvidiaCDIHookPath, l.disabledHooks)
w := wrapper{
Interface: lib,
vendor: l.vendor,

View File

@ -21,6 +21,7 @@ import (
"github.com/NVIDIA/go-nvlib/pkg/nvlib/info"
"github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform"
)
@ -158,10 +159,10 @@ func WithLibrarySearchPaths(paths []string) Option {
// WithDisabledHook allows specific hooks to the disabled.
// This option can be specified multiple times for each hook.
func WithDisabledHook(hook HookName) Option {
func WithDisabledHook(hook hooks.HookName) Option {
return func(o *nvcdilib) {
if o.disabledHooks == nil {
o.disabledHooks = make(map[HookName]bool)
o.disabledHooks = make(hooks.DisabledHooks)
}
o.disabledHooks[hook] = true
}

View File

@ -21,6 +21,7 @@ import (
"path/filepath"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)
@ -28,7 +29,7 @@ type deviceFolderPermissions struct {
logger logger.Interface
devRoot string
devices discover.Discover
hookCreator discover.HookCreator
hookCreator hooks.HookCreator
}
var _ discover.Discover = (*deviceFolderPermissions)(nil)
@ -39,7 +40,7 @@ var _ discover.Discover = (*deviceFolderPermissions)(nil)
// The nested devices that are applicable to the NVIDIA GPU devices are:
// - DRM devices at /dev/dri/*
// - NVIDIA Caps devices at /dev/nvidia-caps/*
func newDeviceFolderPermissionHookDiscoverer(logger logger.Interface, devRoot string, hookCreator discover.HookCreator, devices discover.Discover) discover.Discover {
func newDeviceFolderPermissionHookDiscoverer(logger logger.Interface, devRoot string, hookCreator hooks.HookCreator, devices discover.Discover) discover.Discover {
d := &deviceFolderPermissions{
logger: logger,
devRoot: devRoot,
@ -72,7 +73,13 @@ func (d *deviceFolderPermissions) Hooks() ([]discover.Hook, error) {
hook := d.hookCreator.Create("chmod", args...)
return []discover.Hook{*hook}, nil
return []discover.Hook{
{
Lifecycle: hook.Lifecycle,
Path: hook.Path,
Args: hook.Args,
},
}, nil
}
func (d *deviceFolderPermissions) getDeviceSubfolders() ([]string, error) {