diff --git a/internal/info/auto.go b/internal/info/auto.go index 77c5fb43..c6800da1 100644 --- a/internal/info/auto.go +++ b/internal/info/auto.go @@ -17,66 +17,40 @@ package info import ( - "github.com/NVIDIA/go-nvlib/pkg/nvlib/device" "github.com/NVIDIA/go-nvlib/pkg/nvlib/info" - "github.com/NVIDIA/go-nvml/pkg/nvml" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" ) -type resolver struct { - logger logger.Interface - info info.PropertyExtractor -} - // ResolveAutoMode determines the correct mode for the platform if set to "auto" func ResolveAutoMode(logger logger.Interface, mode string, image image.CUDA) (rmode string) { - nvinfo := info.New() - nvmllib := nvml.New() - devicelib := device.New( - device.WithNvml(nvmllib), - ) - - info := additionalInfo{ - Interface: nvinfo, - nvmllib: nvmllib, - devicelib: devicelib, - } - - r := resolver{ - logger: logger, - info: info, - } - return r.resolveMode(mode, image) + return resolveMode(logger, mode, image, nil) } -// resolveMode determines the correct mode for the platform if set to "auto" -func (r resolver) resolveMode(mode string, image image.CUDA) (rmode string) { +func resolveMode(logger logger.Interface, mode string, image image.CUDA, propertyExtractor info.PropertyExtractor) (rmode string) { if mode != "auto" { - r.logger.Infof("Using requested mode '%s'", mode) + logger.Infof("Using requested mode '%s'", mode) return mode } defer func() { - r.logger.Infof("Auto-detected mode as '%v'", rmode) + logger.Infof("Auto-detected mode as '%v'", rmode) }() if image.OnlyFullyQualifiedCDIDevices() { return "cdi" } - isTegra, reason := r.info.HasTegraFiles() - r.logger.Debugf("Is Tegra-based system? %v: %v", isTegra, reason) + nvinfo := info.New( + info.WithLogger(logger), + info.WithPropertyExtractor(propertyExtractor), + ) - hasNVML, reason := r.info.HasNvml() - r.logger.Debugf("Has NVML? %v: %v", hasNVML, reason) - - usesNVGPUModule, reason := r.info.UsesOnlyNVGPUModule() - r.logger.Debugf("Uses nvgpu kernel module? %v: %v", usesNVGPUModule, reason) - - if (isTegra && !hasNVML) || usesNVGPUModule { + switch nvinfo.ResolvePlatform() { + case info.PlatformNVML, info.PlatformWSL: + return "legacy" + case info.PlatformTegra: return "csv" } - return "legacy" } diff --git a/internal/info/auto_test.go b/internal/info/auto_test.go index 947492a0..4fbfcde4 100644 --- a/internal/info/auto_test.go +++ b/internal/info/auto_test.go @@ -203,10 +203,16 @@ func TestResolveAutoMode(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { - info := &info.PropertyExtractorMock{ + properties := &info.PropertyExtractorMock{ HasNvmlFunc: func() (bool, string) { return tc.info["nvml"], "nvml" }, + HasDXCoreFunc: func() (bool, string) { + return tc.info["dxcore"], "dxcore" + }, + IsTegraSystemFunc: func() (bool, string) { + return tc.info["tegra"], "tegra" + }, HasTegraFilesFunc: func() (bool, string) { return tc.info["tegra"], "tegra" }, @@ -215,11 +221,6 @@ func TestResolveAutoMode(t *testing.T) { }, } - r := resolver{ - logger: logger, - info: info, - } - var mounts []specs.Mount for _, d := range tc.mounts { mount := specs.Mount{ @@ -232,7 +233,7 @@ func TestResolveAutoMode(t *testing.T) { image.WithEnvMap(tc.envmap), image.WithMounts(mounts), ) - mode := r.resolveMode(tc.mode, image) + mode := resolveMode(logger, tc.mode, image, properties) require.EqualValues(t, tc.expectedMode, mode) }) } diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 5c85199e..e3c162e7 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -91,7 +91,12 @@ func New(opts ...Option) (Interface, error) { l.nvidiaCTKPath = "/usr/bin/nvidia-ctk" } if l.infolib == nil { - l.infolib = info.New() + l.infolib = info.New( + info.WithRoot(l.driverRoot), + info.WithLogger(l.logger), + info.WithNvmlLib(l.nvmllib), + info.WithDeviceLib(l.devicelib), + ) } l.driver = root.New( @@ -184,26 +189,19 @@ func (l *nvcdilib) resolveMode() (rmode string) { return l.mode } defer func() { - l.logger.Infof("Auto-detected mode as %q", rmode) + l.logger.Infof("Auto-detected mode as '%v'", rmode) }() - isWSL, reason := l.infolib.HasDXCore() - l.logger.Debugf("Is WSL-based system? %v: %v", isWSL, reason) - - if isWSL { + platform := l.infolib.ResolvePlatform() + switch platform { + case info.PlatformNVML: + return ModeNvml + case info.PlatformTegra: + return ModeCSV + case info.PlatformWSL: return ModeWsl } - - isNvml, reason := l.infolib.HasNvml() - l.logger.Debugf("Is NVML-based system? %v: %v", isNvml, reason) - - isTegra, reason := l.infolib.HasTegraFiles() - l.logger.Debugf("Is Tegra-based system? %v: %v", isTegra, reason) - - if isTegra && !isNvml { - return ModeCSV - } - + l.logger.Warningf("Unsupported platform detected: %v; assuming %v", platform, ModeNvml) return ModeNvml } diff --git a/pkg/nvcdi/lib_test.go b/pkg/nvcdi/lib_test.go index 2384e4ab..61246336 100644 --- a/pkg/nvcdi/lib_test.go +++ b/pkg/nvcdi/lib_test.go @@ -29,11 +29,12 @@ func TestResolveMode(t *testing.T) { logger, _ := testlog.NewNullLogger() testCases := []struct { - mode string - isTegra bool - hasDXCore bool - hasNVML bool - expected string + mode string + isTegra bool + hasDXCore bool + hasNVML bool + UsesOnlyNVGPUModule bool + expected string }{ { mode: "auto", @@ -67,6 +68,13 @@ func TestResolveMode(t *testing.T) { isTegra: false, expected: "nvml", }, + { + mode: "auto", + isTegra: true, + hasNVML: true, + UsesOnlyNVGPUModule: true, + expected: "csv", + }, { mode: "nvml", hasDXCore: true, @@ -90,24 +98,16 @@ func TestResolveMode(t *testing.T) { l := nvcdilib{ logger: logger, mode: tc.mode, - infolib: infoMock{ - PropertyExtractor: &info.PropertyExtractorMock{ - HasDXCoreFunc: func() (bool, string) { return tc.hasDXCore, "" }, - HasNvmlFunc: func() (bool, string) { return tc.hasNVML, "" }, - HasTegraFilesFunc: func() (bool, string) { return tc.isTegra, "" }, - }, - }, + infolib: info.New( + info.WithPropertyExtractor(&info.PropertyExtractorMock{ + HasDXCoreFunc: func() (bool, string) { return tc.hasDXCore, "" }, + HasNvmlFunc: func() (bool, string) { return tc.hasNVML, "" }, + HasTegraFilesFunc: func() (bool, string) { return tc.isTegra, "" }, + UsesOnlyNVGPUModuleFunc: func() (bool, string) { return tc.UsesOnlyNVGPUModule, "" }, + }), + ), } - require.Equal(t, tc.expected, l.resolveMode()) }) } } - -type infoMock struct { - info.PropertyExtractor -} - -func (i infoMock) ResolvePlatform() info.Platform { - return "not-implemented" -}