Refactor the way we create CDI Hooks

Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
This commit is contained in:
Carlos Eduardo Arango Gutierrez
2025-05-16 18:44:59 +02:00
parent ac8f190c99
commit 641f4e6a4e
29 changed files with 208 additions and 208 deletions

View File

@@ -9,16 +9,12 @@ import (
// NewCUDACompatHookDiscoverer creates a discoverer for a enable-cuda-compat hook.
// This hook is responsible for setting up CUDA compatibility in the container and depends on the host driver version.
func NewCUDACompatHookDiscoverer(logger logger.Interface, nvidiaCDIHookPath string, driver *root.Driver) Discover {
func NewCUDACompatHookDiscoverer(logger logger.Interface, hookCreator HookCreator, driver *root.Driver) Discover {
_, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver)
var args []string
if !strings.Contains(cudaVersionPattern, "*") {
args = append(args, "--host-driver-version="+cudaVersionPattern)
}
return CreateNvidiaCDIHook(
nvidiaCDIHookPath,
"enable-cuda-compat",
args...,
)
return hookCreator.Create("enable-cuda-compat", args...)
}

View File

@@ -36,21 +36,21 @@ import (
// TODO: The logic for creating DRM devices should be consolidated between this
// and the logic for generating CDI specs for a single device. This is only used
// when applying OCI spec modifications to an incoming spec in "legacy" mode.
func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, nvidiaCDIHookPath string) (Discover, error) {
func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, hookCreator HookCreator) (Discover, error) {
drmDeviceNodes, err := newDRMDeviceDiscoverer(logger, devices, devRoot)
if err != nil {
return nil, fmt.Errorf("failed to create DRM device discoverer: %v", err)
}
drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, devRoot, nvidiaCDIHookPath)
drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, devRoot, hookCreator)
discover := Merge(drmDeviceNodes, drmByPathSymlinks)
return discover, nil
}
// NewGraphicsMountsDiscoverer creates a discoverer for the mounts required by graphics tools such as vulkan.
func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string) (Discover, error) {
libraries := newGraphicsLibrariesDiscoverer(logger, driver, nvidiaCDIHookPath)
func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator HookCreator) (Discover, error) {
libraries := newGraphicsLibrariesDiscoverer(logger, driver, hookCreator)
configs := NewMounts(
logger,
@@ -95,13 +95,13 @@ func newVulkanConfigsDiscover(logger logger.Interface, driver *root.Driver) Disc
type graphicsDriverLibraries struct {
Discover
logger logger.Interface
nvidiaCDIHookPath string
logger logger.Interface
hookCreator HookCreator
}
var _ Discover = (*graphicsDriverLibraries)(nil)
func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string) Discover {
func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator HookCreator) Discover {
cudaLibRoot, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver)
libraries := NewMounts(
@@ -140,9 +140,9 @@ func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver
)
return &graphicsDriverLibraries{
Discover: Merge(libraries, xorgLibraries),
logger: logger,
nvidiaCDIHookPath: nvidiaCDIHookPath,
Discover: Merge(libraries, xorgLibraries),
logger: logger,
hookCreator: hookCreator,
}
}
@@ -203,7 +203,7 @@ func (d graphicsDriverLibraries) Hooks() ([]Hook, error) {
return nil, nil
}
hooks := CreateCreateSymlinkHook(d.nvidiaCDIHookPath, links)
hooks := d.hookCreator.Create("create-symlinks", links...)
return hooks.Hooks()
}
@@ -275,19 +275,19 @@ func buildXOrgSearchPaths(libRoot string) []string {
type drmDevicesByPath struct {
None
logger logger.Interface
nvidiaCDIHookPath string
devRoot string
devicesFrom Discover
logger logger.Interface
hookCreator HookCreator
devRoot string
devicesFrom Discover
}
// newCreateDRMByPathSymlinks creates a discoverer for a hook to create the by-path symlinks for DRM devices discovered by the specified devices discoverer
func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, nvidiaCDIHookPath string) Discover {
func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, hookCreator HookCreator) Discover {
d := drmDevicesByPath{
logger: logger,
nvidiaCDIHookPath: nvidiaCDIHookPath,
devRoot: devRoot,
devicesFrom: devices,
logger: logger,
hookCreator: hookCreator,
devRoot: devRoot,
devicesFrom: devices,
}
return &d
@@ -315,13 +315,9 @@ func (d drmDevicesByPath) Hooks() ([]Hook, error) {
args = append(args, "--link", l)
}
hook := CreateNvidiaCDIHook(
d.nvidiaCDIHookPath,
"create-symlinks",
args...,
)
hook := d.hookCreator.Create("create-symlinks", args...)
return []Hook{hook}, nil
return hook.Hooks()
}
// getSpecificLinkArgs returns the required specific links that need to be created

