diff --git a/cmd/nvidia-container-runtime-hook/container_config.go b/cmd/nvidia-container-runtime-hook/container_config.go index 20a3b09f..36278dc0 100644 --- a/cmd/nvidia-container-runtime-hook/container_config.go +++ b/cmd/nvidia-container-runtime-hook/container_config.go @@ -45,7 +45,7 @@ type nvidiaConfig struct { type containerConfig struct { Pid int Rootfs string - Env map[string]string + Image image.CUDA Nvidia *nvidiaConfig } @@ -362,7 +362,7 @@ func getContainerConfig(hook HookConfig) (config containerConfig) { return containerConfig{ Pid: h.Pid, Rootfs: s.Root.Path, - Env: image, + Image: image, Nvidia: getNvidiaConfig(&hook, image, s.Mounts, privileged), } } diff --git a/cmd/nvidia-container-runtime-hook/main.go b/cmd/nvidia-container-runtime-hook/main.go index 9e5d4346..99583aa5 100644 --- a/cmd/nvidia-container-runtime-hook/main.go +++ b/cmd/nvidia-container-runtime-hook/main.go @@ -78,10 +78,6 @@ func doPrestart() { } cli := hook.NvidiaContainerCLI - if !hook.NVIDIAContainerRuntimeHook.SkipModeDetection && info.ResolveAutoMode(&logInterceptor{}, hook.NVIDIAContainerRuntime.Mode) != "legacy" { - log.Panicln("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead.") - } - container := getContainerConfig(*hook) nvidia := container.Nvidia if nvidia == nil { @@ -89,6 +85,10 @@ func doPrestart() { return } + if !hook.NVIDIAContainerRuntimeHook.SkipModeDetection && info.ResolveAutoMode(&logInterceptor{}, hook.NVIDIAContainerRuntime.Mode, container.Image) != "legacy" { + log.Panicln("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead.") + } + rootfs := getRootfsPath(container) args := []string{getCLIPath(cli)} diff --git a/internal/info/auto.go b/internal/info/auto.go index 7ec8bb86..396d127b 100644 --- a/internal/info/auto.go +++ b/internal/info/auto.go @@ -17,26 +17,50 @@ package info import ( + "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/info" ) +// infoInterface provides an alias for mocking. +// +//go:generate moq -stub -out info-interface_mock.go . infoInterface +type infoInterface info.Interface + +type resolver struct { + logger logger.Interface + info info.Interface +} + // ResolveAutoMode determines the correct mode for the platform if set to "auto" -func ResolveAutoMode(logger logger.Interface, mode string) (rmode string) { +func ResolveAutoMode(logger logger.Interface, mode string, image image.CUDA) (rmode string) { + nvinfo := info.New() + r := resolver{ + logger: logger, + info: nvinfo, + } + return r.resolveMode(mode, image) +} + +// resolveMode determines the correct mode for the platform if set to "auto" +func (r resolver) resolveMode(mode string, image image.CUDA) (rmode string) { if mode != "auto" { return mode } defer func() { - logger.Infof("Auto-detected mode as '%v'", rmode) + r.logger.Infof("Auto-detected mode as '%v'", rmode) }() - nvinfo := info.New() + if onlyFullyQualifiedCDIDevices(image) { + return "cdi" + } - isTegra, reason := nvinfo.IsTegraSystem() - logger.Debugf("Is Tegra-based system? %v: %v", isTegra, reason) + isTegra, reason := r.info.IsTegraSystem() + r.logger.Debugf("Is Tegra-based system? %v: %v", isTegra, reason) - hasNVML, reason := nvinfo.HasNvml() - logger.Debugf("Has NVML? %v: %v", hasNVML, reason) + hasNVML, reason := r.info.HasNvml() + r.logger.Debugf("Has NVML? %v: %v", hasNVML, reason) if isTegra && !hasNVML { return "csv" @@ -44,3 +68,14 @@ func ResolveAutoMode(logger logger.Interface, mode string) (rmode string) { return "legacy" } + +func onlyFullyQualifiedCDIDevices(image image.CUDA) bool { + var hasCDIdevice bool + for _, device := range image.DevicesFromEnvvars("NVIDIA_VISIBLE_DEVICES").List() { + if !cdi.IsQualifiedName(device) { + return false + } + hasCDIdevice = true + } + return hasCDIdevice +} diff --git a/internal/info/auto_test.go b/internal/info/auto_test.go index 2a905c90..71919c01 100644 --- a/internal/info/auto_test.go +++ b/internal/info/auto_test.go @@ -19,8 +19,10 @@ package info import ( "testing" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" + "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/info" ) func TestResolveAutoMode(t *testing.T) { @@ -30,23 +32,123 @@ func TestResolveAutoMode(t *testing.T) { description string mode string expectedMode string + info info.Interface + image image.CUDA }{ { description: "non-auto resolves to input", mode: "not-auto", expectedMode: "not-auto", }, - // TODO: The following test is brittle in that it will break on Tegra-based systems. - // { - // description: "auto resolves to legacy", - // mode: "auto", - // expectedMode: "legacy", - // }, + { + description: "nvml non-tegra resolves to legacy", + mode: "auto", + info: &infoInterfaceMock{ + HasNvmlFunc: func() (bool, string) { + return true, "nvml" + }, + IsTegraSystemFunc: func() (bool, string) { + return false, "tegra" + }, + }, + expectedMode: "legacy", + }, + { + description: "non-nvml non-tegra resolves to legacy", + mode: "auto", + info: &infoInterfaceMock{ + HasNvmlFunc: func() (bool, string) { + return false, "nvml" + }, + IsTegraSystemFunc: func() (bool, string) { + return false, "tegra" + }, + }, + expectedMode: "legacy", + }, + { + description: "nvml tegra resolves to legacy", + mode: "auto", + info: &infoInterfaceMock{ + HasNvmlFunc: func() (bool, string) { + return true, "nvml" + }, + IsTegraSystemFunc: func() (bool, string) { + return true, "tegra" + }, + }, + expectedMode: "legacy", + }, + { + description: "non-nvml tegra resolves to csv", + mode: "auto", + info: &infoInterfaceMock{ + HasNvmlFunc: func() (bool, string) { + return false, "nvml" + }, + IsTegraSystemFunc: func() (bool, string) { + return true, "tegra" + }, + }, + expectedMode: "csv", + }, + { + description: "cdi devices resolves to cdi", + mode: "auto", + expectedMode: "cdi", + image: image.CUDA{ + "NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=all", + }, + }, + { + description: "multiple cdi devices resolves to cdi", + mode: "auto", + expectedMode: "cdi", + image: image.CUDA{ + "NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0,nvidia.com/gpu=1", + }, + }, + { + description: "at least one non-cdi device resolves to legacy", + mode: "auto", + image: image.CUDA{ + "NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0,0", + }, + info: &infoInterfaceMock{ + HasNvmlFunc: func() (bool, string) { + return true, "nvml" + }, + IsTegraSystemFunc: func() (bool, string) { + return true, "tegra" + }, + }, + expectedMode: "legacy", + }, + { + description: "at least one non-cdi device resolves to csv", + mode: "auto", + image: image.CUDA{ + "NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0,0", + }, + info: &infoInterfaceMock{ + HasNvmlFunc: func() (bool, string) { + return false, "nvml" + }, + IsTegraSystemFunc: func() (bool, string) { + return true, "tegra" + }, + }, + expectedMode: "csv", + }, } for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { - mode := ResolveAutoMode(logger, tc.mode) + r := resolver{ + logger: logger, + info: tc.info, + } + mode := r.resolveMode(tc.mode, tc.image) require.EqualValues(t, tc.expectedMode, mode) }) } diff --git a/internal/info/info-interface_mock.go b/internal/info/info-interface_mock.go new file mode 100644 index 00000000..2375a4f1 --- /dev/null +++ b/internal/info/info-interface_mock.go @@ -0,0 +1,153 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package info + +import ( + "sync" +) + +// Ensure, that infoInterfaceMock does implement infoInterface. +// If this is not the case, regenerate this file with moq. +var _ infoInterface = &infoInterfaceMock{} + +// infoInterfaceMock is a mock implementation of infoInterface. +// +// func TestSomethingThatUsesinfoInterface(t *testing.T) { +// +// // make and configure a mocked infoInterface +// mockedinfoInterface := &infoInterfaceMock{ +// HasDXCoreFunc: func() (bool, string) { +// panic("mock out the HasDXCore method") +// }, +// HasNvmlFunc: func() (bool, string) { +// panic("mock out the HasNvml method") +// }, +// IsTegraSystemFunc: func() (bool, string) { +// panic("mock out the IsTegraSystem method") +// }, +// } +// +// // use mockedinfoInterface in code that requires infoInterface +// // and then make assertions. +// +// } +type infoInterfaceMock struct { + // HasDXCoreFunc mocks the HasDXCore method. + HasDXCoreFunc func() (bool, string) + + // HasNvmlFunc mocks the HasNvml method. + HasNvmlFunc func() (bool, string) + + // IsTegraSystemFunc mocks the IsTegraSystem method. + IsTegraSystemFunc func() (bool, string) + + // calls tracks calls to the methods. + calls struct { + // HasDXCore holds details about calls to the HasDXCore method. + HasDXCore []struct { + } + // HasNvml holds details about calls to the HasNvml method. + HasNvml []struct { + } + // IsTegraSystem holds details about calls to the IsTegraSystem method. + IsTegraSystem []struct { + } + } + lockHasDXCore sync.RWMutex + lockHasNvml sync.RWMutex + lockIsTegraSystem sync.RWMutex +} + +// HasDXCore calls HasDXCoreFunc. +func (mock *infoInterfaceMock) HasDXCore() (bool, string) { + callInfo := struct { + }{} + mock.lockHasDXCore.Lock() + mock.calls.HasDXCore = append(mock.calls.HasDXCore, callInfo) + mock.lockHasDXCore.Unlock() + if mock.HasDXCoreFunc == nil { + var ( + bOut bool + sOut string + ) + return bOut, sOut + } + return mock.HasDXCoreFunc() +} + +// HasDXCoreCalls gets all the calls that were made to HasDXCore. +// Check the length with: +// +// len(mockedinfoInterface.HasDXCoreCalls()) +func (mock *infoInterfaceMock) HasDXCoreCalls() []struct { +} { + var calls []struct { + } + mock.lockHasDXCore.RLock() + calls = mock.calls.HasDXCore + mock.lockHasDXCore.RUnlock() + return calls +} + +// HasNvml calls HasNvmlFunc. +func (mock *infoInterfaceMock) HasNvml() (bool, string) { + callInfo := struct { + }{} + mock.lockHasNvml.Lock() + mock.calls.HasNvml = append(mock.calls.HasNvml, callInfo) + mock.lockHasNvml.Unlock() + if mock.HasNvmlFunc == nil { + var ( + bOut bool + sOut string + ) + return bOut, sOut + } + return mock.HasNvmlFunc() +} + +// HasNvmlCalls gets all the calls that were made to HasNvml. +// Check the length with: +// +// len(mockedinfoInterface.HasNvmlCalls()) +func (mock *infoInterfaceMock) HasNvmlCalls() []struct { +} { + var calls []struct { + } + mock.lockHasNvml.RLock() + calls = mock.calls.HasNvml + mock.lockHasNvml.RUnlock() + return calls +} + +// IsTegraSystem calls IsTegraSystemFunc. +func (mock *infoInterfaceMock) IsTegraSystem() (bool, string) { + callInfo := struct { + }{} + mock.lockIsTegraSystem.Lock() + mock.calls.IsTegraSystem = append(mock.calls.IsTegraSystem, callInfo) + mock.lockIsTegraSystem.Unlock() + if mock.IsTegraSystemFunc == nil { + var ( + bOut bool + sOut string + ) + return bOut, sOut + } + return mock.IsTegraSystemFunc() +} + +// IsTegraSystemCalls gets all the calls that were made to IsTegraSystem. +// Check the length with: +// +// len(mockedinfoInterface.IsTegraSystemCalls()) +func (mock *infoInterfaceMock) IsTegraSystemCalls() []struct { +} { + var calls []struct { + } + mock.lockIsTegraSystem.RLock() + calls = mock.calls.IsTegraSystem + mock.lockIsTegraSystem.RUnlock() + return calls +} diff --git a/internal/modifier/gds.go b/internal/modifier/gds.go index 5334346c..9ef15992 100644 --- a/internal/modifier/gds.go +++ b/internal/modifier/gds.go @@ -32,17 +32,7 @@ const ( // NewGDSModifier creates the modifiers for GDS devices. // If the spec does not contain the NVIDIA_GDS=enabled environment variable no changes are made. -func NewGDSModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) { - rawSpec, err := ociSpec.Load() - if err != nil { - return nil, fmt.Errorf("failed to load OCI spec: %v", err) - } - - image, err := image.NewCUDAImageFromSpec(rawSpec) - if err != nil { - return nil, err - } - +func NewGDSModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) { if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices.List()) == 0 { logger.Infof("No modification required; no devices requested") return nil, nil diff --git a/internal/modifier/graphics.go b/internal/modifier/graphics.go index e80de124..57776a72 100644 --- a/internal/modifier/graphics.go +++ b/internal/modifier/graphics.go @@ -28,17 +28,7 @@ import ( // NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification. // The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made. -func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) { - rawSpec, err := ociSpec.Load() - if err != nil { - return nil, fmt.Errorf("failed to load OCI spec: %v", err) - } - - image, err := image.NewCUDAImageFromSpec(rawSpec) - if err != nil { - return nil, err - } - +func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) { if required, reason := requiresGraphicsModifier(image); !required { logger.Infof("No graphics modifier required: %v", reason) return nil, nil diff --git a/internal/modifier/mofed.go b/internal/modifier/mofed.go index 92796506..5cc69169 100644 --- a/internal/modifier/mofed.go +++ b/internal/modifier/mofed.go @@ -32,17 +32,7 @@ const ( // NewMOFEDModifier creates the modifiers for MOFED devices. // If the spec does not contain the NVIDIA_MOFED=enabled environment variable no changes are made. -func NewMOFEDModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) { - rawSpec, err := ociSpec.Load() - if err != nil { - return nil, fmt.Errorf("failed to load OCI spec: %v", err) - } - - image, err := image.NewCUDAImageFromSpec(rawSpec) - if err != nil { - return nil, err - } - +func NewMOFEDModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) { if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices.List()) == 0 { logger.Infof("No modification required; no devices requested") return nil, nil diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index dd1a847a..9832fd66 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/NVIDIA/nvidia-container-toolkit/internal/config" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/info" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/modifier" @@ -61,7 +62,17 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv // newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config. func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec, argv []string) (oci.SpecModifier, error) { - mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode) + rawSpec, err := ociSpec.Load() + if err != nil { + return nil, fmt.Errorf("failed to load OCI spec: %v", err) + } + + image, err := image.NewCUDAImageFromSpec(rawSpec) + if err != nil { + return nil, err + } + + mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image) modeModifier, err := newModeModifier(logger, mode, cfg, ociSpec, argv) if err != nil { return nil, err @@ -71,17 +82,17 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp return modeModifier, nil } - graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, ociSpec) + graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image) if err != nil { return nil, err } - gdsModifier, err := modifier.NewGDSModifier(logger, cfg, ociSpec) + gdsModifier, err := modifier.NewGDSModifier(logger, cfg, image) if err != nil { return nil, err } - mofedModifier, err := modifier.NewMOFEDModifier(logger, cfg, ociSpec) + mofedModifier, err := modifier.NewMOFEDModifier(logger, cfg, image) if err != nil { return nil, err }