diff --git a/internal/info/auto.go b/internal/info/auto.go index 5d1426d9..c6800da1 100644 --- a/internal/info/auto.go +++ b/internal/info/auto.go @@ -17,75 +17,40 @@ package info import ( - "github.com/NVIDIA/go-nvlib/pkg/nvlib/device" "github.com/NVIDIA/go-nvlib/pkg/nvlib/info" - "github.com/NVIDIA/go-nvml/pkg/nvml" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" ) -// infoInterface provides an alias for mocking. -// -//go:generate moq -stub -out info-interface_mock.go . infoInterface -type infoInterface interface { - info.Interface - // UsesNVGPUModule indicates whether the system is using the nvgpu kernel module - UsesNVGPUModule() (bool, string) -} - -type resolver struct { - logger logger.Interface - info infoInterface -} - // ResolveAutoMode determines the correct mode for the platform if set to "auto" func ResolveAutoMode(logger logger.Interface, mode string, image image.CUDA) (rmode string) { - nvinfo := info.New() - nvmllib := nvml.New() - devicelib := device.New( - device.WithNvml(nvmllib), - ) - - info := additionalInfo{ - Interface: nvinfo, - nvmllib: nvmllib, - devicelib: devicelib, - } - - r := resolver{ - logger: logger, - info: info, - } - return r.resolveMode(mode, image) + return resolveMode(logger, mode, image, nil) } -// resolveMode determines the correct mode for the platform if set to "auto" -func (r resolver) resolveMode(mode string, image image.CUDA) (rmode string) { +func resolveMode(logger logger.Interface, mode string, image image.CUDA, propertyExtractor info.PropertyExtractor) (rmode string) { if mode != "auto" { - r.logger.Infof("Using requested mode '%s'", mode) + logger.Infof("Using requested mode '%s'", mode) return mode } defer func() { - r.logger.Infof("Auto-detected mode as '%v'", rmode) + logger.Infof("Auto-detected mode as '%v'", rmode) }() if image.OnlyFullyQualifiedCDIDevices() { return "cdi" } - isTegra, reason := r.info.IsTegraSystem() - r.logger.Debugf("Is Tegra-based system? %v: %v", isTegra, reason) + nvinfo := info.New( + info.WithLogger(logger), + info.WithPropertyExtractor(propertyExtractor), + ) - hasNVML, reason := r.info.HasNvml() - r.logger.Debugf("Has NVML? %v: %v", hasNVML, reason) - - usesNVGPUModule, reason := r.info.UsesNVGPUModule() - r.logger.Debugf("Uses nvgpu kernel module? %v: %v", usesNVGPUModule, reason) - - if (isTegra && !hasNVML) || usesNVGPUModule { + switch nvinfo.ResolvePlatform() { + case info.PlatformNVML, info.PlatformWSL: + return "legacy" + case info.PlatformTegra: return "csv" } - return "legacy" } diff --git a/internal/info/auto_test.go b/internal/info/auto_test.go index e986ced7..4fbfcde4 100644 --- a/internal/info/auto_test.go +++ b/internal/info/auto_test.go @@ -19,6 +19,7 @@ package info import ( "testing" + "github.com/NVIDIA/go-nvlib/pkg/nvlib/info" "github.com/opencontainers/runtime-spec/specs-go" testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" @@ -202,23 +203,24 @@ func TestResolveAutoMode(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { - info := &infoInterfaceMock{ + properties := &info.PropertyExtractorMock{ HasNvmlFunc: func() (bool, string) { return tc.info["nvml"], "nvml" }, + HasDXCoreFunc: func() (bool, string) { + return tc.info["dxcore"], "dxcore" + }, IsTegraSystemFunc: func() (bool, string) { return tc.info["tegra"], "tegra" }, - UsesNVGPUModuleFunc: func() (bool, string) { + HasTegraFilesFunc: func() (bool, string) { + return tc.info["tegra"], "tegra" + }, + UsesOnlyNVGPUModuleFunc: func() (bool, string) { return tc.info["nvgpu"], "nvgpu" }, } - r := resolver{ - logger: logger, - info: info, - } - var mounts []specs.Mount for _, d := range tc.mounts { mount := specs.Mount{ @@ -231,7 +233,7 @@ func TestResolveAutoMode(t *testing.T) { image.WithEnvMap(tc.envmap), image.WithMounts(mounts), ) - mode := r.resolveMode(tc.mode, image) + mode := resolveMode(logger, tc.mode, image, properties) require.EqualValues(t, tc.expectedMode, mode) }) } diff --git a/internal/info/info-interface_mock.go b/internal/info/info-interface_mock.go deleted file mode 100644 index c1e491c9..00000000 --- a/internal/info/info-interface_mock.go +++ /dev/null @@ -1,194 +0,0 @@ -// Code generated by moq; DO NOT EDIT. -// github.com/matryer/moq - -package info - -import ( - "sync" -) - -// Ensure, that infoInterfaceMock does implement infoInterface. -// If this is not the case, regenerate this file with moq. -var _ infoInterface = &infoInterfaceMock{} - -// infoInterfaceMock is a mock implementation of infoInterface. -// -// func TestSomethingThatUsesinfoInterface(t *testing.T) { -// -// // make and configure a mocked infoInterface -// mockedinfoInterface := &infoInterfaceMock{ -// HasDXCoreFunc: func() (bool, string) { -// panic("mock out the HasDXCore method") -// }, -// HasNvmlFunc: func() (bool, string) { -// panic("mock out the HasNvml method") -// }, -// IsTegraSystemFunc: func() (bool, string) { -// panic("mock out the IsTegraSystem method") -// }, -// UsesNVGPUModuleFunc: func() (bool, string) { -// panic("mock out the UsesNVGPUModule method") -// }, -// } -// -// // use mockedinfoInterface in code that requires infoInterface -// // and then make assertions. -// -// } -type infoInterfaceMock struct { - // HasDXCoreFunc mocks the HasDXCore method. - HasDXCoreFunc func() (bool, string) - - // HasNvmlFunc mocks the HasNvml method. - HasNvmlFunc func() (bool, string) - - // IsTegraSystemFunc mocks the IsTegraSystem method. - IsTegraSystemFunc func() (bool, string) - - // UsesNVGPUModuleFunc mocks the UsesNVGPUModule method. - UsesNVGPUModuleFunc func() (bool, string) - - // calls tracks calls to the methods. - calls struct { - // HasDXCore holds details about calls to the HasDXCore method. - HasDXCore []struct { - } - // HasNvml holds details about calls to the HasNvml method. - HasNvml []struct { - } - // IsTegraSystem holds details about calls to the IsTegraSystem method. - IsTegraSystem []struct { - } - // UsesNVGPUModule holds details about calls to the UsesNVGPUModule method. - UsesNVGPUModule []struct { - } - } - lockHasDXCore sync.RWMutex - lockHasNvml sync.RWMutex - lockIsTegraSystem sync.RWMutex - lockUsesNVGPUModule sync.RWMutex -} - -// HasDXCore calls HasDXCoreFunc. -func (mock *infoInterfaceMock) HasDXCore() (bool, string) { - callInfo := struct { - }{} - mock.lockHasDXCore.Lock() - mock.calls.HasDXCore = append(mock.calls.HasDXCore, callInfo) - mock.lockHasDXCore.Unlock() - if mock.HasDXCoreFunc == nil { - var ( - bOut bool - sOut string - ) - return bOut, sOut - } - return mock.HasDXCoreFunc() -} - -// HasDXCoreCalls gets all the calls that were made to HasDXCore. -// Check the length with: -// -// len(mockedinfoInterface.HasDXCoreCalls()) -func (mock *infoInterfaceMock) HasDXCoreCalls() []struct { -} { - var calls []struct { - } - mock.lockHasDXCore.RLock() - calls = mock.calls.HasDXCore - mock.lockHasDXCore.RUnlock() - return calls -} - -// HasNvml calls HasNvmlFunc. -func (mock *infoInterfaceMock) HasNvml() (bool, string) { - callInfo := struct { - }{} - mock.lockHasNvml.Lock() - mock.calls.HasNvml = append(mock.calls.HasNvml, callInfo) - mock.lockHasNvml.Unlock() - if mock.HasNvmlFunc == nil { - var ( - bOut bool - sOut string - ) - return bOut, sOut - } - return mock.HasNvmlFunc() -} - -// HasNvmlCalls gets all the calls that were made to HasNvml. -// Check the length with: -// -// len(mockedinfoInterface.HasNvmlCalls()) -func (mock *infoInterfaceMock) HasNvmlCalls() []struct { -} { - var calls []struct { - } - mock.lockHasNvml.RLock() - calls = mock.calls.HasNvml - mock.lockHasNvml.RUnlock() - return calls -} - -// IsTegraSystem calls IsTegraSystemFunc. -func (mock *infoInterfaceMock) IsTegraSystem() (bool, string) { - callInfo := struct { - }{} - mock.lockIsTegraSystem.Lock() - mock.calls.IsTegraSystem = append(mock.calls.IsTegraSystem, callInfo) - mock.lockIsTegraSystem.Unlock() - if mock.IsTegraSystemFunc == nil { - var ( - bOut bool - sOut string - ) - return bOut, sOut - } - return mock.IsTegraSystemFunc() -} - -// IsTegraSystemCalls gets all the calls that were made to IsTegraSystem. -// Check the length with: -// -// len(mockedinfoInterface.IsTegraSystemCalls()) -func (mock *infoInterfaceMock) IsTegraSystemCalls() []struct { -} { - var calls []struct { - } - mock.lockIsTegraSystem.RLock() - calls = mock.calls.IsTegraSystem - mock.lockIsTegraSystem.RUnlock() - return calls -} - -// UsesNVGPUModule calls UsesNVGPUModuleFunc. -func (mock *infoInterfaceMock) UsesNVGPUModule() (bool, string) { - callInfo := struct { - }{} - mock.lockUsesNVGPUModule.Lock() - mock.calls.UsesNVGPUModule = append(mock.calls.UsesNVGPUModule, callInfo) - mock.lockUsesNVGPUModule.Unlock() - if mock.UsesNVGPUModuleFunc == nil { - var ( - bOut bool - sOut string - ) - return bOut, sOut - } - return mock.UsesNVGPUModuleFunc() -} - -// UsesNVGPUModuleCalls gets all the calls that were made to UsesNVGPUModule. -// Check the length with: -// -// len(mockedinfoInterface.UsesNVGPUModuleCalls()) -func (mock *infoInterfaceMock) UsesNVGPUModuleCalls() []struct { -} { - var calls []struct { - } - mock.lockUsesNVGPUModule.RLock() - calls = mock.calls.UsesNVGPUModule - mock.lockUsesNVGPUModule.RUnlock() - return calls -} diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 7f07fd7e..e3c162e7 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -91,7 +91,12 @@ func New(opts ...Option) (Interface, error) { l.nvidiaCTKPath = "/usr/bin/nvidia-ctk" } if l.infolib == nil { - l.infolib = info.New() + l.infolib = info.New( + info.WithRoot(l.driverRoot), + info.WithLogger(l.logger), + info.WithNvmlLib(l.nvmllib), + info.WithDeviceLib(l.devicelib), + ) } l.driver = root.New( @@ -184,26 +189,19 @@ func (l *nvcdilib) resolveMode() (rmode string) { return l.mode } defer func() { - l.logger.Infof("Auto-detected mode as %q", rmode) + l.logger.Infof("Auto-detected mode as '%v'", rmode) }() - isWSL, reason := l.infolib.HasDXCore() - l.logger.Debugf("Is WSL-based system? %v: %v", isWSL, reason) - - if isWSL { + platform := l.infolib.ResolvePlatform() + switch platform { + case info.PlatformNVML: + return ModeNvml + case info.PlatformTegra: + return ModeCSV + case info.PlatformWSL: return ModeWsl } - - isNvml, reason := l.infolib.HasNvml() - l.logger.Debugf("Is NVML-based system? %v: %v", isNvml, reason) - - isTegra, reason := l.infolib.IsTegraSystem() - l.logger.Debugf("Is Tegra-based system? %v: %v", isTegra, reason) - - if isTegra && !isNvml { - return ModeCSV - } - + l.logger.Warningf("Unsupported platform detected: %v; assuming %v", platform, ModeNvml) return ModeNvml } diff --git a/pkg/nvcdi/lib_test.go b/pkg/nvcdi/lib_test.go deleted file mode 100644 index b467ee5b..00000000 --- a/pkg/nvcdi/lib_test.go +++ /dev/null @@ -1,116 +0,0 @@ -/** -# Copyright (c) NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -**/ - -package nvcdi - -import ( - "fmt" - "testing" - - testlog "github.com/sirupsen/logrus/hooks/test" - "github.com/stretchr/testify/require" -) - -func TestResolveMode(t *testing.T) { - logger, _ := testlog.NewNullLogger() - - testCases := []struct { - mode string - isTegra bool - hasDXCore bool - hasNVML bool - expected string - }{ - { - mode: "auto", - hasDXCore: true, - expected: "wsl", - }, - { - mode: "auto", - hasDXCore: false, - isTegra: true, - hasNVML: false, - expected: "csv", - }, - { - mode: "auto", - hasDXCore: false, - isTegra: false, - hasNVML: false, - expected: "nvml", - }, - { - mode: "auto", - hasDXCore: false, - isTegra: true, - hasNVML: true, - expected: "nvml", - }, - { - mode: "auto", - hasDXCore: false, - isTegra: false, - expected: "nvml", - }, - { - mode: "nvml", - hasDXCore: true, - isTegra: true, - expected: "nvml", - }, - { - mode: "wsl", - hasDXCore: false, - expected: "wsl", - }, - { - mode: "not-auto", - hasDXCore: true, - expected: "not-auto", - }, - } - - for i, tc := range testCases { - t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) { - l := nvcdilib{ - logger: logger, - mode: tc.mode, - infolib: infoMock{hasDXCore: tc.hasDXCore, isTegra: tc.isTegra, hasNVML: tc.hasNVML}, - } - - require.Equal(t, tc.expected, l.resolveMode()) - }) - } -} - -type infoMock struct { - hasDXCore bool - isTegra bool - hasNVML bool -} - -func (i infoMock) HasDXCore() (bool, string) { - return i.hasDXCore, "" -} - -func (i infoMock) HasNvml() (bool, string) { - return i.hasNVML, "" -} - -func (i infoMock) IsTegraSystem() (bool, string) { - return i.isTegra, "" -} diff --git a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/info/builder.go b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/info/builder.go new file mode 100644 index 00000000..74318433 --- /dev/null +++ b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/info/builder.go @@ -0,0 +1,89 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package info + +import ( + "github.com/NVIDIA/go-nvlib/pkg/nvlib/device" + "github.com/NVIDIA/go-nvlib/pkg/nvml" +) + +type builder struct { + logger basicLogger + root string + nvmllib nvml.Interface + devicelib device.Interface + + preHook Resolver + properties Properties +} + +// New creates a new instance of the 'info' interface +func New(opts ...Option) Interface { + b := &builder{} + for _, opt := range opts { + opt(b) + } + if b.logger == nil { + b.logger = &nullLogger{} + } + if b.root == "" { + b.root = "/" + } + if b.nvmllib == nil { + b.nvmllib = nvml.New() + } + if b.devicelib == nil { + b.devicelib = device.New(device.WithNvml(b.nvmllib)) + } + if b.preHook == nil { + b.preHook = noop{} + } + if b.properties == nil { + b.properties = &info{ + root: b.root, + nvmllib: b.nvmllib, + devicelib: b.devicelib, + } + } + return b.build() +} + +func (b *builder) build() Interface { + return &infolib{ + logger: b.logger, + Resolver: b.getResolvers(), + Properties: b.properties, + } +} + +func (b *builder) getResolvers() Resolver { + auto := ¬EqualsResolver{ + logger: b.logger, + mode: "auto", + } + + systemMode := &systemMode{ + logger: b.logger, + Properties: b.properties, + } + + return firstOf([]Resolver{ + auto, + b.preHook, + systemMode, + }) +}