diff --git a/internal/discover/compat_libs.go b/internal/discover/compat_libs.go index 7e7f9ff4..3f57d360 100644 --- a/internal/discover/compat_libs.go +++ b/internal/discover/compat_libs.go @@ -3,18 +3,48 @@ package discover import ( "strings" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" ) +// cudaCompatHook is a discoverer for the enable-cuda-compat hook. +type cudaCompatHook struct { + hooks.Hook +} + // NewCUDACompatHookDiscoverer creates a discoverer for a enable-cuda-compat hook. // This hook is responsible for setting up CUDA compatibility in the container and depends on the host driver version. -func NewCUDACompatHookDiscoverer(logger logger.Interface, hookCreator HookCreator, driver *root.Driver) Discover { +func NewCUDACompatHookDiscoverer(logger logger.Interface, hookCreator hooks.HookCreator, driver *root.Driver) Discover { _, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver) var args []string if !strings.Contains(cudaVersionPattern, "*") { args = append(args, "--host-driver-version="+cudaVersionPattern) } - return hookCreator.Create("enable-cuda-compat", args...) + hook := hookCreator.Create(hooks.EnableCudaCompat, args...) + if hook == nil { + return nil + } + + return &cudaCompatHook{ + Hook: *hook, + } +} + +func (h *cudaCompatHook) Hooks() ([]Hook, error) { + return []Hook{ + { + Lifecycle: h.Lifecycle, + Path: h.Path, + Args: h.Args, + }}, nil +} + +func (h *cudaCompatHook) Devices() ([]Device, error) { + return nil, nil +} + +func (h *cudaCompatHook) Mounts() ([]Mount, error) { + return nil, nil } diff --git a/internal/discover/graphics.go b/internal/discover/graphics.go index 4665ce29..4ce1aefe 100644 --- a/internal/discover/graphics.go +++ b/internal/discover/graphics.go @@ -23,6 +23,7 @@ import ( "strings" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" "github.com/NVIDIA/nvidia-container-toolkit/internal/info/drm" "github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" @@ -36,7 +37,7 @@ import ( // TODO: The logic for creating DRM devices should be consolidated between this // and the logic for generating CDI specs for a single device. This is only used // when applying OCI spec modifications to an incoming spec in "legacy" mode. -func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, hookCreator HookCreator) (Discover, error) { +func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, hookCreator hooks.HookCreator) (Discover, error) { drmDeviceNodes, err := newDRMDeviceDiscoverer(logger, devices, devRoot) if err != nil { return nil, fmt.Errorf("failed to create DRM device discoverer: %v", err) @@ -49,7 +50,7 @@ func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices } // NewGraphicsMountsDiscoverer creates a discoverer for the mounts required by graphics tools such as vulkan. -func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator HookCreator) (Discover, error) { +func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator hooks.HookCreator) (Discover, error) { libraries := newGraphicsLibrariesDiscoverer(logger, driver, hookCreator) configs := NewMounts( @@ -96,12 +97,12 @@ func newVulkanConfigsDiscover(logger logger.Interface, driver *root.Driver) Disc type graphicsDriverLibraries struct { Discover logger logger.Interface - hookCreator HookCreator + hookCreator hooks.HookCreator } var _ Discover = (*graphicsDriverLibraries)(nil) -func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator HookCreator) Discover { +func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator hooks.HookCreator) Discover { cudaLibRoot, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver) libraries := NewMounts( @@ -203,9 +204,15 @@ func (d graphicsDriverLibraries) Hooks() ([]Hook, error) { return nil, nil } - hook := d.hookCreator.Create("create-symlinks", links...) + hook := d.hookCreator.Create(hooks.CreateSymlinks, links...) - return hook.Hooks() + return []Hook{ + { + Lifecycle: hook.Lifecycle, + Path: hook.Path, + Args: hook.Args, + }, + }, nil } // isDriverLibrary checks whether the specified filename is a specific driver library. @@ -276,13 +283,13 @@ func buildXOrgSearchPaths(libRoot string) []string { type drmDevicesByPath struct { None logger logger.Interface - hookCreator HookCreator + hookCreator hooks.HookCreator devRoot string devicesFrom Discover } // newCreateDRMByPathSymlinks creates a discoverer for a hook to create the by-path symlinks for DRM devices discovered by the specified devices discoverer -func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, hookCreator HookCreator) Discover { +func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, hookCreator hooks.HookCreator) Discover { d := drmDevicesByPath{ logger: logger, hookCreator: hookCreator, @@ -315,9 +322,15 @@ func (d drmDevicesByPath) Hooks() ([]Hook, error) { args = append(args, "--link", l) } - hook := d.hookCreator.Create("create-symlinks", args...) + hook := d.hookCreator.Create(hooks.CreateSymlinks, args...) - return hook.Hooks() + return []Hook{ + { + Lifecycle: hook.Lifecycle, + Path: hook.Path, + Args: hook.Args, + }, + }, nil } // getSpecificLinkArgs returns the required specific links that need to be created diff --git a/internal/discover/graphics_test.go b/internal/discover/graphics_test.go index 3aea93cb..0adb1337 100644 --- a/internal/discover/graphics_test.go +++ b/internal/discover/graphics_test.go @@ -19,13 +19,15 @@ package discover import ( "testing" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" + testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" ) func TestGraphicsLibrariesDiscoverer(t *testing.T) { logger, _ := testlog.NewNullLogger() - hookCreator := NewHookCreator("/usr/bin/nvidia-cdi-hook") + hookCreator := hooks.NewHookCreator("/usr/bin/nvidia-cdi-hook", nil) testCases := []struct { description string diff --git a/internal/discover/hooks.go b/internal/discover/hooks.go deleted file mode 100644 index 0f239bfd..00000000 --- a/internal/discover/hooks.go +++ /dev/null @@ -1,92 +0,0 @@ -/** -# Copyright (c) NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -**/ - -package discover - -import ( - "path/filepath" - - "tags.cncf.io/container-device-interface/pkg/cdi" -) - -var _ Discover = (*Hook)(nil) - -// Devices returns an empty list of devices for a Hook discoverer. -func (h *Hook) Devices() ([]Device, error) { - return nil, nil -} - -// Mounts returns an empty list of mounts for a Hook discoverer. -func (h *Hook) Mounts() ([]Mount, error) { - return nil, nil -} - -// Hooks allows the Hook type to also implement the Discoverer interface. -// It returns a single hook -func (h *Hook) Hooks() ([]Hook, error) { - if h == nil { - return nil, nil - } - - return []Hook{*h}, nil -} - -// Option is a function that configures the nvcdilib -type Option func(*CDIHook) - -type CDIHook struct { - nvidiaCDIHookPath string -} - -type HookCreator interface { - Create(string, ...string) *Hook -} - -func NewHookCreator(nvidiaCDIHookPath string) HookCreator { - CDIHook := &CDIHook{ - nvidiaCDIHookPath: nvidiaCDIHookPath, - } - - return CDIHook -} - -func (c CDIHook) Create(name string, args ...string) *Hook { - if name == "create-symlinks" { - if len(args) == 0 { - return nil - } - - links := []string{} - for _, arg := range args { - links = append(links, "--link", arg) - } - args = links - } - - return &Hook{ - Lifecycle: cdi.CreateContainerHook, - Path: c.nvidiaCDIHookPath, - Args: append(c.requiredArgs(name), args...), - } -} - -func (c CDIHook) requiredArgs(name string) []string { - base := filepath.Base(c.nvidiaCDIHookPath) - if base == "nvidia-ctk" { - return []string{base, "hook", name} - } - return []string{base, name} -} diff --git a/internal/discover/ldconfig.go b/internal/discover/ldconfig.go index 3fab927a..54a167b2 100644 --- a/internal/discover/ldconfig.go +++ b/internal/discover/ldconfig.go @@ -21,11 +21,12 @@ import ( "path/filepath" "strings" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" ) // NewLDCacheUpdateHook creates a discoverer that updates the ldcache for the specified mounts. A logger can also be specified -func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, hookCreator HookCreator, ldconfigPath string) (Discover, error) { +func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, hookCreator hooks.HookCreator, ldconfigPath string) (Discover, error) { d := ldconfig{ logger: logger, hookCreator: hookCreator, @@ -39,7 +40,7 @@ func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, hookCreator type ldconfig struct { None logger logger.Interface - hookCreator HookCreator + hookCreator hooks.HookCreator ldconfigPath string mountsFrom Discover } @@ -57,11 +58,17 @@ func (d ldconfig) Hooks() ([]Hook, error) { getLibraryPaths(mounts), ) - return h.Hooks() + return []Hook{ + { + Lifecycle: h.Lifecycle, + Path: h.Path, + Args: h.Args, + }, + }, nil } // createLDCacheUpdateHook locates the NVIDIA Container Toolkit CLI and creates a hook for updating the LD Cache -func createLDCacheUpdateHook(hookCreator HookCreator, ldconfig string, libraries []string) *Hook { +func createLDCacheUpdateHook(hookCreator hooks.HookCreator, ldconfig string, libraries []string) *Hook { var args []string if ldconfig != "" { @@ -72,7 +79,12 @@ func createLDCacheUpdateHook(hookCreator HookCreator, ldconfig string, libraries args = append(args, "--folder", f) } - return hookCreator.Create("update-ldcache", args...) + h := hookCreator.Create(hooks.UpdateLDCache, args...) + return &Hook{ + Lifecycle: h.Lifecycle, + Path: h.Path, + Args: h.Args, + } } // getLibraryPaths extracts the library dirs from the specified mounts diff --git a/internal/discover/ldconfig_test.go b/internal/discover/ldconfig_test.go index ddbda4cc..11b32e78 100644 --- a/internal/discover/ldconfig_test.go +++ b/internal/discover/ldconfig_test.go @@ -20,6 +20,8 @@ import ( "fmt" "testing" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" + testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" ) @@ -31,7 +33,7 @@ const ( func TestLDCacheUpdateHook(t *testing.T) { logger, _ := testlog.NewNullLogger() - hookCreator := NewHookCreator(testNvidiaCDIHookPath) + hookCreator := hooks.NewHookCreator(testNvidiaCDIHookPath, nil) testCases := []struct { description string diff --git a/internal/discover/symlinks.go b/internal/discover/symlinks.go index a9cd811a..e33bd017 100644 --- a/internal/discover/symlinks.go +++ b/internal/discover/symlinks.go @@ -19,17 +19,19 @@ package discover import ( "fmt" "path/filepath" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" ) type additionalSymlinks struct { Discover version string - hookCreator HookCreator + hookCreator hooks.HookCreator } // WithDriverDotSoSymlinks decorates the provided discoverer. // A hook is added that checks for specific driver symlinks that need to be created. -func WithDriverDotSoSymlinks(mounts Discover, version string, hookCreator HookCreator) Discover { +func WithDriverDotSoSymlinks(mounts Discover, version string, hookCreator hooks.HookCreator) Discover { if version == "" { version = "*.*" } @@ -46,7 +48,7 @@ func (d *additionalSymlinks) Hooks() ([]Hook, error) { if err != nil { return nil, fmt.Errorf("failed to get library mounts: %v", err) } - hooks, err := d.Discover.Hooks() + h, err := d.Discover.Hooks() if err != nil { return nil, fmt.Errorf("failed to get hooks: %v", err) } @@ -70,15 +72,16 @@ func (d *additionalSymlinks) Hooks() ([]Hook, error) { } if len(links) == 0 { - return hooks, nil + return h, nil } - createSymlinkHooks, err := d.hookCreator.Create("create-symlinks", links...).Hooks() - if err != nil { - return nil, fmt.Errorf("failed to create symlink hook: %v", err) - } + createSymlinkHook := d.hookCreator.Create(hooks.CreateSymlinks, links...) - return append(hooks, createSymlinkHooks...), nil + return append(h, Hook{ + Lifecycle: createSymlinkHook.Lifecycle, + Path: createSymlinkHook.Path, + Args: createSymlinkHook.Args, + }), nil } // getLinksForMount maps the path to created links if any. diff --git a/internal/discover/symlinks_test.go b/internal/discover/symlinks_test.go index 2a6c9812..b781f6d8 100644 --- a/internal/discover/symlinks_test.go +++ b/internal/discover/symlinks_test.go @@ -19,6 +19,8 @@ package discover import ( "testing" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" + "github.com/stretchr/testify/require" ) @@ -306,7 +308,7 @@ func TestWithWithDriverDotSoSymlinks(t *testing.T) { }, } - hookCreator := NewHookCreator("/path/to/nvidia-cdi-hook") + hookCreator := hooks.NewHookCreator("/path/to/nvidia-cdi-hook", nil) for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { d := WithDriverDotSoSymlinks( diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go new file mode 100644 index 00000000..d2bc3029 --- /dev/null +++ b/internal/hooks/hooks.go @@ -0,0 +1,116 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package hooks + +import ( + "path/filepath" + + "tags.cncf.io/container-device-interface/pkg/cdi" +) + +// A HookName refers to one of the predefined set of CDI hooks that may be +// included in the generated CDI specification. +type HookName string + +// DisabledHooks allows individual hooks to be disabled. +type DisabledHooks map[HookName]bool + +const ( + // EnableCudaCompat refers to the hook used to enable CUDA Forward Compatibility. + // This was added with v1.17.5 of the NVIDIA Container Toolkit. + EnableCudaCompat = HookName("enable-cuda-compat") + // CreateSymlinks refers to the hook used to create symlinks for the NVIDIA + // Container Toolkit. This was added with v1.17.5 of the NVIDIA Container Toolkit. + CreateSymlinks = HookName("create-symlinks") + // UpdateLDCache refers to the hook used to update the LD cache for the NVIDIA + // Container Toolkit. This was added with v1.17.5 of the NVIDIA Container Toolkit. + UpdateLDCache = HookName("update-ldcache") +) + +// Hook represents an OCI container hook. +type Hook struct { + Lifecycle string + Path string + Args []string +} + +// Option is a function that configures the nvcdilib +type Option func(*CDIHook) + +type CDIHook struct { + nvidiaCDIHookPath string + disabledHooks DisabledHooks +} + +type HookCreator interface { + Create(HookName, ...string) *Hook +} + +func NewHookCreator(nvidiaCDIHookPath string, disabledHooks DisabledHooks) HookCreator { + if disabledHooks == nil { + disabledHooks = make(DisabledHooks) + } + + CDIHook := &CDIHook{ + nvidiaCDIHookPath: nvidiaCDIHookPath, + disabledHooks: disabledHooks, + } + + return CDIHook +} + +func (c CDIHook) Create(name HookName, args ...string) *Hook { + if c.disabledHooks[name] { + return nil + } + + if name == CreateSymlinks { + if len(args) == 0 { + return nil + } + + links := []string{} + for _, arg := range args { + links = append(links, "--link", arg) + } + args = links + } + + return &Hook{ + Lifecycle: cdi.CreateContainerHook, + Path: c.nvidiaCDIHookPath, + Args: append(c.requiredArgs(name), args...), + } +} + +func (c CDIHook) requiredArgs(name HookName) []string { + base := filepath.Base(c.nvidiaCDIHookPath) + if base == "nvidia-ctk" { + return []string{base, "hook", string(name)} + } + return []string{base, string(name)} +} + +// IsSupported checks whether a hook of the specified name is supported. +// Hooks must be explicitly disabled, meaning that if no disabled hooks are +// all hooks are supported. +func (c CDIHook) IsSupported(h HookName) bool { + if len(c.disabledHooks) == 0 { + return true + } + return !c.disabledHooks[h] +} diff --git a/internal/hooks/hooks_test.go b/internal/hooks/hooks_test.go new file mode 100644 index 00000000..49848b7c --- /dev/null +++ b/internal/hooks/hooks_test.go @@ -0,0 +1,177 @@ +package hooks + +import ( + "testing" +) + +func TestNewHookCreator(t *testing.T) { + tests := []struct { + name string + hookPath string + disabledHooks DisabledHooks + expectedCreator HookCreator + }{ + { + name: "nil disabled hooks", + hookPath: "/usr/bin/nvidia-ctk", + expectedCreator: &CDIHook{ + nvidiaCDIHookPath: "/usr/bin/nvidia-ctk", + disabledHooks: DisabledHooks{}, + }, + }, + { + name: "with disabled hooks", + hookPath: "/usr/bin/nvidia-ctk", + disabledHooks: DisabledHooks{ + EnableCudaCompat: true, + }, + expectedCreator: &CDIHook{ + nvidiaCDIHookPath: "/usr/bin/nvidia-ctk", + disabledHooks: DisabledHooks{ + EnableCudaCompat: true, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + creator := NewHookCreator(tt.hookPath, tt.disabledHooks) + if creator == nil { + t.Fatal("NewHookCreator returned nil") + } + }) + } +} + +func TestCDIHook_Create(t *testing.T) { + tests := []struct { + name string + hookPath string + disabledHooks DisabledHooks + hookName HookName + args []string + expectedHook *Hook + }{ + { + name: "disabled hook returns nil", + hookPath: "/usr/bin/nvidia-ctk", + disabledHooks: DisabledHooks{ + EnableCudaCompat: true, + }, + hookName: EnableCudaCompat, + expectedHook: nil, + }, + { + name: "create symlinks with no args returns nil", + hookPath: "/usr/bin/nvidia-ctk", + hookName: CreateSymlinks, + expectedHook: nil, + }, + { + name: "create symlinks with args", + hookPath: "/usr/bin/nvidia-ctk", + hookName: CreateSymlinks, + args: []string{"/path/to/lib1", "/path/to/lib2"}, + expectedHook: &Hook{ + Lifecycle: "createContainer", + Path: "/usr/bin/nvidia-ctk", + Args: []string{"nvidia-ctk", "hook", "create-symlinks", "--link", "/path/to/lib1", "--link", "/path/to/lib2"}, + }, + }, + { + name: "enable cuda compat", + hookPath: "/usr/bin/nvidia-ctk", + hookName: EnableCudaCompat, + expectedHook: &Hook{ + Lifecycle: "createContainer", + Path: "/usr/bin/nvidia-ctk", + Args: []string{"nvidia-ctk", "hook", "enable-cuda-compat"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hook := &CDIHook{ + nvidiaCDIHookPath: tt.hookPath, + disabledHooks: tt.disabledHooks, + } + + result := hook.Create(tt.hookName, tt.args...) + if tt.expectedHook == nil { + if result != nil { + t.Errorf("expected nil hook, got %v", result) + } + return + } + + if result == nil { + t.Fatal("expected non-nil hook, got nil") + } + + if result.Lifecycle != tt.expectedHook.Lifecycle { + t.Errorf("expected lifecycle %q, got %q", tt.expectedHook.Lifecycle, result.Lifecycle) + } + + if result.Path != tt.expectedHook.Path { + t.Errorf("expected path %q, got %q", tt.expectedHook.Path, result.Path) + } + + if len(result.Args) != len(tt.expectedHook.Args) { + t.Errorf("expected %d args, got %d", len(tt.expectedHook.Args), len(result.Args)) + return + } + + for i, arg := range tt.expectedHook.Args { + if result.Args[i] != arg { + t.Errorf("expected arg[%d] %q, got %q", i, arg, result.Args[i]) + } + } + }) + } +} + +func TestCDIHook_IsSupported(t *testing.T) { + tests := []struct { + name string + disabledHooks DisabledHooks + hookName HookName + expected bool + }{ + { + name: "no disabled hooks", + hookName: EnableCudaCompat, + expected: true, + }, + { + name: "disabled hook", + disabledHooks: DisabledHooks{ + EnableCudaCompat: true, + }, + hookName: EnableCudaCompat, + expected: false, + }, + { + name: "non-disabled hook", + disabledHooks: DisabledHooks{ + EnableCudaCompat: true, + }, + hookName: CreateSymlinks, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hook := &CDIHook{ + disabledHooks: tt.disabledHooks, + } + + result := hook.IsSupported(tt.hookName) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} diff --git a/internal/modifier/gated.go b/internal/modifier/gated.go index a0239df8..bef0975f 100644 --- a/internal/modifier/gated.go +++ b/internal/modifier/gated.go @@ -22,6 +22,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" @@ -36,7 +37,7 @@ import ( // NVIDIA_GDRCOPY=enabled // // If not devices are selected, no changes are made. -func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) { +func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver, hookCreator hooks.HookCreator) (oci.SpecModifier, error) { if devices := image.VisibleDevicesFromEnvVar(); len(devices) == 0 { logger.Infof("No modification required; no devices requested") return nil, nil @@ -91,7 +92,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image return NewModifierFromDiscoverer(logger, discover.Merge(discoverers...)) } -func getCudaCompatModeDiscoverer(logger logger.Interface, cfg *config.Config, driver *root.Driver, hookCreator discover.HookCreator) (discover.Discover, error) { +func getCudaCompatModeDiscoverer(logger logger.Interface, cfg *config.Config, driver *root.Driver, hookCreator hooks.HookCreator) (discover.Discover, error) { // For legacy mode, we only include the enable-cuda-compat hook if cuda-compat-mode is set to hook. if cfg.NVIDIAContainerRuntimeConfig.Mode == "legacy" && cfg.NVIDIAContainerRuntimeConfig.Modes.Legacy.CUDACompatMode != config.CUDACompatModeHook { return nil, nil diff --git a/internal/modifier/graphics.go b/internal/modifier/graphics.go index 6e602d7a..e1482650 100644 --- a/internal/modifier/graphics.go +++ b/internal/modifier/graphics.go @@ -22,6 +22,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" @@ -29,7 +30,7 @@ import ( // NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification. // The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made. -func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerImage image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) { +func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerImage image.CUDA, driver *root.Driver, hookCreator hooks.HookCreator) (oci.SpecModifier, error) { if required, reason := requiresGraphicsModifier(containerImage); !required { logger.Infof("No graphics modifier required: %v", reason) return nil, nil diff --git a/internal/platform-support/dgpu/by-path-hooks.go b/internal/platform-support/dgpu/by-path-hooks.go index b78720a2..3a8d0270 100644 --- a/internal/platform-support/dgpu/by-path-hooks.go +++ b/internal/platform-support/dgpu/by-path-hooks.go @@ -22,6 +22,7 @@ import ( "path/filepath" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" ) @@ -29,7 +30,7 @@ import ( type byPathHookDiscoverer struct { logger logger.Interface devRoot string - hookCreator discover.HookCreator + hookCreator hooks.HookCreator pciBusID string deviceNodes discover.Discover } @@ -53,9 +54,14 @@ func (d *byPathHookDiscoverer) Hooks() ([]discover.Hook, error) { return nil, nil } - hook := d.hookCreator.Create("create-symlinks", links...) + hook := d.hookCreator.Create(hooks.CreateSymlinks, links...) - return hook.Hooks() + return []discover.Hook{ + { + Lifecycle: hook.Lifecycle, + Path: hook.Path, + Args: hook.Args, + }}, nil } // Mounts returns an empty slice for a full GPU diff --git a/internal/platform-support/dgpu/nvsandboxutils.go b/internal/platform-support/dgpu/nvsandboxutils.go index f8925e4a..94d8b459 100644 --- a/internal/platform-support/dgpu/nvsandboxutils.go +++ b/internal/platform-support/dgpu/nvsandboxutils.go @@ -24,6 +24,7 @@ import ( "github.com/NVIDIA/go-nvml/pkg/nvml" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" "github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils" ) @@ -32,7 +33,7 @@ type nvsandboxutilsDGPU struct { uuid string devRoot string isMig bool - hookCreator discover.HookCreator + hookCreator hooks.HookCreator deviceLinks []string } @@ -112,9 +113,14 @@ func (d *nvsandboxutilsDGPU) Hooks() ([]discover.Hook, error) { return nil, nil } - hook := d.hookCreator.Create("create-symlinks", d.deviceLinks...) + hook := d.hookCreator.Create(hooks.CreateSymlinks, d.deviceLinks...) - return hook.Hooks() + return []discover.Hook{ + { + Lifecycle: hook.Lifecycle, + Path: hook.Path, + Args: hook.Args, + }}, nil } func (d *nvsandboxutilsDGPU) Mounts() ([]discover.Mount, error) { diff --git a/internal/platform-support/dgpu/options.go b/internal/platform-support/dgpu/options.go index 6b2d62ce..038d411e 100644 --- a/internal/platform-support/dgpu/options.go +++ b/internal/platform-support/dgpu/options.go @@ -17,7 +17,7 @@ package dgpu import ( - "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps" "github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils" @@ -26,7 +26,7 @@ import ( type options struct { logger logger.Interface devRoot string - hookCreator discover.HookCreator + hookCreator hooks.HookCreator isMigDevice bool // migCaps stores the MIG capabilities for the system. @@ -54,7 +54,7 @@ func WithLogger(logger logger.Interface) Option { } // WithHookCreator sets the hook creator for the library -func WithHookCreator(hookCreator discover.HookCreator) Option { +func WithHookCreator(hookCreator hooks.HookCreator) Option { return func(l *options) { l.hookCreator = hookCreator } diff --git a/internal/platform-support/tegra/csv_test.go b/internal/platform-support/tegra/csv_test.go index 129bf00c..c754c1e9 100644 --- a/internal/platform-support/tegra/csv_test.go +++ b/internal/platform-support/tegra/csv_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/require" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" @@ -181,7 +182,7 @@ func TestDiscovererFromCSVFiles(t *testing.T) { }, } - hookCreator := discover.NewHookCreator("/usr/bin/nvidia-cdi-hook") + hookCreator := hooks.NewHookCreator("/usr/bin/nvidia-cdi-hook", nil) for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { defer setGetTargetsFromCSVFiles(tc.moutSpecs)() diff --git a/internal/platform-support/tegra/symlinks.go b/internal/platform-support/tegra/symlinks.go index 822d482f..574a4870 100644 --- a/internal/platform-support/tegra/symlinks.go +++ b/internal/platform-support/tegra/symlinks.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" ) @@ -27,7 +28,7 @@ import ( type symlinkHook struct { discover.None logger logger.Interface - hookCreator discover.HookCreator + hookCreator hooks.HookCreator targets []string // The following can be overridden for testing @@ -48,7 +49,17 @@ func (o tegraOptions) createCSVSymlinkHooks(targets []string) discover.Discover // Hooks returns a hook to create the symlinks from the required CSV files func (d symlinkHook) Hooks() ([]discover.Hook, error) { - return d.hookCreator.Create("create-symlinks", d.getCSVFileSymlinks()...).Hooks() + hook := d.hookCreator.Create(hooks.CreateSymlinks, d.getCSVFileSymlinks()...) + if hook == nil { + return nil, nil + } + + return []discover.Hook{ + { + Lifecycle: hook.Lifecycle, + Path: hook.Path, + Args: hook.Args, + }}, nil } // getSymlinkCandidates returns a list of symlinks that are candidates for being created. diff --git a/internal/platform-support/tegra/tegra.go b/internal/platform-support/tegra/tegra.go index 6ad774b4..25f92fa5 100644 --- a/internal/platform-support/tegra/tegra.go +++ b/internal/platform-support/tegra/tegra.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks" @@ -30,7 +31,7 @@ type tegraOptions struct { csvFiles []string driverRoot string devRoot string - hookCreator discover.HookCreator + hookCreator hooks.HookCreator ldconfigPath string librarySearchPaths []string ignorePatterns ignoreMountSpecPatterns @@ -134,7 +135,7 @@ func WithCSVFiles(csvFiles []string) Option { } // WithHookCreator sets the hook creator for the discoverer. -func WithHookCreator(hookCreator discover.HookCreator) Option { +func WithHookCreator(hookCreator hooks.HookCreator) Option { return func(o *tegraOptions) { o.hookCreator = hookCreator } diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index c1a82ac9..e9816485 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -21,7 +21,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" - "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" "github.com/NVIDIA/nvidia-container-toolkit/internal/info" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" @@ -75,7 +75,7 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp return nil, err } - hookCreator := discover.NewHookCreator(cfg.NVIDIACTKConfig.Path) + hookCreator := hooks.NewHookCreator(cfg.NVIDIACTKConfig.Path, nil) mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image) // We update the mode here so that we can continue passing just the config to other functions. diff --git a/pkg/nvcdi/api.go b/pkg/nvcdi/api.go index 2988026f..f1c7b97a 100644 --- a/pkg/nvcdi/api.go +++ b/pkg/nvcdi/api.go @@ -35,13 +35,3 @@ type Interface interface { GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) ([]specs.Device, error) GetDeviceSpecsByID(...string) ([]specs.Device, error) } - -// A HookName refers to one of the predefined set of CDI hooks that may be -// included in the generated CDI specification. -type HookName string - -const ( - // HookEnableCudaCompat refers to the hook used to enable CUDA Forward Compatibility. - // This was added with v1.17.5 of the NVIDIA Container Toolkit. - HookEnableCudaCompat = HookName("enable-cuda-compat") -) diff --git a/pkg/nvcdi/driver-nvml.go b/pkg/nvcdi/driver-nvml.go index 3fbc0e94..ff02ac72 100644 --- a/pkg/nvcdi/driver-nvml.go +++ b/pkg/nvcdi/driver-nvml.go @@ -106,11 +106,9 @@ func (l *nvcdilib) NewDriverLibraryDiscoverer(version string) (discover.Discover ) discoverers = append(discoverers, driverDotSoSymlinksDiscoverer) - if l.HookIsSupported(HookEnableCudaCompat) { - // TODO: The following should use the version directly. - cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, l.driver) - discoverers = append(discoverers, cudaCompatLibHookDiscoverer) - } + // TODO: The following should use the version directly. + cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, l.driver) + discoverers = append(discoverers, cudaCompatLibHookDiscoverer) updateLDCache, _ := discover.NewLDCacheUpdateHook(l.logger, libraries, l.hookCreator, l.ldconfigPath) discoverers = append(discoverers, updateLDCache) diff --git a/pkg/nvcdi/driver-wsl.go b/pkg/nvcdi/driver-wsl.go index 97e267c5..24347971 100644 --- a/pkg/nvcdi/driver-wsl.go +++ b/pkg/nvcdi/driver-wsl.go @@ -22,6 +22,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/dxcore" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" ) @@ -39,7 +40,7 @@ var requiredDriverStoreFiles = []string{ } // newWSLDriverDiscoverer returns a Discoverer for WSL2 drivers. -func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, hookCreator discover.HookCreator, ldconfigPath string) (discover.Discover, error) { +func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, hookCreator hooks.HookCreator, ldconfigPath string) (discover.Discover, error) { err := dxcore.Init() if err != nil { return nil, fmt.Errorf("failed to initialize dxcore: %v", err) @@ -60,7 +61,7 @@ func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, hookCrea } // newWSLDriverStoreDiscoverer returns a Discoverer for WSL2 drivers in the driver store associated with a dxcore adapter. -func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, hookCreator discover.HookCreator, ldconfigPath string, driverStorePaths []string) (discover.Discover, error) { +func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, hookCreator hooks.HookCreator, ldconfigPath string, driverStorePaths []string) (discover.Discover, error) { var searchPaths []string seen := make(map[string]bool) for _, path := range driverStorePaths { @@ -108,7 +109,7 @@ type nvidiaSMISimlinkHook struct { discover.None logger logger.Interface mountsFrom discover.Discover - hookCreator discover.HookCreator + hookCreator hooks.HookCreator } // Hooks returns a hook that creates a symlink to nvidia-smi in the driver store. @@ -135,7 +136,13 @@ func (m nvidiaSMISimlinkHook) Hooks() ([]discover.Hook, error) { } link := "/usr/bin/nvidia-smi" links := []string{fmt.Sprintf("%s::%s", target, link)} - symlinkHook := m.hookCreator.Create("create-symlinks", links...) + symlinkHook := m.hookCreator.Create(hooks.CreateSymlinks, links...) - return symlinkHook.Hooks() + return []discover.Hook{ + { + Lifecycle: symlinkHook.Lifecycle, + Path: symlinkHook.Path, + Args: symlinkHook.Args, + }, + }, nil } diff --git a/pkg/nvcdi/driver-wsl_test.go b/pkg/nvcdi/driver-wsl_test.go index 27247cc6..e44dd4f4 100644 --- a/pkg/nvcdi/driver-wsl_test.go +++ b/pkg/nvcdi/driver-wsl_test.go @@ -23,13 +23,14 @@ import ( "github.com/stretchr/testify/require" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" testlog "github.com/sirupsen/logrus/hooks/test" ) func TestNvidiaSMISymlinkHook(t *testing.T) { logger, _ := testlog.NewNullLogger() - hookCreator := discover.NewHookCreator("nvidia-cdi-hook") + hookCreator := hooks.NewHookCreator("nvidia-cdi-hook", nil) errMounts := errors.New("mounts error") diff --git a/pkg/nvcdi/hooks.go b/pkg/nvcdi/hooks.go deleted file mode 100644 index a4620dc8..00000000 --- a/pkg/nvcdi/hooks.go +++ /dev/null @@ -1,30 +0,0 @@ -/** -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -**/ - -package nvcdi - -// disabledHooks allows individual hooks to be disabled. -type disabledHooks map[HookName]bool - -// HookIsSupported checks whether a hook of the specified name is supported. -// Hooks must be explicitly disabled, meaning that if no disabled hooks are -// all hooks are supported. -func (l *nvcdilib) HookIsSupported(h HookName) bool { - if len(l.disabledHooks) == 0 { - return true - } - return !l.disabledHooks[h] -} diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 97a39168..3464eca3 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -23,7 +23,7 @@ import ( "github.com/NVIDIA/go-nvlib/pkg/nvlib/info" "github.com/NVIDIA/go-nvml/pkg/nvml" - "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" "github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils" @@ -56,14 +56,14 @@ type nvcdilib struct { mergedDeviceOptions []transform.MergedDeviceOption - disabledHooks disabledHooks - hookCreator discover.HookCreator + disabledHooks hooks.DisabledHooks + hookCreator hooks.HookCreator } // New creates a new nvcdi library func New(opts ...Option) (Interface, error) { l := &nvcdilib{ - disabledHooks: make(disabledHooks), + disabledHooks: make(hooks.DisabledHooks), } for _, opt := range opts { opt(l) @@ -81,9 +81,6 @@ func New(opts ...Option) (Interface, error) { if l.nvidiaCDIHookPath == "" { l.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook" } - // create hookCreator - l.hookCreator = discover.NewHookCreator(l.nvidiaCDIHookPath) - if l.driverRoot == "" { l.driverRoot = "/" } @@ -150,7 +147,7 @@ func New(opts ...Option) (Interface, error) { l.vendor = "management.nvidia.com" } // Management containers in general do not require CUDA Forward compatibility. - l.disabledHooks[HookEnableCudaCompat] = true + l.disabledHooks[hooks.EnableCudaCompat] = true lib = (*managementlib)(l) case ModeNvml: lib = (*nvmllib)(l) @@ -175,6 +172,8 @@ func New(opts ...Option) (Interface, error) { return nil, fmt.Errorf("unknown mode %q", l.mode) } + l.hookCreator = hooks.NewHookCreator(l.nvidiaCDIHookPath, l.disabledHooks) + w := wrapper{ Interface: lib, vendor: l.vendor, diff --git a/pkg/nvcdi/options.go b/pkg/nvcdi/options.go index f38f2b4a..2ce7c972 100644 --- a/pkg/nvcdi/options.go +++ b/pkg/nvcdi/options.go @@ -21,6 +21,7 @@ import ( "github.com/NVIDIA/go-nvlib/pkg/nvlib/info" "github.com/NVIDIA/go-nvml/pkg/nvml" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" ) @@ -158,10 +159,10 @@ func WithLibrarySearchPaths(paths []string) Option { // WithDisabledHook allows specific hooks to the disabled. // This option can be specified multiple times for each hook. -func WithDisabledHook(hook HookName) Option { +func WithDisabledHook(hook hooks.HookName) Option { return func(o *nvcdilib) { if o.disabledHooks == nil { - o.disabledHooks = make(map[HookName]bool) + o.disabledHooks = make(hooks.DisabledHooks) } o.disabledHooks[hook] = true } diff --git a/pkg/nvcdi/workarounds-device-folder-permissions.go b/pkg/nvcdi/workarounds-device-folder-permissions.go index 71967ac4..2b3b5bbb 100644 --- a/pkg/nvcdi/workarounds-device-folder-permissions.go +++ b/pkg/nvcdi/workarounds-device-folder-permissions.go @@ -21,6 +21,7 @@ import ( "path/filepath" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/hooks" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" ) @@ -28,7 +29,7 @@ type deviceFolderPermissions struct { logger logger.Interface devRoot string devices discover.Discover - hookCreator discover.HookCreator + hookCreator hooks.HookCreator } var _ discover.Discover = (*deviceFolderPermissions)(nil) @@ -39,7 +40,7 @@ var _ discover.Discover = (*deviceFolderPermissions)(nil) // The nested devices that are applicable to the NVIDIA GPU devices are: // - DRM devices at /dev/dri/* // - NVIDIA Caps devices at /dev/nvidia-caps/* -func newDeviceFolderPermissionHookDiscoverer(logger logger.Interface, devRoot string, hookCreator discover.HookCreator, devices discover.Discover) discover.Discover { +func newDeviceFolderPermissionHookDiscoverer(logger logger.Interface, devRoot string, hookCreator hooks.HookCreator, devices discover.Discover) discover.Discover { d := &deviceFolderPermissions{ logger: logger, devRoot: devRoot, @@ -72,7 +73,13 @@ func (d *deviceFolderPermissions) Hooks() ([]discover.Hook, error) { hook := d.hookCreator.Create("chmod", args...) - return []discover.Hook{*hook}, nil + return []discover.Hook{ + { + Lifecycle: hook.Lifecycle, + Path: hook.Path, + Args: hook.Args, + }, + }, nil } func (d *deviceFolderPermissions) getDeviceSubfolders() ([]string, error) {