mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2025-06-26 18:18:24 +00:00
Refactor the way we create CDI Hooks
Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
This commit is contained in:
@@ -9,16 +9,12 @@ import (
|
||||
|
||||
// NewCUDACompatHookDiscoverer creates a discoverer for a enable-cuda-compat hook.
|
||||
// This hook is responsible for setting up CUDA compatibility in the container and depends on the host driver version.
|
||||
func NewCUDACompatHookDiscoverer(logger logger.Interface, nvidiaCDIHookPath string, driver *root.Driver) Discover {
|
||||
func NewCUDACompatHookDiscoverer(logger logger.Interface, hookCreator HookCreator, driver *root.Driver) Discover {
|
||||
_, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver)
|
||||
var args []string
|
||||
if !strings.Contains(cudaVersionPattern, "*") {
|
||||
args = append(args, "--host-driver-version="+cudaVersionPattern)
|
||||
}
|
||||
|
||||
return CreateNvidiaCDIHook(
|
||||
nvidiaCDIHookPath,
|
||||
"enable-cuda-compat",
|
||||
args...,
|
||||
)
|
||||
return hookCreator.Create("enable-cuda-compat", args...)
|
||||
}
|
||||
|
||||
@@ -36,21 +36,21 @@ import (
|
||||
// TODO: The logic for creating DRM devices should be consolidated between this
|
||||
// and the logic for generating CDI specs for a single device. This is only used
|
||||
// when applying OCI spec modifications to an incoming spec in "legacy" mode.
|
||||
func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, nvidiaCDIHookPath string) (Discover, error) {
|
||||
func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, hookCreator HookCreator) (Discover, error) {
|
||||
drmDeviceNodes, err := newDRMDeviceDiscoverer(logger, devices, devRoot)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create DRM device discoverer: %v", err)
|
||||
}
|
||||
|
||||
drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, devRoot, nvidiaCDIHookPath)
|
||||
drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, devRoot, hookCreator)
|
||||
|
||||
discover := Merge(drmDeviceNodes, drmByPathSymlinks)
|
||||
return discover, nil
|
||||
}
|
||||
|
||||
// NewGraphicsMountsDiscoverer creates a discoverer for the mounts required by graphics tools such as vulkan.
|
||||
func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string) (Discover, error) {
|
||||
libraries := newGraphicsLibrariesDiscoverer(logger, driver, nvidiaCDIHookPath)
|
||||
func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator HookCreator) (Discover, error) {
|
||||
libraries := newGraphicsLibrariesDiscoverer(logger, driver, hookCreator)
|
||||
|
||||
configs := NewMounts(
|
||||
logger,
|
||||
@@ -95,13 +95,13 @@ func newVulkanConfigsDiscover(logger logger.Interface, driver *root.Driver) Disc
|
||||
|
||||
type graphicsDriverLibraries struct {
|
||||
Discover
|
||||
logger logger.Interface
|
||||
nvidiaCDIHookPath string
|
||||
logger logger.Interface
|
||||
hookCreator HookCreator
|
||||
}
|
||||
|
||||
var _ Discover = (*graphicsDriverLibraries)(nil)
|
||||
|
||||
func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string) Discover {
|
||||
func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator HookCreator) Discover {
|
||||
cudaLibRoot, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver)
|
||||
|
||||
libraries := NewMounts(
|
||||
@@ -140,9 +140,9 @@ func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver
|
||||
)
|
||||
|
||||
return &graphicsDriverLibraries{
|
||||
Discover: Merge(libraries, xorgLibraries),
|
||||
logger: logger,
|
||||
nvidiaCDIHookPath: nvidiaCDIHookPath,
|
||||
Discover: Merge(libraries, xorgLibraries),
|
||||
logger: logger,
|
||||
hookCreator: hookCreator,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,7 +203,7 @@ func (d graphicsDriverLibraries) Hooks() ([]Hook, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
hooks := CreateCreateSymlinkHook(d.nvidiaCDIHookPath, links)
|
||||
hooks := d.hookCreator.Create("create-symlinks", links...)
|
||||
|
||||
return hooks.Hooks()
|
||||
}
|
||||
@@ -275,19 +275,19 @@ func buildXOrgSearchPaths(libRoot string) []string {
|
||||
|
||||
type drmDevicesByPath struct {
|
||||
None
|
||||
logger logger.Interface
|
||||
nvidiaCDIHookPath string
|
||||
devRoot string
|
||||
devicesFrom Discover
|
||||
logger logger.Interface
|
||||
hookCreator HookCreator
|
||||
devRoot string
|
||||
devicesFrom Discover
|
||||
}
|
||||
|
||||
// newCreateDRMByPathSymlinks creates a discoverer for a hook to create the by-path symlinks for DRM devices discovered by the specified devices discoverer
|
||||
func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, nvidiaCDIHookPath string) Discover {
|
||||
func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, hookCreator HookCreator) Discover {
|
||||
d := drmDevicesByPath{
|
||||
logger: logger,
|
||||
nvidiaCDIHookPath: nvidiaCDIHookPath,
|
||||
devRoot: devRoot,
|
||||
devicesFrom: devices,
|
||||
logger: logger,
|
||||
hookCreator: hookCreator,
|
||||
devRoot: devRoot,
|
||||
devicesFrom: devices,
|
||||
}
|
||||
|
||||
return &d
|
||||
@@ -315,13 +315,9 @@ func (d drmDevicesByPath) Hooks() ([]Hook, error) {
|
||||
args = append(args, "--link", l)
|
||||
}
|
||||
|
||||
hook := CreateNvidiaCDIHook(
|
||||
d.nvidiaCDIHookPath,
|
||||
"create-symlinks",
|
||||
args...,
|
||||
)
|
||||
hook := d.hookCreator.Create("create-symlinks", args...)
|
||||
|
||||
return []Hook{hook}, nil
|
||||
return hook.Hooks()
|
||||
}
|
||||
|
||||
// getSpecificLinkArgs returns the required specific links that need to be created
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
|
||||
func TestGraphicsLibrariesDiscoverer(t *testing.T) {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
hookCreator := NewHookCreator("/usr/bin/nvidia-cdi-hook")
|
||||
|
||||
testCases := []struct {
|
||||
description string
|
||||
@@ -136,9 +137,9 @@ func TestGraphicsLibrariesDiscoverer(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
d := &graphicsDriverLibraries{
|
||||
Discover: tc.libraries,
|
||||
logger: logger,
|
||||
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
|
||||
Discover: tc.libraries,
|
||||
logger: logger,
|
||||
hookCreator: hookCreator,
|
||||
}
|
||||
|
||||
devices, err := d.Devices()
|
||||
|
||||
@@ -25,54 +25,73 @@ import (
|
||||
var _ Discover = (*Hook)(nil)
|
||||
|
||||
// Devices returns an empty list of devices for a Hook discoverer.
|
||||
func (h Hook) Devices() ([]Device, error) {
|
||||
func (h *Hook) Devices() ([]Device, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Mounts returns an empty list of mounts for a Hook discoverer.
|
||||
func (h Hook) Mounts() ([]Mount, error) {
|
||||
func (h *Hook) Mounts() ([]Mount, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Hooks allows the Hook type to also implement the Discoverer interface.
|
||||
// It returns a single hook
|
||||
func (h Hook) Hooks() ([]Hook, error) {
|
||||
return []Hook{h}, nil
|
||||
}
|
||||
|
||||
// CreateCreateSymlinkHook creates a hook which creates a symlink from link -> target.
|
||||
func CreateCreateSymlinkHook(nvidiaCDIHookPath string, links []string) Discover {
|
||||
if len(links) == 0 {
|
||||
return None{}
|
||||
func (h *Hook) Hooks() ([]Hook, error) {
|
||||
if h == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var args []string
|
||||
for _, link := range links {
|
||||
args = append(args, "--link", link)
|
||||
return []Hook{*h}, nil
|
||||
}
|
||||
|
||||
type CDIHook struct {
|
||||
nvidiaCDIHookPath string
|
||||
disabledHooks map[string]bool
|
||||
}
|
||||
|
||||
type HookCreator interface {
|
||||
Create(string, ...string) *Hook
|
||||
DisableHook(string)
|
||||
}
|
||||
|
||||
func NewHookCreator(nvidiaCDIHookPath string) HookCreator {
|
||||
return &CDIHook{
|
||||
nvidiaCDIHookPath: nvidiaCDIHookPath,
|
||||
disabledHooks: make(map[string]bool),
|
||||
}
|
||||
return CreateNvidiaCDIHook(
|
||||
nvidiaCDIHookPath,
|
||||
"create-symlinks",
|
||||
args...,
|
||||
)
|
||||
}
|
||||
|
||||
// CreateNvidiaCDIHook creates a hook which invokes the NVIDIA Container CLI hook subcommand.
|
||||
func CreateNvidiaCDIHook(nvidiaCDIHookPath string, hookName string, additionalArgs ...string) Hook {
|
||||
return cdiHook(nvidiaCDIHookPath).Create(hookName, additionalArgs...)
|
||||
}
|
||||
func (c CDIHook) Create(name string, args ...string) *Hook {
|
||||
if c.disabledHooks[name] {
|
||||
return nil
|
||||
}
|
||||
|
||||
type cdiHook string
|
||||
if name == "create-symlinks" {
|
||||
if len(args) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c cdiHook) Create(name string, args ...string) Hook {
|
||||
return Hook{
|
||||
links := make([]string, 0, len(args))
|
||||
for _, arg := range args {
|
||||
links = append(links, "--link", arg)
|
||||
}
|
||||
args = links
|
||||
}
|
||||
|
||||
return &Hook{
|
||||
Lifecycle: cdi.CreateContainerHook,
|
||||
Path: string(c),
|
||||
Path: c.nvidiaCDIHookPath,
|
||||
Args: append(c.requiredArgs(name), args...),
|
||||
}
|
||||
}
|
||||
func (c cdiHook) requiredArgs(name string) []string {
|
||||
base := filepath.Base(string(c))
|
||||
|
||||
// DisableHook disables a hook by adding it to the disabled hooks map
|
||||
func (c *CDIHook) DisableHook(name string) {
|
||||
c.disabledHooks[name] = true
|
||||
}
|
||||
|
||||
func (c CDIHook) requiredArgs(name string) []string {
|
||||
base := filepath.Base(c.nvidiaCDIHookPath)
|
||||
if base == "nvidia-ctk" {
|
||||
return []string{base, "hook", name}
|
||||
}
|
||||
|
||||
@@ -25,12 +25,12 @@ import (
|
||||
)
|
||||
|
||||
// NewLDCacheUpdateHook creates a discoverer that updates the ldcache for the specified mounts. A logger can also be specified
|
||||
func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, nvidiaCDIHookPath, ldconfigPath string) (Discover, error) {
|
||||
func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, hookCreator HookCreator, ldconfigPath string) (Discover, error) {
|
||||
d := ldconfig{
|
||||
logger: logger,
|
||||
nvidiaCDIHookPath: nvidiaCDIHookPath,
|
||||
ldconfigPath: ldconfigPath,
|
||||
mountsFrom: mounts,
|
||||
logger: logger,
|
||||
hookCreator: hookCreator,
|
||||
ldconfigPath: ldconfigPath,
|
||||
mountsFrom: mounts,
|
||||
}
|
||||
|
||||
return &d, nil
|
||||
@@ -38,10 +38,10 @@ func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, nvidiaCDIHoo
|
||||
|
||||
type ldconfig struct {
|
||||
None
|
||||
logger logger.Interface
|
||||
nvidiaCDIHookPath string
|
||||
ldconfigPath string
|
||||
mountsFrom Discover
|
||||
logger logger.Interface
|
||||
hookCreator HookCreator
|
||||
ldconfigPath string
|
||||
mountsFrom Discover
|
||||
}
|
||||
|
||||
// Hooks checks the required mounts for libraries and returns a hook to update the LDcache for the discovered paths.
|
||||
@@ -50,16 +50,18 @@ func (d ldconfig) Hooks() ([]Hook, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to discover mounts for ldcache update: %v", err)
|
||||
}
|
||||
h := CreateLDCacheUpdateHook(
|
||||
d.nvidiaCDIHookPath,
|
||||
|
||||
h := createLDCacheUpdateHook(
|
||||
d.hookCreator,
|
||||
d.ldconfigPath,
|
||||
getLibraryPaths(mounts),
|
||||
)
|
||||
return []Hook{h}, nil
|
||||
|
||||
return h.Hooks()
|
||||
}
|
||||
|
||||
// CreateLDCacheUpdateHook locates the NVIDIA Container Toolkit CLI and creates a hook for updating the LD Cache
|
||||
func CreateLDCacheUpdateHook(executable string, ldconfig string, libraries []string) Hook {
|
||||
// createLDCacheUpdateHook locates the NVIDIA Container Toolkit CLI and creates a hook for updating the LD Cache
|
||||
func createLDCacheUpdateHook(hookCreator HookCreator, ldconfig string, libraries []string) *Hook {
|
||||
var args []string
|
||||
|
||||
if ldconfig != "" {
|
||||
@@ -70,13 +72,7 @@ func CreateLDCacheUpdateHook(executable string, ldconfig string, libraries []str
|
||||
args = append(args, "--folder", f)
|
||||
}
|
||||
|
||||
hook := CreateNvidiaCDIHook(
|
||||
executable,
|
||||
"update-ldcache",
|
||||
args...,
|
||||
)
|
||||
|
||||
return hook
|
||||
return hookCreator.Create("update-ldcache", args...)
|
||||
}
|
||||
|
||||
// getLibraryPaths extracts the library dirs from the specified mounts
|
||||
|
||||
@@ -31,6 +31,7 @@ const (
|
||||
|
||||
func TestLDCacheUpdateHook(t *testing.T) {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
hookCreator := NewHookCreator(testNvidiaCDIHookPath)
|
||||
|
||||
testCases := []struct {
|
||||
description string
|
||||
@@ -97,7 +98,7 @@ func TestLDCacheUpdateHook(t *testing.T) {
|
||||
Lifecycle: "createContainer",
|
||||
}
|
||||
|
||||
d, err := NewLDCacheUpdateHook(logger, mountMock, testNvidiaCDIHookPath, tc.ldconfigPath)
|
||||
d, err := NewLDCacheUpdateHook(logger, mountMock, hookCreator, tc.ldconfigPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
hooks, err := d.Hooks()
|
||||
|
||||
@@ -23,20 +23,20 @@ import (
|
||||
|
||||
type additionalSymlinks struct {
|
||||
Discover
|
||||
version string
|
||||
nvidiaCDIHookPath string
|
||||
version string
|
||||
hookCreator HookCreator
|
||||
}
|
||||
|
||||
// WithDriverDotSoSymlinks decorates the provided discoverer.
|
||||
// A hook is added that checks for specific driver symlinks that need to be created.
|
||||
func WithDriverDotSoSymlinks(mounts Discover, version string, nvidiaCDIHookPath string) Discover {
|
||||
func WithDriverDotSoSymlinks(mounts Discover, version string, hookCreator HookCreator) Discover {
|
||||
if version == "" {
|
||||
version = "*.*"
|
||||
}
|
||||
return &additionalSymlinks{
|
||||
Discover: mounts,
|
||||
nvidiaCDIHookPath: nvidiaCDIHookPath,
|
||||
version: version,
|
||||
Discover: mounts,
|
||||
hookCreator: hookCreator,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,8 +73,12 @@ func (d *additionalSymlinks) Hooks() ([]Hook, error) {
|
||||
return hooks, nil
|
||||
}
|
||||
|
||||
hook := CreateCreateSymlinkHook(d.nvidiaCDIHookPath, links).(Hook)
|
||||
return append(hooks, hook), nil
|
||||
hook, err := d.hookCreator.Create("create-symlinks", links...).Hooks()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create symlink hook: %v", err)
|
||||
}
|
||||
|
||||
return append(hooks, hook...), nil
|
||||
}
|
||||
|
||||
// getLinksForMount maps the path to created links if any.
|
||||
|
||||
@@ -308,10 +308,12 @@ func TestWithWithDriverDotSoSymlinks(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
hookCreator := NewHookCreator("/path/to/nvidia-cdi-hook")
|
||||
|
||||
d := WithDriverDotSoSymlinks(
|
||||
tc.discover,
|
||||
tc.version,
|
||||
"/path/to/nvidia-cdi-hook",
|
||||
hookCreator,
|
||||
)
|
||||
|
||||
devices, err := d.Devices()
|
||||
|
||||
Reference in New Issue
Block a user