diff --git a/cmd/nvidia-container-runtime-hook/container_config.go b/cmd/nvidia-container-runtime-hook/container_config.go index 20a3b09f..36278dc0 100644 --- a/cmd/nvidia-container-runtime-hook/container_config.go +++ b/cmd/nvidia-container-runtime-hook/container_config.go @@ -45,7 +45,7 @@ type nvidiaConfig struct { type containerConfig struct { Pid int Rootfs string - Env map[string]string + Image image.CUDA Nvidia *nvidiaConfig } @@ -362,7 +362,7 @@ func getContainerConfig(hook HookConfig) (config containerConfig) { return containerConfig{ Pid: h.Pid, Rootfs: s.Root.Path, - Env: image, + Image: image, Nvidia: getNvidiaConfig(&hook, image, s.Mounts, privileged), } } diff --git a/cmd/nvidia-container-runtime-hook/main.go b/cmd/nvidia-container-runtime-hook/main.go index 9e5d4346..99583aa5 100644 --- a/cmd/nvidia-container-runtime-hook/main.go +++ b/cmd/nvidia-container-runtime-hook/main.go @@ -78,10 +78,6 @@ func doPrestart() { } cli := hook.NvidiaContainerCLI - if !hook.NVIDIAContainerRuntimeHook.SkipModeDetection && info.ResolveAutoMode(&logInterceptor{}, hook.NVIDIAContainerRuntime.Mode) != "legacy" { - log.Panicln("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead.") - } - container := getContainerConfig(*hook) nvidia := container.Nvidia if nvidia == nil { @@ -89,6 +85,10 @@ func doPrestart() { return } + if !hook.NVIDIAContainerRuntimeHook.SkipModeDetection && info.ResolveAutoMode(&logInterceptor{}, hook.NVIDIAContainerRuntime.Mode, container.Image) != "legacy" { + log.Panicln("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead.") + } + rootfs := getRootfsPath(container) args := []string{getCLIPath(cli)} diff --git a/internal/info/auto.go b/internal/info/auto.go index b9a9a5d8..396d127b 100644 --- a/internal/info/auto.go +++ b/internal/info/auto.go @@ -17,7 +17,9 @@ package info import ( + "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/info" ) @@ -32,17 +34,17 @@ type resolver struct { } // ResolveAutoMode determines the correct mode for the platform if set to "auto" -func ResolveAutoMode(logger logger.Interface, mode string) (rmode string) { +func ResolveAutoMode(logger logger.Interface, mode string, image image.CUDA) (rmode string) { nvinfo := info.New() r := resolver{ logger: logger, info: nvinfo, } - return r.resolveMode(mode) + return r.resolveMode(mode, image) } // resolveMode determines the correct mode for the platform if set to "auto" -func (r resolver) resolveMode(mode string) (rmode string) { +func (r resolver) resolveMode(mode string, image image.CUDA) (rmode string) { if mode != "auto" { return mode } @@ -50,6 +52,10 @@ func (r resolver) resolveMode(mode string) (rmode string) { r.logger.Infof("Auto-detected mode as '%v'", rmode) }() + if onlyFullyQualifiedCDIDevices(image) { + return "cdi" + } + isTegra, reason := r.info.IsTegraSystem() r.logger.Debugf("Is Tegra-based system? %v: %v", isTegra, reason) @@ -62,3 +68,14 @@ func (r resolver) resolveMode(mode string) (rmode string) { return "legacy" } + +func onlyFullyQualifiedCDIDevices(image image.CUDA) bool { + var hasCDIdevice bool + for _, device := range image.DevicesFromEnvvars("NVIDIA_VISIBLE_DEVICES").List() { + if !cdi.IsQualifiedName(device) { + return false + } + hasCDIdevice = true + } + return hasCDIdevice +} diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index ed32585b..9832fd66 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -72,7 +72,7 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp return nil, err } - mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode) + mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image) modeModifier, err := newModeModifier(logger, mode, cfg, ociSpec, argv) if err != nil { return nil, err