diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index 7aef47b8..caee034c 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -36,10 +36,6 @@ import ( ) const ( - discoveryModeAuto = "auto" - discoveryModeNVML = "nvml" - discoveryModeWSL = "wsl" - formatJSON = "json" formatYAML = "yaml" @@ -97,8 +93,8 @@ func (m command) build() *cli.Command { }, &cli.StringFlag{ Name: "discovery-mode", - Usage: "The mode to use when discovering the available entities. One of [auto | nvml | wsl]. I mode is set to 'auto' the mode will be determined based on the system configuration.", - Value: discoveryModeAuto, + Usage: "The mode to use when discovering the available entities. One of [auto | nvml | wsl]. If mode is set to 'auto' the mode will be determined based on the system configuration.", + Value: nvcdi.ModeAuto, Destination: &cfg.discoveryMode, }, &cli.StringFlag{ @@ -133,9 +129,9 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error { cfg.discoveryMode = strings.ToLower(cfg.discoveryMode) switch cfg.discoveryMode { - case discoveryModeAuto: - case discoveryModeNVML: - case discoveryModeWSL: + case nvcdi.ModeAuto: + case nvcdi.ModeNvml: + case nvcdi.ModeWsl: default: return fmt.Errorf("invalid discovery mode: %v", cfg.discoveryMode) } diff --git a/pkg/nvcdi/api.go b/pkg/nvcdi/api.go index 0fd07e72..267010e9 100644 --- a/pkg/nvcdi/api.go +++ b/pkg/nvcdi/api.go @@ -22,6 +22,15 @@ import ( "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" ) +const ( + // ModeAuto configures the CDI spec generator to automatically detect the system configuration + ModeAuto = "auto" + // ModeNvml configures the CDI spec generator to use the NVML library. + ModeNvml = "nvml" + // ModeWsl configures the CDI spec generator to generate a WSL spec. + ModeWsl = "wsl" +) + // Interface defines the API for the nvcdi package type Interface interface { GetCommonEdits() (*cdi.ContainerEdits, error) diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index f3d7dba0..be5554eb 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -31,6 +31,8 @@ type nvcdilib struct { deviceNamer DeviceNamer driverRoot string nvidiaCTKPath string + + infolib info.Interface } // New creates a new nvcdi library @@ -40,7 +42,7 @@ func New(opts ...Option) Interface { opt(l) } if l.mode == "" { - l.mode = "auto" + l.mode = ModeAuto } if l.logger == nil { l.logger = logrus.StandardLogger() @@ -54,9 +56,12 @@ func New(opts ...Option) Interface { if l.nvidiaCTKPath == "" { l.nvidiaCTKPath = "/usr/bin/nvidia-ctk" } + if l.infolib == nil { + l.infolib = info.New() + } switch l.resolveMode() { - case "nvml": + case ModeNvml: if l.nvmllib == nil { l.nvmllib = nvml.New() } @@ -65,7 +70,7 @@ func New(opts ...Option) Interface { } return (*nvmllib)(l) - case "wsl": + case ModeWsl: return (*wsllib)(l) } @@ -75,21 +80,19 @@ func New(opts ...Option) Interface { // resolveMode resolves the mode for CDI spec generation based on the current system. func (l *nvcdilib) resolveMode() (rmode string) { - if l.mode != "auto" { + if l.mode != ModeAuto { return l.mode } defer func() { l.logger.Infof("Auto-detected mode as %q", rmode) }() - nvinfo := info.New() - - isWSL, reason := nvinfo.HasDXCore() + isWSL, reason := l.infolib.HasDXCore() l.logger.Debugf("Is WSL-based system? %v: %v", isWSL, reason) if isWSL { - return "wsl" + return ModeWsl } - return "nvml" + return ModeNvml } diff --git a/pkg/nvcdi/lib_test.go b/pkg/nvcdi/lib_test.go new file mode 100644 index 00000000..f0ddf96e --- /dev/null +++ b/pkg/nvcdi/lib_test.go @@ -0,0 +1,88 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvcdi + +import ( + "fmt" + "testing" + + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +func TestResolveMode(t *testing.T) { + logger, _ := testlog.NewNullLogger() + + testCases := []struct { + mode string + // TODO: This should be a proper mock + hasDXCore bool + expected string + }{ + { + mode: "auto", + hasDXCore: true, + expected: "wsl", + }, + { + mode: "auto", + hasDXCore: false, + expected: "nvml", + }, + { + mode: "nvml", + hasDXCore: true, + expected: "nvml", + }, + { + mode: "wsl", + hasDXCore: false, + expected: "wsl", + }, + { + mode: "not-auto", + hasDXCore: true, + expected: "not-auto", + }, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) { + l := nvcdilib{ + logger: logger, + mode: tc.mode, + infolib: infoMock(tc.hasDXCore), + } + + require.Equal(t, tc.expected, l.resolveMode()) + }) + } +} + +type infoMock bool + +func (i infoMock) HasDXCore() (bool, string) { + return bool(i), "" +} + +func (i infoMock) HasNvml() (bool, string) { + panic("should not be called") +} + +func (i infoMock) IsTegraSystem() (bool, string) { + panic("should not be called") +}