View File

@@ -25,6 +25,7 @@ import (
func TestGraphicsLibrariesDiscoverer(t *testing.T) {
logger, _ := testlog.NewNullLogger()
hookCreator := NewHookCreator("/usr/bin/nvidia-cdi-hook")
testCases := []struct {
description string
@@ -136,9 +137,9 @@ func TestGraphicsLibrariesDiscoverer(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
d := &graphicsDriverLibraries{
Discover: tc.libraries,
logger: logger,
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
Discover: tc.libraries,
logger: logger,
hookCreator: hookCreator,
}
devices, err := d.Devices()

View File

@@ -25,54 +25,73 @@ import (
var _ Discover = (*Hook)(nil)
// Devices returns an empty list of devices for a Hook discoverer.
func (h Hook) Devices() ([]Device, error) {
func (h *Hook) Devices() ([]Device, error) {
return nil, nil
}
// Mounts returns an empty list of mounts for a Hook discoverer.
func (h Hook) Mounts() ([]Mount, error) {
func (h *Hook) Mounts() ([]Mount, error) {
return nil, nil
}
// Hooks allows the Hook type to also implement the Discoverer interface.
// It returns a single hook
func (h Hook) Hooks() ([]Hook, error) {
return []Hook{h}, nil
}
// CreateCreateSymlinkHook creates a hook which creates a symlink from link -> target.
func CreateCreateSymlinkHook(nvidiaCDIHookPath string, links []string) Discover {
if len(links) == 0 {
return None{}
func (h *Hook) Hooks() ([]Hook, error) {
if h == nil {
return nil, nil
}
var args []string
for _, link := range links {
args = append(args, "--link", link)
return []Hook{*h}, nil
}
type CDIHook struct {
nvidiaCDIHookPath string
disabledHooks map[string]bool
}
type HookCreator interface {
Create(string, ...string) *Hook
DisableHook(string)
}
func NewHookCreator(nvidiaCDIHookPath string) HookCreator {
return &CDIHook{
nvidiaCDIHookPath: nvidiaCDIHookPath,
disabledHooks: make(map[string]bool),
}
return CreateNvidiaCDIHook(
nvidiaCDIHookPath,
"create-symlinks",
args...,
)
}
// CreateNvidiaCDIHook creates a hook which invokes the NVIDIA Container CLI hook subcommand.
func CreateNvidiaCDIHook(nvidiaCDIHookPath string, hookName string, additionalArgs ...string) Hook {
return cdiHook(nvidiaCDIHookPath).Create(hookName, additionalArgs...)
}
func (c CDIHook) Create(name string, args ...string) *Hook {
if c.disabledHooks[name] {
return nil
}
type cdiHook string
if name == "create-symlinks" {
if len(args) == 0 {
return nil
}
func (c cdiHook) Create(name string, args ...string) Hook {
return Hook{
links := make([]string, 0, len(args))
for _, arg := range args {
links = append(links, "--link", arg)
}
args = links
}
return &Hook{
Lifecycle: cdi.CreateContainerHook,
Path: string(c),
Path: c.nvidiaCDIHookPath,
Args: append(c.requiredArgs(name), args...),
}
}
func (c cdiHook) requiredArgs(name string) []string {
base := filepath.Base(string(c))
// DisableHook disables a hook by adding it to the disabled hooks map
func (c *CDIHook) DisableHook(name string) {
c.disabledHooks[name] = true
}
func (c CDIHook) requiredArgs(name string) []string {
base := filepath.Base(c.nvidiaCDIHookPath)
if base == "nvidia-ctk" {
return []string{base, "hook", name}
}

View File

@@ -25,12 +25,12 @@ import (
)
// NewLDCacheUpdateHook creates a discoverer that updates the ldcache for the specified mounts. A logger can also be specified
func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, nvidiaCDIHookPath, ldconfigPath string) (Discover, error) {
func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, hookCreator HookCreator, ldconfigPath string) (Discover, error) {
d := ldconfig{
logger: logger,
nvidiaCDIHookPath: nvidiaCDIHookPath,
ldconfigPath: ldconfigPath,
mountsFrom: mounts,
logger: logger,
hookCreator: hookCreator,
ldconfigPath: ldconfigPath,
mountsFrom: mounts,
}
return &d, nil
@@ -38,10 +38,10 @@ func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, nvidiaCDIHoo
type ldconfig struct {
None
logger logger.Interface
nvidiaCDIHookPath string
ldconfigPath string
mountsFrom Discover
logger logger.Interface
hookCreator HookCreator
ldconfigPath string
mountsFrom Discover
}
// Hooks checks the required mounts for libraries and returns a hook to update the LDcache for the discovered paths.
@@ -50,16 +50,18 @@ func (d ldconfig) Hooks() ([]Hook, error) {
if err != nil {
return nil, fmt.Errorf("failed to discover mounts for ldcache update: %v", err)
}
h := CreateLDCacheUpdateHook(
d.nvidiaCDIHookPath,
h := createLDCacheUpdateHook(
d.hookCreator,
d.ldconfigPath,
getLibraryPaths(mounts),
)
return []Hook{h}, nil
return h.Hooks()
}
// CreateLDCacheUpdateHook locates the NVIDIA Container Toolkit CLI and creates a hook for updating the LD Cache
func CreateLDCacheUpdateHook(executable string, ldconfig string, libraries []string) Hook {
// createLDCacheUpdateHook locates the NVIDIA Container Toolkit CLI and creates a hook for updating the LD Cache
func createLDCacheUpdateHook(hookCreator HookCreator, ldconfig string, libraries []string) *Hook {
var args []string
if ldconfig != "" {
@@ -70,13 +72,7 @@ func CreateLDCacheUpdateHook(executable string, ldconfig string, libraries []str
args = append(args, "--folder", f)
}
hook := CreateNvidiaCDIHook(
executable,
"update-ldcache",
args...,
)
return hook
return hookCreator.Create("update-ldcache", args...)
}
// getLibraryPaths extracts the library dirs from the specified mounts

View File

@@ -31,6 +31,7 @@ const (
func TestLDCacheUpdateHook(t *testing.T) {
logger, _ := testlog.NewNullLogger()
hookCreator := NewHookCreator(testNvidiaCDIHookPath)
testCases := []struct {
description string
@@ -97,7 +98,7 @@ func TestLDCacheUpdateHook(t *testing.T) {
Lifecycle: "createContainer",
}
d, err := NewLDCacheUpdateHook(logger, mountMock, testNvidiaCDIHookPath, tc.ldconfigPath)
d, err := NewLDCacheUpdateHook(logger, mountMock, hookCreator, tc.ldconfigPath)
require.NoError(t, err)
hooks, err := d.Hooks()

View File

@@ -23,20 +23,20 @@ import (
type additionalSymlinks struct {
Discover
version string
nvidiaCDIHookPath string
version string
hookCreator HookCreator
}
// WithDriverDotSoSymlinks decorates the provided discoverer.
// A hook is added that checks for specific driver symlinks that need to be created.
func WithDriverDotSoSymlinks(mounts Discover, version string, nvidiaCDIHookPath string) Discover {
func WithDriverDotSoSymlinks(mounts Discover, version string, hookCreator HookCreator) Discover {
if version == "" {
version = "*.*"
}
return &additionalSymlinks{
Discover: mounts,
nvidiaCDIHookPath: nvidiaCDIHookPath,
version: version,
Discover: mounts,
hookCreator: hookCreator,
version: version,
}
}
@@ -73,8 +73,12 @@ func (d *additionalSymlinks) Hooks() ([]Hook, error) {
return hooks, nil
}
hook := CreateCreateSymlinkHook(d.nvidiaCDIHookPath, links).(Hook)
return append(hooks, hook), nil
hook, err := d.hookCreator.Create("create-symlinks", links...).Hooks()
if err != nil {
return nil, fmt.Errorf("failed to create symlink hook: %v", err)
}
return append(hooks, hook...), nil
}
// getLinksForMount maps the path to created links if any.

View File

@@ -308,10 +308,12 @@ func TestWithWithDriverDotSoSymlinks(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
hookCreator := NewHookCreator("/path/to/nvidia-cdi-hook")
d := WithDriverDotSoSymlinks(
tc.discover,
tc.version,
"/path/to/nvidia-cdi-hook",
hookCreator,
)
devices, err := d.Devices()