diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index f3d7dba0..18d69d84 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -31,6 +31,8 @@ type nvcdilib struct { deviceNamer DeviceNamer driverRoot string nvidiaCTKPath string + + infolib info.Interface } // New creates a new nvcdi library @@ -54,6 +56,9 @@ func New(opts ...Option) Interface { if l.nvidiaCTKPath == "" { l.nvidiaCTKPath = "/usr/bin/nvidia-ctk" } + if l.infolib == nil { + l.infolib = info.New() + } switch l.resolveMode() { case "nvml": @@ -82,9 +87,7 @@ func (l *nvcdilib) resolveMode() (rmode string) { l.logger.Infof("Auto-detected mode as %q", rmode) }() - nvinfo := info.New() - - isWSL, reason := nvinfo.HasDXCore() + isWSL, reason := l.infolib.HasDXCore() l.logger.Debugf("Is WSL-based system? %v: %v", isWSL, reason) if isWSL { diff --git a/pkg/nvcdi/lib_test.go b/pkg/nvcdi/lib_test.go new file mode 100644 index 00000000..f0ddf96e --- /dev/null +++ b/pkg/nvcdi/lib_test.go @@ -0,0 +1,88 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvcdi + +import ( + "fmt" + "testing" + + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +func TestResolveMode(t *testing.T) { + logger, _ := testlog.NewNullLogger() + + testCases := []struct { + mode string + // TODO: This should be a proper mock + hasDXCore bool + expected string + }{ + { + mode: "auto", + hasDXCore: true, + expected: "wsl", + }, + { + mode: "auto", + hasDXCore: false, + expected: "nvml", + }, + { + mode: "nvml", + hasDXCore: true, + expected: "nvml", + }, + { + mode: "wsl", + hasDXCore: false, + expected: "wsl", + }, + { + mode: "not-auto", + hasDXCore: true, + expected: "not-auto", + }, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) { + l := nvcdilib{ + logger: logger, + mode: tc.mode, + infolib: infoMock(tc.hasDXCore), + } + + require.Equal(t, tc.expected, l.resolveMode()) + }) + } +} + +type infoMock bool + +func (i infoMock) HasDXCore() (bool, string) { + return bool(i), "" +} + +func (i infoMock) HasNvml() (bool, string) { + panic("should not be called") +} + +func (i infoMock) IsTegraSystem() (bool, string) { + panic("should not be called") +}