mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2025-06-26 18:18:24 +00:00
Create an internal/hooks pkg to centralize hook management
Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
This commit is contained in:
parent
f93d96a0de
commit
61ae3dc746
@ -3,18 +3,48 @@ package discover
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
|
||||
)
|
||||
|
||||
// cudaCompatHook is a discoverer for the enable-cuda-compat hook.
|
||||
type cudaCompatHook struct {
|
||||
hooks.Hook
|
||||
}
|
||||
|
||||
// NewCUDACompatHookDiscoverer creates a discoverer for a enable-cuda-compat hook.
|
||||
// This hook is responsible for setting up CUDA compatibility in the container and depends on the host driver version.
|
||||
func NewCUDACompatHookDiscoverer(logger logger.Interface, hookCreator HookCreator, driver *root.Driver) Discover {
|
||||
func NewCUDACompatHookDiscoverer(logger logger.Interface, hookCreator hooks.HookCreator, driver *root.Driver) Discover {
|
||||
_, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver)
|
||||
var args []string
|
||||
if !strings.Contains(cudaVersionPattern, "*") {
|
||||
args = append(args, "--host-driver-version="+cudaVersionPattern)
|
||||
}
|
||||
|
||||
return hookCreator.Create("enable-cuda-compat", args...)
|
||||
hook := hookCreator.Create(hooks.EnableCudaCompat, args...)
|
||||
if hook == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &cudaCompatHook{
|
||||
Hook: *hook,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cudaCompatHook) Hooks() ([]Hook, error) {
|
||||
return []Hook{
|
||||
{
|
||||
Lifecycle: h.Lifecycle,
|
||||
Path: h.Path,
|
||||
Args: h.Args,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func (h *cudaCompatHook) Devices() ([]Device, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (h *cudaCompatHook) Mounts() ([]Mount, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -23,6 +23,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/info/drm"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
@ -36,7 +37,7 @@ import (
|
||||
// TODO: The logic for creating DRM devices should be consolidated between this
|
||||
// and the logic for generating CDI specs for a single device. This is only used
|
||||
// when applying OCI spec modifications to an incoming spec in "legacy" mode.
|
||||
func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, hookCreator HookCreator) (Discover, error) {
|
||||
func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, hookCreator hooks.HookCreator) (Discover, error) {
|
||||
drmDeviceNodes, err := newDRMDeviceDiscoverer(logger, devices, devRoot)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create DRM device discoverer: %v", err)
|
||||
@ -49,7 +50,7 @@ func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices
|
||||
}
|
||||
|
||||
// NewGraphicsMountsDiscoverer creates a discoverer for the mounts required by graphics tools such as vulkan.
|
||||
func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator HookCreator) (Discover, error) {
|
||||
func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator hooks.HookCreator) (Discover, error) {
|
||||
libraries := newGraphicsLibrariesDiscoverer(logger, driver, hookCreator)
|
||||
|
||||
configs := NewMounts(
|
||||
@ -96,12 +97,12 @@ func newVulkanConfigsDiscover(logger logger.Interface, driver *root.Driver) Disc
|
||||
type graphicsDriverLibraries struct {
|
||||
Discover
|
||||
logger logger.Interface
|
||||
hookCreator HookCreator
|
||||
hookCreator hooks.HookCreator
|
||||
}
|
||||
|
||||
var _ Discover = (*graphicsDriverLibraries)(nil)
|
||||
|
||||
func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator HookCreator) Discover {
|
||||
func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator hooks.HookCreator) Discover {
|
||||
cudaLibRoot, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver)
|
||||
|
||||
libraries := NewMounts(
|
||||
@ -203,9 +204,15 @@ func (d graphicsDriverLibraries) Hooks() ([]Hook, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
hook := d.hookCreator.Create("create-symlinks", links...)
|
||||
hook := d.hookCreator.Create(hooks.CreateSymlinks, links...)
|
||||
|
||||
return hook.Hooks()
|
||||
return []Hook{
|
||||
{
|
||||
Lifecycle: hook.Lifecycle,
|
||||
Path: hook.Path,
|
||||
Args: hook.Args,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// isDriverLibrary checks whether the specified filename is a specific driver library.
|
||||
@ -276,13 +283,13 @@ func buildXOrgSearchPaths(libRoot string) []string {
|
||||
type drmDevicesByPath struct {
|
||||
None
|
||||
logger logger.Interface
|
||||
hookCreator HookCreator
|
||||
hookCreator hooks.HookCreator
|
||||
devRoot string
|
||||
devicesFrom Discover
|
||||
}
|
||||
|
||||
// newCreateDRMByPathSymlinks creates a discoverer for a hook to create the by-path symlinks for DRM devices discovered by the specified devices discoverer
|
||||
func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, hookCreator HookCreator) Discover {
|
||||
func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, hookCreator hooks.HookCreator) Discover {
|
||||
d := drmDevicesByPath{
|
||||
logger: logger,
|
||||
hookCreator: hookCreator,
|
||||
@ -315,9 +322,15 @@ func (d drmDevicesByPath) Hooks() ([]Hook, error) {
|
||||
args = append(args, "--link", l)
|
||||
}
|
||||
|
||||
hook := d.hookCreator.Create("create-symlinks", args...)
|
||||
hook := d.hookCreator.Create(hooks.CreateSymlinks, args...)
|
||||
|
||||
return hook.Hooks()
|
||||
return []Hook{
|
||||
{
|
||||
Lifecycle: hook.Lifecycle,
|
||||
Path: hook.Path,
|
||||
Args: hook.Args,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getSpecificLinkArgs returns the required specific links that need to be created
|
||||
|
@ -19,13 +19,15 @@ package discover
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGraphicsLibrariesDiscoverer(t *testing.T) {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
hookCreator := NewHookCreator("/usr/bin/nvidia-cdi-hook")
|
||||
hookCreator := hooks.NewHookCreator("/usr/bin/nvidia-cdi-hook", nil)
|
||||
|
||||
testCases := []struct {
|
||||
description string
|
||||
|
@ -1,92 +0,0 @@
|
||||
/**
|
||||
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
"tags.cncf.io/container-device-interface/pkg/cdi"
|
||||
)
|
||||
|
||||
var _ Discover = (*Hook)(nil)
|
||||
|
||||
// Devices returns an empty list of devices for a Hook discoverer.
|
||||
func (h *Hook) Devices() ([]Device, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Mounts returns an empty list of mounts for a Hook discoverer.
|
||||
func (h *Hook) Mounts() ([]Mount, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Hooks allows the Hook type to also implement the Discoverer interface.
|
||||
// It returns a single hook
|
||||
func (h *Hook) Hooks() ([]Hook, error) {
|
||||
if h == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return []Hook{*h}, nil
|
||||
}
|
||||
|
||||
// Option is a function that configures the nvcdilib
|
||||
type Option func(*CDIHook)
|
||||
|
||||
type CDIHook struct {
|
||||
nvidiaCDIHookPath string
|
||||
}
|
||||
|
||||
type HookCreator interface {
|
||||
Create(string, ...string) *Hook
|
||||
}
|
||||
|
||||
func NewHookCreator(nvidiaCDIHookPath string) HookCreator {
|
||||
CDIHook := &CDIHook{
|
||||
nvidiaCDIHookPath: nvidiaCDIHookPath,
|
||||
}
|
||||
|
||||
return CDIHook
|
||||
}
|
||||
|
||||
func (c CDIHook) Create(name string, args ...string) *Hook {
|
||||
if name == "create-symlinks" {
|
||||
if len(args) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
links := []string{}
|
||||
for _, arg := range args {
|
||||
links = append(links, "--link", arg)
|
||||
}
|
||||
args = links
|
||||
}
|
||||
|
||||
return &Hook{
|
||||
Lifecycle: cdi.CreateContainerHook,
|
||||
Path: c.nvidiaCDIHookPath,
|
||||
Args: append(c.requiredArgs(name), args...),
|
||||
}
|
||||
}
|
||||
|
||||
func (c CDIHook) requiredArgs(name string) []string {
|
||||
base := filepath.Base(c.nvidiaCDIHookPath)
|
||||
if base == "nvidia-ctk" {
|
||||
return []string{base, "hook", name}
|
||||
}
|
||||
return []string{base, name}
|
||||
}
|
@ -21,11 +21,12 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
)
|
||||
|
||||
// NewLDCacheUpdateHook creates a discoverer that updates the ldcache for the specified mounts. A logger can also be specified
|
||||
func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, hookCreator HookCreator, ldconfigPath string) (Discover, error) {
|
||||
func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, hookCreator hooks.HookCreator, ldconfigPath string) (Discover, error) {
|
||||
d := ldconfig{
|
||||
logger: logger,
|
||||
hookCreator: hookCreator,
|
||||
@ -39,7 +40,7 @@ func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, hookCreator
|
||||
type ldconfig struct {
|
||||
None
|
||||
logger logger.Interface
|
||||
hookCreator HookCreator
|
||||
hookCreator hooks.HookCreator
|
||||
ldconfigPath string
|
||||
mountsFrom Discover
|
||||
}
|
||||
@ -57,11 +58,17 @@ func (d ldconfig) Hooks() ([]Hook, error) {
|
||||
getLibraryPaths(mounts),
|
||||
)
|
||||
|
||||
return h.Hooks()
|
||||
return []Hook{
|
||||
{
|
||||
Lifecycle: h.Lifecycle,
|
||||
Path: h.Path,
|
||||
Args: h.Args,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createLDCacheUpdateHook locates the NVIDIA Container Toolkit CLI and creates a hook for updating the LD Cache
|
||||
func createLDCacheUpdateHook(hookCreator HookCreator, ldconfig string, libraries []string) *Hook {
|
||||
func createLDCacheUpdateHook(hookCreator hooks.HookCreator, ldconfig string, libraries []string) *Hook {
|
||||
var args []string
|
||||
|
||||
if ldconfig != "" {
|
||||
@ -72,7 +79,12 @@ func createLDCacheUpdateHook(hookCreator HookCreator, ldconfig string, libraries
|
||||
args = append(args, "--folder", f)
|
||||
}
|
||||
|
||||
return hookCreator.Create("update-ldcache", args...)
|
||||
h := hookCreator.Create(hooks.UpdateLDCache, args...)
|
||||
return &Hook{
|
||||
Lifecycle: h.Lifecycle,
|
||||
Path: h.Path,
|
||||
Args: h.Args,
|
||||
}
|
||||
}
|
||||
|
||||
// getLibraryPaths extracts the library dirs from the specified mounts
|
||||
|
@ -20,6 +20,8 @@ import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -31,7 +33,7 @@ const (
|
||||
|
||||
func TestLDCacheUpdateHook(t *testing.T) {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
hookCreator := NewHookCreator(testNvidiaCDIHookPath)
|
||||
hookCreator := hooks.NewHookCreator(testNvidiaCDIHookPath, nil)
|
||||
|
||||
testCases := []struct {
|
||||
description string
|
||||
|
@ -19,17 +19,19 @@ package discover
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
)
|
||||
|
||||
type additionalSymlinks struct {
|
||||
Discover
|
||||
version string
|
||||
hookCreator HookCreator
|
||||
hookCreator hooks.HookCreator
|
||||
}
|
||||
|
||||
// WithDriverDotSoSymlinks decorates the provided discoverer.
|
||||
// A hook is added that checks for specific driver symlinks that need to be created.
|
||||
func WithDriverDotSoSymlinks(mounts Discover, version string, hookCreator HookCreator) Discover {
|
||||
func WithDriverDotSoSymlinks(mounts Discover, version string, hookCreator hooks.HookCreator) Discover {
|
||||
if version == "" {
|
||||
version = "*.*"
|
||||
}
|
||||
@ -46,7 +48,7 @@ func (d *additionalSymlinks) Hooks() ([]Hook, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get library mounts: %v", err)
|
||||
}
|
||||
hooks, err := d.Discover.Hooks()
|
||||
h, err := d.Discover.Hooks()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get hooks: %v", err)
|
||||
}
|
||||
@ -70,15 +72,16 @@ func (d *additionalSymlinks) Hooks() ([]Hook, error) {
|
||||
}
|
||||
|
||||
if len(links) == 0 {
|
||||
return hooks, nil
|
||||
return h, nil
|
||||
}
|
||||
|
||||
createSymlinkHooks, err := d.hookCreator.Create("create-symlinks", links...).Hooks()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create symlink hook: %v", err)
|
||||
}
|
||||
createSymlinkHook := d.hookCreator.Create(hooks.CreateSymlinks, links...)
|
||||
|
||||
return append(hooks, createSymlinkHooks...), nil
|
||||
return append(h, Hook{
|
||||
Lifecycle: createSymlinkHook.Lifecycle,
|
||||
Path: createSymlinkHook.Path,
|
||||
Args: createSymlinkHook.Args,
|
||||
}), nil
|
||||
}
|
||||
|
||||
// getLinksForMount maps the path to created links if any.
|
||||
|
@ -19,6 +19,8 @@ package discover
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -306,7 +308,7 @@ func TestWithWithDriverDotSoSymlinks(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
hookCreator := NewHookCreator("/path/to/nvidia-cdi-hook")
|
||||
hookCreator := hooks.NewHookCreator("/path/to/nvidia-cdi-hook", nil)
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
d := WithDriverDotSoSymlinks(
|
||||
|
116
internal/hooks/hooks.go
Normal file
116
internal/hooks/hooks.go
Normal file
@ -0,0 +1,116 @@
|
||||
/**
|
||||
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package hooks
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
"tags.cncf.io/container-device-interface/pkg/cdi"
|
||||
)
|
||||
|
||||
// A HookName refers to one of the predefined set of CDI hooks that may be
|
||||
// included in the generated CDI specification.
|
||||
type HookName string
|
||||
|
||||
// DisabledHooks allows individual hooks to be disabled.
|
||||
type DisabledHooks map[HookName]bool
|
||||
|
||||
const (
|
||||
// EnableCudaCompat refers to the hook used to enable CUDA Forward Compatibility.
|
||||
// This was added with v1.17.5 of the NVIDIA Container Toolkit.
|
||||
EnableCudaCompat = HookName("enable-cuda-compat")
|
||||
// CreateSymlinks refers to the hook used to create symlinks for the NVIDIA
|
||||
// Container Toolkit. This was added with v1.17.5 of the NVIDIA Container Toolkit.
|
||||
CreateSymlinks = HookName("create-symlinks")
|
||||
// UpdateLDCache refers to the hook used to update the LD cache for the NVIDIA
|
||||
// Container Toolkit. This was added with v1.17.5 of the NVIDIA Container Toolkit.
|
||||
UpdateLDCache = HookName("update-ldcache")
|
||||
)
|
||||
|
||||
// Hook represents an OCI container hook.
|
||||
type Hook struct {
|
||||
Lifecycle string
|
||||
Path string
|
||||
Args []string
|
||||
}
|
||||
|
||||
// Option is a function that configures the nvcdilib
|
||||
type Option func(*CDIHook)
|
||||
|
||||
type CDIHook struct {
|
||||
nvidiaCDIHookPath string
|
||||
disabledHooks DisabledHooks
|
||||
}
|
||||
|
||||
type HookCreator interface {
|
||||
Create(HookName, ...string) *Hook
|
||||
}
|
||||
|
||||
func NewHookCreator(nvidiaCDIHookPath string, disabledHooks DisabledHooks) HookCreator {
|
||||
if disabledHooks == nil {
|
||||
disabledHooks = make(DisabledHooks)
|
||||
}
|
||||
|
||||
CDIHook := &CDIHook{
|
||||
nvidiaCDIHookPath: nvidiaCDIHookPath,
|
||||
disabledHooks: disabledHooks,
|
||||
}
|
||||
|
||||
return CDIHook
|
||||
}
|
||||
|
||||
func (c CDIHook) Create(name HookName, args ...string) *Hook {
|
||||
if c.disabledHooks[name] {
|
||||
return nil
|
||||
}
|
||||
|
||||
if name == CreateSymlinks {
|
||||
if len(args) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
links := []string{}
|
||||
for _, arg := range args {
|
||||
links = append(links, "--link", arg)
|
||||
}
|
||||
args = links
|
||||
}
|
||||
|
||||
return &Hook{
|
||||
Lifecycle: cdi.CreateContainerHook,
|
||||
Path: c.nvidiaCDIHookPath,
|
||||
Args: append(c.requiredArgs(name), args...),
|
||||
}
|
||||
}
|
||||
|
||||
func (c CDIHook) requiredArgs(name HookName) []string {
|
||||
base := filepath.Base(c.nvidiaCDIHookPath)
|
||||
if base == "nvidia-ctk" {
|
||||
return []string{base, "hook", string(name)}
|
||||
}
|
||||
return []string{base, string(name)}
|
||||
}
|
||||
|
||||
// IsSupported checks whether a hook of the specified name is supported.
|
||||
// Hooks must be explicitly disabled, meaning that if no disabled hooks are
|
||||
// all hooks are supported.
|
||||
func (c CDIHook) IsSupported(h HookName) bool {
|
||||
if len(c.disabledHooks) == 0 {
|
||||
return true
|
||||
}
|
||||
return !c.disabledHooks[h]
|
||||
}
|
177
internal/hooks/hooks_test.go
Normal file
177
internal/hooks/hooks_test.go
Normal file
@ -0,0 +1,177 @@
|
||||
package hooks
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewHookCreator(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hookPath string
|
||||
disabledHooks DisabledHooks
|
||||
expectedCreator HookCreator
|
||||
}{
|
||||
{
|
||||
name: "nil disabled hooks",
|
||||
hookPath: "/usr/bin/nvidia-ctk",
|
||||
expectedCreator: &CDIHook{
|
||||
nvidiaCDIHookPath: "/usr/bin/nvidia-ctk",
|
||||
disabledHooks: DisabledHooks{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with disabled hooks",
|
||||
hookPath: "/usr/bin/nvidia-ctk",
|
||||
disabledHooks: DisabledHooks{
|
||||
EnableCudaCompat: true,
|
||||
},
|
||||
expectedCreator: &CDIHook{
|
||||
nvidiaCDIHookPath: "/usr/bin/nvidia-ctk",
|
||||
disabledHooks: DisabledHooks{
|
||||
EnableCudaCompat: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
creator := NewHookCreator(tt.hookPath, tt.disabledHooks)
|
||||
if creator == nil {
|
||||
t.Fatal("NewHookCreator returned nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCDIHook_Create(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hookPath string
|
||||
disabledHooks DisabledHooks
|
||||
hookName HookName
|
||||
args []string
|
||||
expectedHook *Hook
|
||||
}{
|
||||
{
|
||||
name: "disabled hook returns nil",
|
||||
hookPath: "/usr/bin/nvidia-ctk",
|
||||
disabledHooks: DisabledHooks{
|
||||
EnableCudaCompat: true,
|
||||
},
|
||||
hookName: EnableCudaCompat,
|
||||
expectedHook: nil,
|
||||
},
|
||||
{
|
||||
name: "create symlinks with no args returns nil",
|
||||
hookPath: "/usr/bin/nvidia-ctk",
|
||||
hookName: CreateSymlinks,
|
||||
expectedHook: nil,
|
||||
},
|
||||
{
|
||||
name: "create symlinks with args",
|
||||
hookPath: "/usr/bin/nvidia-ctk",
|
||||
hookName: CreateSymlinks,
|
||||
args: []string{"/path/to/lib1", "/path/to/lib2"},
|
||||
expectedHook: &Hook{
|
||||
Lifecycle: "createContainer",
|
||||
Path: "/usr/bin/nvidia-ctk",
|
||||
Args: []string{"nvidia-ctk", "hook", "create-symlinks", "--link", "/path/to/lib1", "--link", "/path/to/lib2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enable cuda compat",
|
||||
hookPath: "/usr/bin/nvidia-ctk",
|
||||
hookName: EnableCudaCompat,
|
||||
expectedHook: &Hook{
|
||||
Lifecycle: "createContainer",
|
||||
Path: "/usr/bin/nvidia-ctk",
|
||||
Args: []string{"nvidia-ctk", "hook", "enable-cuda-compat"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hook := &CDIHook{
|
||||
nvidiaCDIHookPath: tt.hookPath,
|
||||
disabledHooks: tt.disabledHooks,
|
||||
}
|
||||
|
||||
result := hook.Create(tt.hookName, tt.args...)
|
||||
if tt.expectedHook == nil {
|
||||
if result != nil {
|
||||
t.Errorf("expected nil hook, got %v", result)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil hook, got nil")
|
||||
}
|
||||
|
||||
if result.Lifecycle != tt.expectedHook.Lifecycle {
|
||||
t.Errorf("expected lifecycle %q, got %q", tt.expectedHook.Lifecycle, result.Lifecycle)
|
||||
}
|
||||
|
||||
if result.Path != tt.expectedHook.Path {
|
||||
t.Errorf("expected path %q, got %q", tt.expectedHook.Path, result.Path)
|
||||
}
|
||||
|
||||
if len(result.Args) != len(tt.expectedHook.Args) {
|
||||
t.Errorf("expected %d args, got %d", len(tt.expectedHook.Args), len(result.Args))
|
||||
return
|
||||
}
|
||||
|
||||
for i, arg := range tt.expectedHook.Args {
|
||||
if result.Args[i] != arg {
|
||||
t.Errorf("expected arg[%d] %q, got %q", i, arg, result.Args[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCDIHook_IsSupported(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
disabledHooks DisabledHooks
|
||||
hookName HookName
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "no disabled hooks",
|
||||
hookName: EnableCudaCompat,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "disabled hook",
|
||||
disabledHooks: DisabledHooks{
|
||||
EnableCudaCompat: true,
|
||||
},
|
||||
hookName: EnableCudaCompat,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "non-disabled hook",
|
||||
disabledHooks: DisabledHooks{
|
||||
EnableCudaCompat: true,
|
||||
},
|
||||
hookName: CreateSymlinks,
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hook := &CDIHook{
|
||||
disabledHooks: tt.disabledHooks,
|
||||
}
|
||||
|
||||
result := hook.IsSupported(tt.hookName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -22,6 +22,7 @@ import (
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
||||
@ -36,7 +37,7 @@ import (
|
||||
// NVIDIA_GDRCOPY=enabled
|
||||
//
|
||||
// If not devices are selected, no changes are made.
|
||||
func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) {
|
||||
func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver, hookCreator hooks.HookCreator) (oci.SpecModifier, error) {
|
||||
if devices := image.VisibleDevicesFromEnvVar(); len(devices) == 0 {
|
||||
logger.Infof("No modification required; no devices requested")
|
||||
return nil, nil
|
||||
@ -91,7 +92,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image
|
||||
return NewModifierFromDiscoverer(logger, discover.Merge(discoverers...))
|
||||
}
|
||||
|
||||
func getCudaCompatModeDiscoverer(logger logger.Interface, cfg *config.Config, driver *root.Driver, hookCreator discover.HookCreator) (discover.Discover, error) {
|
||||
func getCudaCompatModeDiscoverer(logger logger.Interface, cfg *config.Config, driver *root.Driver, hookCreator hooks.HookCreator) (discover.Discover, error) {
|
||||
// For legacy mode, we only include the enable-cuda-compat hook if cuda-compat-mode is set to hook.
|
||||
if cfg.NVIDIAContainerRuntimeConfig.Mode == "legacy" && cfg.NVIDIAContainerRuntimeConfig.Modes.Legacy.CUDACompatMode != config.CUDACompatModeHook {
|
||||
return nil, nil
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
||||
@ -29,7 +30,7 @@ import (
|
||||
|
||||
// NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification.
|
||||
// The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made.
|
||||
func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerImage image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) {
|
||||
func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerImage image.CUDA, driver *root.Driver, hookCreator hooks.HookCreator) (oci.SpecModifier, error) {
|
||||
if required, reason := requiresGraphicsModifier(containerImage); !required {
|
||||
logger.Infof("No graphics modifier required: %v", reason)
|
||||
return nil, nil
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
)
|
||||
|
||||
@ -29,7 +30,7 @@ import (
|
||||
type byPathHookDiscoverer struct {
|
||||
logger logger.Interface
|
||||
devRoot string
|
||||
hookCreator discover.HookCreator
|
||||
hookCreator hooks.HookCreator
|
||||
pciBusID string
|
||||
deviceNodes discover.Discover
|
||||
}
|
||||
@ -53,9 +54,14 @@ func (d *byPathHookDiscoverer) Hooks() ([]discover.Hook, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
hook := d.hookCreator.Create("create-symlinks", links...)
|
||||
hook := d.hookCreator.Create(hooks.CreateSymlinks, links...)
|
||||
|
||||
return hook.Hooks()
|
||||
return []discover.Hook{
|
||||
{
|
||||
Lifecycle: hook.Lifecycle,
|
||||
Path: hook.Path,
|
||||
Args: hook.Args,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// Mounts returns an empty slice for a full GPU
|
||||
|
@ -24,6 +24,7 @@ import (
|
||||
"github.com/NVIDIA/go-nvml/pkg/nvml"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils"
|
||||
)
|
||||
|
||||
@ -32,7 +33,7 @@ type nvsandboxutilsDGPU struct {
|
||||
uuid string
|
||||
devRoot string
|
||||
isMig bool
|
||||
hookCreator discover.HookCreator
|
||||
hookCreator hooks.HookCreator
|
||||
deviceLinks []string
|
||||
}
|
||||
|
||||
@ -112,9 +113,14 @@ func (d *nvsandboxutilsDGPU) Hooks() ([]discover.Hook, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
hook := d.hookCreator.Create("create-symlinks", d.deviceLinks...)
|
||||
hook := d.hookCreator.Create(hooks.CreateSymlinks, d.deviceLinks...)
|
||||
|
||||
return hook.Hooks()
|
||||
return []discover.Hook{
|
||||
{
|
||||
Lifecycle: hook.Lifecycle,
|
||||
Path: hook.Path,
|
||||
Args: hook.Args,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func (d *nvsandboxutilsDGPU) Mounts() ([]discover.Mount, error) {
|
||||
|
@ -17,7 +17,7 @@
|
||||
package dgpu
|
||||
|
||||
import (
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils"
|
||||
@ -26,7 +26,7 @@ import (
|
||||
type options struct {
|
||||
logger logger.Interface
|
||||
devRoot string
|
||||
hookCreator discover.HookCreator
|
||||
hookCreator hooks.HookCreator
|
||||
|
||||
isMigDevice bool
|
||||
// migCaps stores the MIG capabilities for the system.
|
||||
@ -54,7 +54,7 @@ func WithLogger(logger logger.Interface) Option {
|
||||
}
|
||||
|
||||
// WithHookCreator sets the hook creator for the library
|
||||
func WithHookCreator(hookCreator discover.HookCreator) Option {
|
||||
func WithHookCreator(hookCreator hooks.HookCreator) Option {
|
||||
return func(l *options) {
|
||||
l.hookCreator = hookCreator
|
||||
}
|
||||
|
@ -24,6 +24,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
||||
|
||||
@ -181,7 +182,7 @@ func TestDiscovererFromCSVFiles(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
hookCreator := discover.NewHookCreator("/usr/bin/nvidia-cdi-hook")
|
||||
hookCreator := hooks.NewHookCreator("/usr/bin/nvidia-cdi-hook", nil)
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
defer setGetTargetsFromCSVFiles(tc.moutSpecs)()
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
||||
)
|
||||
@ -27,7 +28,7 @@ import (
|
||||
type symlinkHook struct {
|
||||
discover.None
|
||||
logger logger.Interface
|
||||
hookCreator discover.HookCreator
|
||||
hookCreator hooks.HookCreator
|
||||
targets []string
|
||||
|
||||
// The following can be overridden for testing
|
||||
@ -48,7 +49,17 @@ func (o tegraOptions) createCSVSymlinkHooks(targets []string) discover.Discover
|
||||
|
||||
// Hooks returns a hook to create the symlinks from the required CSV files
|
||||
func (d symlinkHook) Hooks() ([]discover.Hook, error) {
|
||||
return d.hookCreator.Create("create-symlinks", d.getCSVFileSymlinks()...).Hooks()
|
||||
hook := d.hookCreator.Create(hooks.CreateSymlinks, d.getCSVFileSymlinks()...)
|
||||
if hook == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return []discover.Hook{
|
||||
{
|
||||
Lifecycle: hook.Lifecycle,
|
||||
Path: hook.Path,
|
||||
Args: hook.Args,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// getSymlinkCandidates returns a list of symlinks that are candidates for being created.
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks"
|
||||
@ -30,7 +31,7 @@ type tegraOptions struct {
|
||||
csvFiles []string
|
||||
driverRoot string
|
||||
devRoot string
|
||||
hookCreator discover.HookCreator
|
||||
hookCreator hooks.HookCreator
|
||||
ldconfigPath string
|
||||
librarySearchPaths []string
|
||||
ignorePatterns ignoreMountSpecPatterns
|
||||
@ -134,7 +135,7 @@ func WithCSVFiles(csvFiles []string) Option {
|
||||
}
|
||||
|
||||
// WithHookCreator sets the hook creator for the discoverer.
|
||||
func WithHookCreator(hookCreator discover.HookCreator) Option {
|
||||
func WithHookCreator(hookCreator hooks.HookCreator) Option {
|
||||
return func(o *tegraOptions) {
|
||||
o.hookCreator = hookCreator
|
||||
}
|
||||
|
@ -21,7 +21,7 @@ import (
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
|
||||
@ -75,7 +75,7 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hookCreator := discover.NewHookCreator(cfg.NVIDIACTKConfig.Path)
|
||||
hookCreator := hooks.NewHookCreator(cfg.NVIDIACTKConfig.Path, nil)
|
||||
|
||||
mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image)
|
||||
// We update the mode here so that we can continue passing just the config to other functions.
|
||||
|
@ -35,13 +35,3 @@ type Interface interface {
|
||||
GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) ([]specs.Device, error)
|
||||
GetDeviceSpecsByID(...string) ([]specs.Device, error)
|
||||
}
|
||||
|
||||
// A HookName refers to one of the predefined set of CDI hooks that may be
|
||||
// included in the generated CDI specification.
|
||||
type HookName string
|
||||
|
||||
const (
|
||||
// HookEnableCudaCompat refers to the hook used to enable CUDA Forward Compatibility.
|
||||
// This was added with v1.17.5 of the NVIDIA Container Toolkit.
|
||||
HookEnableCudaCompat = HookName("enable-cuda-compat")
|
||||
)
|
||||
|
@ -106,11 +106,9 @@ func (l *nvcdilib) NewDriverLibraryDiscoverer(version string) (discover.Discover
|
||||
)
|
||||
discoverers = append(discoverers, driverDotSoSymlinksDiscoverer)
|
||||
|
||||
if l.HookIsSupported(HookEnableCudaCompat) {
|
||||
// TODO: The following should use the version directly.
|
||||
cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, l.driver)
|
||||
discoverers = append(discoverers, cudaCompatLibHookDiscoverer)
|
||||
}
|
||||
// TODO: The following should use the version directly.
|
||||
cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, l.driver)
|
||||
discoverers = append(discoverers, cudaCompatLibHookDiscoverer)
|
||||
|
||||
updateLDCache, _ := discover.NewLDCacheUpdateHook(l.logger, libraries, l.hookCreator, l.ldconfigPath)
|
||||
discoverers = append(discoverers, updateLDCache)
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/dxcore"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
||||
)
|
||||
@ -39,7 +40,7 @@ var requiredDriverStoreFiles = []string{
|
||||
}
|
||||
|
||||
// newWSLDriverDiscoverer returns a Discoverer for WSL2 drivers.
|
||||
func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, hookCreator discover.HookCreator, ldconfigPath string) (discover.Discover, error) {
|
||||
func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, hookCreator hooks.HookCreator, ldconfigPath string) (discover.Discover, error) {
|
||||
err := dxcore.Init()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize dxcore: %v", err)
|
||||
@ -60,7 +61,7 @@ func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, hookCrea
|
||||
}
|
||||
|
||||
// newWSLDriverStoreDiscoverer returns a Discoverer for WSL2 drivers in the driver store associated with a dxcore adapter.
|
||||
func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, hookCreator discover.HookCreator, ldconfigPath string, driverStorePaths []string) (discover.Discover, error) {
|
||||
func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, hookCreator hooks.HookCreator, ldconfigPath string, driverStorePaths []string) (discover.Discover, error) {
|
||||
var searchPaths []string
|
||||
seen := make(map[string]bool)
|
||||
for _, path := range driverStorePaths {
|
||||
@ -108,7 +109,7 @@ type nvidiaSMISimlinkHook struct {
|
||||
discover.None
|
||||
logger logger.Interface
|
||||
mountsFrom discover.Discover
|
||||
hookCreator discover.HookCreator
|
||||
hookCreator hooks.HookCreator
|
||||
}
|
||||
|
||||
// Hooks returns a hook that creates a symlink to nvidia-smi in the driver store.
|
||||
@ -135,7 +136,13 @@ func (m nvidiaSMISimlinkHook) Hooks() ([]discover.Hook, error) {
|
||||
}
|
||||
link := "/usr/bin/nvidia-smi"
|
||||
links := []string{fmt.Sprintf("%s::%s", target, link)}
|
||||
symlinkHook := m.hookCreator.Create("create-symlinks", links...)
|
||||
symlinkHook := m.hookCreator.Create(hooks.CreateSymlinks, links...)
|
||||
|
||||
return symlinkHook.Hooks()
|
||||
return []discover.Hook{
|
||||
{
|
||||
Lifecycle: symlinkHook.Lifecycle,
|
||||
Path: symlinkHook.Path,
|
||||
Args: symlinkHook.Args,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
@ -23,13 +23,14 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
)
|
||||
|
||||
func TestNvidiaSMISymlinkHook(t *testing.T) {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
hookCreator := discover.NewHookCreator("nvidia-cdi-hook")
|
||||
hookCreator := hooks.NewHookCreator("nvidia-cdi-hook", nil)
|
||||
|
||||
errMounts := errors.New("mounts error")
|
||||
|
||||
|
@ -1,30 +0,0 @@
|
||||
/**
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package nvcdi
|
||||
|
||||
// disabledHooks allows individual hooks to be disabled.
|
||||
type disabledHooks map[HookName]bool
|
||||
|
||||
// HookIsSupported checks whether a hook of the specified name is supported.
|
||||
// Hooks must be explicitly disabled, meaning that if no disabled hooks are
|
||||
// all hooks are supported.
|
||||
func (l *nvcdilib) HookIsSupported(h HookName) bool {
|
||||
if len(l.disabledHooks) == 0 {
|
||||
return true
|
||||
}
|
||||
return !l.disabledHooks[h]
|
||||
}
|
@ -23,7 +23,7 @@ import (
|
||||
"github.com/NVIDIA/go-nvlib/pkg/nvlib/info"
|
||||
"github.com/NVIDIA/go-nvml/pkg/nvml"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils"
|
||||
@ -56,14 +56,14 @@ type nvcdilib struct {
|
||||
|
||||
mergedDeviceOptions []transform.MergedDeviceOption
|
||||
|
||||
disabledHooks disabledHooks
|
||||
hookCreator discover.HookCreator
|
||||
disabledHooks hooks.DisabledHooks
|
||||
hookCreator hooks.HookCreator
|
||||
}
|
||||
|
||||
// New creates a new nvcdi library
|
||||
func New(opts ...Option) (Interface, error) {
|
||||
l := &nvcdilib{
|
||||
disabledHooks: make(disabledHooks),
|
||||
disabledHooks: make(hooks.DisabledHooks),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(l)
|
||||
@ -81,9 +81,6 @@ func New(opts ...Option) (Interface, error) {
|
||||
if l.nvidiaCDIHookPath == "" {
|
||||
l.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook"
|
||||
}
|
||||
// create hookCreator
|
||||
l.hookCreator = discover.NewHookCreator(l.nvidiaCDIHookPath)
|
||||
|
||||
if l.driverRoot == "" {
|
||||
l.driverRoot = "/"
|
||||
}
|
||||
@ -150,7 +147,7 @@ func New(opts ...Option) (Interface, error) {
|
||||
l.vendor = "management.nvidia.com"
|
||||
}
|
||||
// Management containers in general do not require CUDA Forward compatibility.
|
||||
l.disabledHooks[HookEnableCudaCompat] = true
|
||||
l.disabledHooks[hooks.EnableCudaCompat] = true
|
||||
lib = (*managementlib)(l)
|
||||
case ModeNvml:
|
||||
lib = (*nvmllib)(l)
|
||||
@ -175,6 +172,8 @@ func New(opts ...Option) (Interface, error) {
|
||||
return nil, fmt.Errorf("unknown mode %q", l.mode)
|
||||
}
|
||||
|
||||
l.hookCreator = hooks.NewHookCreator(l.nvidiaCDIHookPath, l.disabledHooks)
|
||||
|
||||
w := wrapper{
|
||||
Interface: lib,
|
||||
vendor: l.vendor,
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"github.com/NVIDIA/go-nvlib/pkg/nvlib/info"
|
||||
"github.com/NVIDIA/go-nvml/pkg/nvml"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform"
|
||||
)
|
||||
@ -158,10 +159,10 @@ func WithLibrarySearchPaths(paths []string) Option {
|
||||
|
||||
// WithDisabledHook allows specific hooks to the disabled.
|
||||
// This option can be specified multiple times for each hook.
|
||||
func WithDisabledHook(hook HookName) Option {
|
||||
func WithDisabledHook(hook hooks.HookName) Option {
|
||||
return func(o *nvcdilib) {
|
||||
if o.disabledHooks == nil {
|
||||
o.disabledHooks = make(map[HookName]bool)
|
||||
o.disabledHooks = make(hooks.DisabledHooks)
|
||||
}
|
||||
o.disabledHooks[hook] = true
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/hooks"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
)
|
||||
|
||||
@ -28,7 +29,7 @@ type deviceFolderPermissions struct {
|
||||
logger logger.Interface
|
||||
devRoot string
|
||||
devices discover.Discover
|
||||
hookCreator discover.HookCreator
|
||||
hookCreator hooks.HookCreator
|
||||
}
|
||||
|
||||
var _ discover.Discover = (*deviceFolderPermissions)(nil)
|
||||
@ -39,7 +40,7 @@ var _ discover.Discover = (*deviceFolderPermissions)(nil)
|
||||
// The nested devices that are applicable to the NVIDIA GPU devices are:
|
||||
// - DRM devices at /dev/dri/*
|
||||
// - NVIDIA Caps devices at /dev/nvidia-caps/*
|
||||
func newDeviceFolderPermissionHookDiscoverer(logger logger.Interface, devRoot string, hookCreator discover.HookCreator, devices discover.Discover) discover.Discover {
|
||||
func newDeviceFolderPermissionHookDiscoverer(logger logger.Interface, devRoot string, hookCreator hooks.HookCreator, devices discover.Discover) discover.Discover {
|
||||
d := &deviceFolderPermissions{
|
||||
logger: logger,
|
||||
devRoot: devRoot,
|
||||
@ -72,7 +73,13 @@ func (d *deviceFolderPermissions) Hooks() ([]discover.Hook, error) {
|
||||
|
||||
hook := d.hookCreator.Create("chmod", args...)
|
||||
|
||||
return []discover.Hook{*hook}, nil
|
||||
return []discover.Hook{
|
||||
{
|
||||
Lifecycle: hook.Lifecycle,
|
||||
Path: hook.Path,
|
||||
Args: hook.Args,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *deviceFolderPermissions) getDeviceSubfolders() ([]string, error) {
|
||||
|
Loading…
Reference in New Issue
Block a user