diff --git a/internal/config/features.go b/internal/config/features.go index 80d3c95a..d5ce22df 100644 --- a/internal/config/features.go +++ b/internal/config/features.go @@ -21,6 +21,9 @@ type features struct { // DisableImexChannelCreation ensures that the implicit creation of // requested IMEX channels is skipped when invoking the nvidia-container-cli. DisableImexChannelCreation *feature `toml:"disable-imex-channel-creation,omitempty"` + // RequireNvidiaKernelModules indicates that the NVIDIA kernel module must be + // loaded for the NVIDIA Container Runtime to perform any OCI spec modifications. + RequireNvidiaKernelModules *feature `toml:"require-nvidia-kernel-module,omitempty"` } //nolint:unused diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index 50c19a4f..436db3fd 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -18,6 +18,7 @@ package runtime import ( "fmt" + "os" "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" @@ -41,6 +42,11 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv return lowLevelRuntime, nil } + if cfg.Features.RequireNvidiaKernelModules.IsEnabled() && !isNvidiaModuleLoaded() { + logger.Tracef("NVIDIA driver modules are not yet loaded; skipping modifer") + return lowLevelRuntime, nil + } + ociSpec, err := oci.NewSpec(logger, argv) if err != nil { return nil, fmt.Errorf("error constructing OCI specification: %v", err) @@ -62,6 +68,19 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv return r, nil } +// isNvidiaKernelModuleLoaded checks whether the NVIDIA GPU driver is installed +// and the kernel module is available. +func isNvidiaModuleLoaded() bool { + // TODO: This was implemented as: + // cat /proc/modules | grep -e \"^nvidia \" >/dev/null 2>&1 + // if [ "${?}" != "0" ]; then + // echo "nvidia driver modules are not yet loaded, invoking runc directly" + // exec runc "$@" + // fi + _, err := os.Stat("/proc/driver/nvidia/version") + return err == nil +} + // newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config. func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec, driver *root.Driver) (oci.SpecModifier, error) { rawSpec, err := ociSpec.Load()