diff --git a/internal/discover/hooks.go b/internal/discover/hooks.go index 3e0d1fc2..f20d000c 100644 --- a/internal/discover/hooks.go +++ b/internal/discover/hooks.go @@ -45,6 +45,29 @@ func (h *Hook) Hooks() ([]Hook, error) { return []Hook{*h}, nil } +type HookName string + +// DisabledHooks allows individual hooks to be disabled. +type DisabledHooks map[HookName]bool + +const ( + // HookEnableCudaCompat refers to the hook used to enable CUDA Forward Compatibility. + // This was added with v1.17.5 of the NVIDIA Container Toolkit. + HookEnableCudaCompat = HookName("enable-cuda-compat") + // directory path to be mounted into a container. + HookCreateSymlinks = HookName("create-symlinks") + // HookUpdateLDCache refers to the hook used to Update the dynamic linker + // cache inside the directory path to be mounted into a container. + HookUpdateLDCache = HookName("update-ldcache") +) + +// AllHooks maintains a future-proof list of all defined hooks. +var AllHooks = []HookName{ + HookEnableCudaCompat, + HookCreateSymlinks, + HookUpdateLDCache, +} + // Option is a function that configures the nvcdilib type Option func(*CDIHook) @@ -54,7 +77,7 @@ type CDIHook struct { } type HookCreator interface { - Create(string, ...string) *Hook + Create(HookName, ...string) *Hook } func NewHookCreator(nvidiaCDIHookPath string, debugLogging bool) HookCreator { @@ -66,7 +89,7 @@ func NewHookCreator(nvidiaCDIHookPath string, debugLogging bool) HookCreator { return CDIHook } -func (c CDIHook) Create(name string, args ...string) *Hook { +func (c CDIHook) Create(name HookName, args ...string) *Hook { if name == "create-symlinks" { if len(args) == 0 { return nil diff --git a/internal/discover/ldconfig.go b/internal/discover/ldconfig.go index 3fab927a..0c632a74 100644 --- a/internal/discover/ldconfig.go +++ b/internal/discover/ldconfig.go @@ -72,7 +72,7 @@ func createLDCacheUpdateHook(hookCreator HookCreator, ldconfig string, libraries args = append(args, "--folder", f) } - return hookCreator.Create("update-ldcache", args...) + return hookCreator.Create(HookUpdateLDCache, args...) } // getLibraryPaths extracts the library dirs from the specified mounts diff --git a/pkg/nvcdi/api.go b/pkg/nvcdi/api.go index 8c37f277..f9d9f83d 100644 --- a/pkg/nvcdi/api.go +++ b/pkg/nvcdi/api.go @@ -21,6 +21,7 @@ import ( "tags.cncf.io/container-device-interface/pkg/cdi" "tags.cncf.io/container-device-interface/specs-go" + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" ) @@ -36,14 +37,14 @@ type Interface interface { GetDeviceSpecsByID(...string) ([]specs.Device, error) } -// A HookName refers to one of the predefined set of CDI hooks that may be -// included in the generated CDI specification. -type HookName string +// HookName is an alias for the discover.HookName type. +type HookName = discover.HookName +// Aliases for the discover.HookName constants. const ( - // HookEnableCudaCompat refers to the hook used to enable CUDA Forward Compatibility. - // This was added with v1.17.5 of the NVIDIA Container Toolkit. - HookEnableCudaCompat = HookName("enable-cuda-compat") + HookEnableCudaCompat = discover.HookEnableCudaCompat + HookCreateSymlinks = discover.HookCreateSymlinks + HookUpdateLDCache = discover.HookUpdateLDCache ) // A FeatureFlag refers to a specific feature that can be toggled in the CDI api. diff --git a/pkg/nvcdi/driver-wsl.go b/pkg/nvcdi/driver-wsl.go index 97e267c5..26a67b11 100644 --- a/pkg/nvcdi/driver-wsl.go +++ b/pkg/nvcdi/driver-wsl.go @@ -135,7 +135,7 @@ func (m nvidiaSMISimlinkHook) Hooks() ([]discover.Hook, error) { } link := "/usr/bin/nvidia-smi" links := []string{fmt.Sprintf("%s::%s", target, link)} - symlinkHook := m.hookCreator.Create("create-symlinks", links...) + symlinkHook := m.hookCreator.Create(HookCreateSymlinks, links...) return symlinkHook.Hooks() } diff --git a/pkg/nvcdi/hooks.go b/pkg/nvcdi/hooks.go index a4620dc8..20ef59a4 100644 --- a/pkg/nvcdi/hooks.go +++ b/pkg/nvcdi/hooks.go @@ -16,9 +16,6 @@ package nvcdi -// disabledHooks allows individual hooks to be disabled. -type disabledHooks map[HookName]bool - // HookIsSupported checks whether a hook of the specified name is supported. // Hooks must be explicitly disabled, meaning that if no disabled hooks are // all hooks are supported. diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index cbbf2419..15d6361c 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -58,16 +58,13 @@ type nvcdilib struct { featureFlags map[FeatureFlag]bool - disabledHooks disabledHooks + disabledHooks []discover.HookName hookCreator discover.HookCreator } // New creates a new nvcdi library func New(opts ...Option) (Interface, error) { - l := &nvcdilib{ - disabledHooks: make(disabledHooks), - featureFlags: make(map[FeatureFlag]bool), - } + l := &nvcdilib{} for _, opt := range opts { opt(l) } @@ -136,7 +133,7 @@ func New(opts ...Option) (Interface, error) { l.vendor = "management.nvidia.com" } // Management containers in general do not require CUDA Forward compatibility. - l.disabledHooks[HookEnableCudaCompat] = true + l.disabledHooks = append(l.disabledHooks, discover.HookEnableCudaCompat) lib = (*managementlib)(l) case ModeNvml: lib = (*nvmllib)(l) diff --git a/pkg/nvcdi/options.go b/pkg/nvcdi/options.go index 7c76f7fc..25494ebf 100644 --- a/pkg/nvcdi/options.go +++ b/pkg/nvcdi/options.go @@ -21,6 +21,7 @@ import ( "github.com/NVIDIA/go-nvlib/pkg/nvlib/info" "github.com/NVIDIA/go-nvml/pkg/nvml" + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" ) @@ -158,12 +159,9 @@ func WithLibrarySearchPaths(paths []string) Option { // WithDisabledHook allows specific hooks to the disabled. // This option can be specified multiple times for each hook. -func WithDisabledHook(hook HookName) Option { +func WithDisabledHook[T string | HookName](hook T) Option { return func(o *nvcdilib) { - if o.disabledHooks == nil { - o.disabledHooks = make(map[HookName]bool) - } - o.disabledHooks[hook] = true + o.disabledHooks = append(o.disabledHooks, discover.HookName(hook)) } }