Add --disable-hook flag to cdi generate command
Some checks are pending
CI Pipeline / code-scanning (push) Waiting to run
CI Pipeline / variables (push) Waiting to run
CI Pipeline / golang (push) Waiting to run
CI Pipeline / image (push) Blocked by required conditions
CI Pipeline / e2e-test (push) Blocked by required conditions

When running the nvidia-ctk cdi generate command, a user should be able to opt out of specific hooks.
We propose to add a flag --disable-hook that will take a comma-separated list of hooks that will be skipped when creating the CDI spec.

Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
This commit is contained in:
Carlos Eduardo Arango Gutierrez 2025-05-22 13:51:15 +02:00
parent aa696ef919
commit f095bd5070
No known key found for this signature in database
GPG Key ID: 42D9CB42F300A852
13 changed files with 225 additions and 45 deletions

View File

@ -57,6 +57,7 @@ type options struct {
configSearchPaths cli.StringSlice
librarySearchPaths cli.StringSlice
disabledHooks cli.StringSlice
csv struct {
files cli.StringSlice
@ -176,6 +177,13 @@ func (m command) build() *cli.Command {
Usage: "Specify a pattern the CSV mount specifications.",
Destination: &opts.csv.ignorePatterns,
},
&cli.StringSliceFlag{
Name: "disable-hook",
Aliases: []string{"disable-hooks"},
Usage: "Hook to skip when generating the CDI specification. Can be specified multiple times. Can be a comma-separated list of hooks or a single hook name.",
Value: cli.NewStringSlice(),
Destination: &opts.disabledHooks,
},
}
return &c
@ -262,7 +270,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
deviceNamers = append(deviceNamers, deviceNamer)
}
cdilib, err := nvcdi.New(
initOpts := []nvcdi.Option{
nvcdi.WithLogger(m.logger),
nvcdi.WithDriverRoot(opts.driverRoot),
nvcdi.WithDevRoot(opts.devRoot),
@ -276,7 +284,13 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
nvcdi.WithCSVIgnorePatterns(opts.csv.ignorePatterns.Value()),
// We set the following to allow for dependency injection:
nvcdi.WithNvmlLib(opts.nvmllib),
)
}
for _, hook := range opts.disabledHooks.Value() {
initOpts = append(initOpts, nvcdi.WithDisabledHook(hook))
}
cdilib, err := nvcdi.New(initOpts...)
if err != nil {
return nil, fmt.Errorf("failed to create CDI library: %v", err)
}

View File

@ -26,6 +26,7 @@ import (
"github.com/NVIDIA/go-nvml/pkg/nvml/mock/dgxa100"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
"github.com/urfave/cli/v2"
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
)
@ -36,6 +37,9 @@ func TestGenerateSpec(t *testing.T) {
require.NoError(t, err)
driverRoot := filepath.Join(moduleRoot, "testdata", "lookup", "rootfs-1")
disableHook1 := cli.NewStringSlice("enable-cuda-compat")
disableHook2 := cli.NewStringSlice("enable-cuda-compat", "update-ldcache")
disableHook3 := cli.NewStringSlice("all")
logger, _ := testlog.NewNullLogger()
testCases := []struct {
@ -113,6 +117,179 @@ containerEdits:
- nodev
- rbind
- rprivate
`,
},
{
description: "disableHooks1",
options: options{
format: "yaml",
mode: "nvml",
vendor: "example.com",
class: "device",
driverRoot: driverRoot,
disabledHooks: *disableHook1,
},
expectedOptions: options{
format: "yaml",
mode: "nvml",
vendor: "example.com",
class: "device",
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
driverRoot: driverRoot,
disabledHooks: *disableHook1,
},
expectedSpec: `---
cdiVersion: 0.5.0
kind: example.com/device
devices:
- name: "0"
containerEdits:
deviceNodes:
- path: /dev/nvidia0
hostPath: {{ .driverRoot }}/dev/nvidia0
- name: all
containerEdits:
deviceNodes:
- path: /dev/nvidia0
hostPath: {{ .driverRoot }}/dev/nvidia0
containerEdits:
env:
- NVIDIA_VISIBLE_DEVICES=void
deviceNodes:
- path: /dev/nvidiactl
hostPath: {{ .driverRoot }}/dev/nvidiactl
hooks:
- hookName: createContainer
path: /usr/bin/nvidia-cdi-hook
args:
- nvidia-cdi-hook
- create-symlinks
- --link
- libcuda.so.1::/lib/x86_64-linux-gnu/libcuda.so
- hookName: createContainer
path: /usr/bin/nvidia-cdi-hook
args:
- nvidia-cdi-hook
- update-ldcache
- --folder
- /lib/x86_64-linux-gnu
mounts:
- hostPath: {{ .driverRoot }}/lib/x86_64-linux-gnu/libcuda.so.999.88.77
containerPath: /lib/x86_64-linux-gnu/libcuda.so.999.88.77
options:
- ro
- nosuid
- nodev
- rbind
- rprivate
`,
},
{
description: "disableHooks2",
options: options{
format: "yaml",
mode: "nvml",
vendor: "example.com",
class: "device",
driverRoot: driverRoot,
disabledHooks: *disableHook2,
},
expectedOptions: options{
format: "yaml",
mode: "nvml",
vendor: "example.com",
class: "device",
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
driverRoot: driverRoot,
disabledHooks: *disableHook2,
},
expectedSpec: `---
cdiVersion: 0.5.0
kind: example.com/device
devices:
- name: "0"
containerEdits:
deviceNodes:
- path: /dev/nvidia0
hostPath: {{ .driverRoot }}/dev/nvidia0
- name: all
containerEdits:
deviceNodes:
- path: /dev/nvidia0
hostPath: {{ .driverRoot }}/dev/nvidia0
containerEdits:
env:
- NVIDIA_VISIBLE_DEVICES=void
deviceNodes:
- path: /dev/nvidiactl
hostPath: {{ .driverRoot }}/dev/nvidiactl
hooks:
- hookName: createContainer
path: /usr/bin/nvidia-cdi-hook
args:
- nvidia-cdi-hook
- create-symlinks
- --link
- libcuda.so.1::/lib/x86_64-linux-gnu/libcuda.so
mounts:
- hostPath: {{ .driverRoot }}/lib/x86_64-linux-gnu/libcuda.so.999.88.77
containerPath: /lib/x86_64-linux-gnu/libcuda.so.999.88.77
options:
- ro
- nosuid
- nodev
- rbind
- rprivate
`,
},
{
description: "disableHooksAll",
options: options{
format: "yaml",
mode: "nvml",
vendor: "example.com",
class: "device",
driverRoot: driverRoot,
disabledHooks: *disableHook3,
},
expectedOptions: options{
format: "yaml",
mode: "nvml",
vendor: "example.com",
class: "device",
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
driverRoot: driverRoot,
disabledHooks: *disableHook3,
},
expectedSpec: `---
cdiVersion: 0.5.0
kind: example.com/device
devices:
- name: "0"
containerEdits:
deviceNodes:
- path: /dev/nvidia0
hostPath: {{ .driverRoot }}/dev/nvidia0
- name: all
containerEdits:
deviceNodes:
- path: /dev/nvidia0
hostPath: {{ .driverRoot }}/dev/nvidia0
containerEdits:
env:
- NVIDIA_VISIBLE_DEVICES=void
deviceNodes:
- path: /dev/nvidiactl
hostPath: {{ .driverRoot }}/dev/nvidiactl
mounts:
- hostPath: {{ .driverRoot }}/lib/x86_64-linux-gnu/libcuda.so.999.88.77
containerPath: /lib/x86_64-linux-gnu/libcuda.so.999.88.77
options:
- ro
- nosuid
- nodev
- rbind
- rprivate
`,
},
}

View File

@ -25,7 +25,7 @@ import (
func TestGraphicsLibrariesDiscoverer(t *testing.T) {
logger, _ := testlog.NewNullLogger()
hookCreator := NewHookCreator("/usr/bin/nvidia-cdi-hook")
hookCreator := NewHookCreator("/usr/bin/nvidia-cdi-hook", make(DisabledHooks))
testCases := []struct {
description string

View File

@ -49,6 +49,17 @@ type HookName string
// DisabledHooks allows individual hooks to be disabled.
type DisabledHooks map[HookName]bool
func (d DisabledHooks) Set(value HookName) {
if value == "all" {
for _, hook := range AllHooks {
d[hook] = true
}
return
}
d[value] = true
}
const (
// HookEnableCudaCompat refers to the hook used to enable CUDA Forward Compatibility.
// This was added with v1.17.5 of the NVIDIA Container Toolkit.
@ -72,21 +83,27 @@ type Option func(*CDIHook)
type CDIHook struct {
nvidiaCDIHookPath string
disabledHooks DisabledHooks
}
type HookCreator interface {
Create(HookName, ...string) *Hook
}
func NewHookCreator(nvidiaCDIHookPath string) HookCreator {
func NewHookCreator(nvidiaCDIHookPath string, disabledHooks DisabledHooks) HookCreator {
CDIHook := &CDIHook{
nvidiaCDIHookPath: nvidiaCDIHookPath,
disabledHooks: disabledHooks,
}
return CDIHook
}
func (c CDIHook) Create(name HookName, args ...string) *Hook {
if c.disabledHooks[name] {
return nil
}
if name == "create-symlinks" {
if len(args) == 0 {
return nil

View File

@ -31,7 +31,7 @@ const (
func TestLDCacheUpdateHook(t *testing.T) {
logger, _ := testlog.NewNullLogger()
hookCreator := NewHookCreator(testNvidiaCDIHookPath)
hookCreator := NewHookCreator(testNvidiaCDIHookPath, make(DisabledHooks))
testCases := []struct {
description string

View File

@ -306,7 +306,7 @@ func TestWithWithDriverDotSoSymlinks(t *testing.T) {
},
}
hookCreator := NewHookCreator("/path/to/nvidia-cdi-hook")
hookCreator := NewHookCreator("/path/to/nvidia-cdi-hook", make(DisabledHooks))
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
d := WithDriverDotSoSymlinks(

View File

@ -181,7 +181,7 @@ func TestDiscovererFromCSVFiles(t *testing.T) {
},
}
hookCreator := discover.NewHookCreator("/usr/bin/nvidia-cdi-hook")
hookCreator := discover.NewHookCreator("/usr/bin/nvidia-cdi-hook", make(discover.DisabledHooks))
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
defer setGetTargetsFromCSVFiles(tc.moutSpecs)()

View File

@ -75,7 +75,7 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
return nil, err
}
hookCreator := discover.NewHookCreator(cfg.NVIDIACTKConfig.Path)
hookCreator := discover.NewHookCreator(cfg.NVIDIACTKConfig.Path, make(discover.DisabledHooks))
mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image)
// We update the mode here so that we can continue passing just the config to other functions.

View File

@ -106,11 +106,9 @@ func (l *nvcdilib) NewDriverLibraryDiscoverer(version string) (discover.Discover
)
discoverers = append(discoverers, driverDotSoSymlinksDiscoverer)
if l.HookIsSupported(HookEnableCudaCompat) {
// TODO: The following should use the version directly.
cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, l.driver)
discoverers = append(discoverers, cudaCompatLibHookDiscoverer)
}
// TODO: The following should use the version directly.
cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, l.driver)
discoverers = append(discoverers, cudaCompatLibHookDiscoverer)
updateLDCache, _ := discover.NewLDCacheUpdateHook(l.logger, libraries, l.hookCreator, l.ldconfigPath)
discoverers = append(discoverers, updateLDCache)

View File

@ -29,7 +29,7 @@ import (
func TestNvidiaSMISymlinkHook(t *testing.T) {
logger, _ := testlog.NewNullLogger()
hookCreator := discover.NewHookCreator("nvidia-cdi-hook")
hookCreator := discover.NewHookCreator("nvidia-cdi-hook", make(discover.DisabledHooks))
errMounts := errors.New("mounts error")

View File

@ -1,27 +0,0 @@
/**
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvcdi
// HookIsSupported checks whether a hook of the specified name is supported.
// Hooks must be explicitly disabled, meaning that if no disabled hooks are
// all hooks are supported.
func (l *nvcdilib) HookIsSupported(h HookName) bool {
if len(l.disabledHooks) == 0 {
return true
}
return !l.disabledHooks[h]
}

View File

@ -81,8 +81,6 @@ func New(opts ...Option) (Interface, error) {
if l.nvidiaCDIHookPath == "" {
l.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook"
}
// create hookCreator
l.hookCreator = discover.NewHookCreator(l.nvidiaCDIHookPath)
if l.driverRoot == "" {
l.driverRoot = "/"
@ -175,6 +173,9 @@ func New(opts ...Option) (Interface, error) {
return nil, fmt.Errorf("unknown mode %q", l.mode)
}
// create hookCreator
l.hookCreator = discover.NewHookCreator(l.nvidiaCDIHookPath, l.disabledHooks)
w := wrapper{
Interface: lib,
vendor: l.vendor,

View File

@ -158,11 +158,11 @@ func WithLibrarySearchPaths(paths []string) Option {
// WithDisabledHook allows specific hooks to the disabled.
// This option can be specified multiple times for each hook.
func WithDisabledHook(hook HookName) Option {
func WithDisabledHook[T string | HookName](hook T) Option {
return func(o *nvcdilib) {
if o.disabledHooks == nil {
o.disabledHooks = make(map[HookName]bool)
}
o.disabledHooks[hook] = true
o.disabledHooks.Set(HookName(hook))
}
}