diff --git a/pkg/nvlib/info/api.go b/pkg/nvlib/info/api.go index b466bcb..1c62d63 100644 --- a/pkg/nvlib/info/api.go +++ b/pkg/nvlib/info/api.go @@ -18,9 +18,15 @@ package info // Interface provides the API to the info package. type Interface interface { + PlatformResolver PropertyExtractor } +// PlatformResolver defines a function to resolve the current platform. +type PlatformResolver interface { + ResolvePlatform() Platform +} + // PropertyExtractor provides a set of functions to query capabilities of the // system. // diff --git a/pkg/nvlib/info/builder.go b/pkg/nvlib/info/builder.go index bf2dd89..87f20f0 100644 --- a/pkg/nvlib/info/builder.go +++ b/pkg/nvlib/info/builder.go @@ -22,18 +22,30 @@ import ( "github.com/NVIDIA/go-nvlib/pkg/nvlib/device" ) +type infolib struct { + PropertyExtractor + PlatformResolver +} + type options struct { + logger basicLogger root root nvmllib nvml.Interface devicelib device.Interface + + platform Platform + propertyExtractor PropertyExtractor } -// New creates a new instance of the 'info' Interface. +// New creates a new instance of the 'info' interface. func New(opts ...Option) Interface { o := &options{} for _, opt := range opts { opt(o) } + if o.logger == nil { + o.logger = &nullLogger{} + } if o.root == "" { o.root = "/" } @@ -45,9 +57,22 @@ func New(opts ...Option) Interface { if o.devicelib == nil { o.devicelib = device.New(device.WithNvml(o.nvmllib)) } - return &propertyExtractor{ - root: o.root, - nvmllib: o.nvmllib, - devicelib: o.devicelib, + if o.platform == "" { + o.platform = PlatformAuto + } + if o.propertyExtractor == nil { + o.propertyExtractor = &propertyExtractor{ + root: o.root, + nvmllib: o.nvmllib, + devicelib: o.devicelib, + } + } + return &infolib{ + PlatformResolver: &platformResolver{ + logger: o.logger, + platform: o.platform, + propertyExtractor: o.propertyExtractor, + }, + PropertyExtractor: o.propertyExtractor, } } diff --git a/pkg/nvlib/info/logger.go b/pkg/nvlib/info/logger.go new file mode 100644 index 0000000..6a6f74e --- /dev/null +++ b/pkg/nvlib/info/logger.go @@ -0,0 +1,28 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package info + +type basicLogger interface { + Debugf(string, ...interface{}) + Infof(string, ...interface{}) +} + +type nullLogger struct{} + +func (n *nullLogger) Debugf(string, ...interface{}) {} + +func (n *nullLogger) Infof(string, ...interface{}) {} diff --git a/pkg/nvlib/info/options.go b/pkg/nvlib/info/options.go index f8b47aa..e05c2bf 100644 --- a/pkg/nvlib/info/options.go +++ b/pkg/nvlib/info/options.go @@ -27,15 +27,22 @@ type Option func(*options) // WithDeviceLib sets the device library for the library. func WithDeviceLib(devicelib device.Interface) Option { - return func(l *options) { - l.devicelib = devicelib + return func(i *options) { + i.devicelib = devicelib + } +} + +// WithLogger sets the logger for the library. +func WithLogger(logger basicLogger) Option { + return func(i *options) { + i.logger = logger } } // WithNvmlLib sets the nvml library for the library. func WithNvmlLib(nvmllib nvml.Interface) Option { - return func(l *options) { - l.nvmllib = nvmllib + return func(i *options) { + i.nvmllib = nvmllib } } @@ -45,3 +52,19 @@ func WithRoot(r string) Option { i.root = root(r) } } + +// WithPropertyExtractor provides an Option to set the PropertyExtractor +// interface implementation. +// This is predominantly used for testing. +func WithPropertyExtractor(propertyExtractor PropertyExtractor) Option { + return func(i *options) { + i.propertyExtractor = propertyExtractor + } +} + +// WithPlatform provides an option to set the platform explicitly. +func WithPlatform(platform Platform) Option { + return func(i *options) { + i.platform = platform + } +} diff --git a/pkg/nvlib/info/property-extractor.go b/pkg/nvlib/info/property-extractor.go index 43ec3b8..5d5d97c 100644 --- a/pkg/nvlib/info/property-extractor.go +++ b/pkg/nvlib/info/property-extractor.go @@ -32,7 +32,7 @@ type propertyExtractor struct { devicelib device.Interface } -var _ Interface = &propertyExtractor{} +var _ PropertyExtractor = &propertyExtractor{} // HasDXCore returns true if DXCore is detected on the system. func (i *propertyExtractor) HasDXCore() (bool, string) { diff --git a/pkg/nvlib/info/resolver.go b/pkg/nvlib/info/resolver.go new file mode 100644 index 0000000..1aeb04c --- /dev/null +++ b/pkg/nvlib/info/resolver.go @@ -0,0 +1,64 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package info + +// Platform represents a supported plaform. +type Platform string + +const ( + PlatformAuto = Platform("auto") + PlatformNVML = Platform("nvml") + PlatformTegra = Platform("tegra") + PlatformWSL = Platform("wsl") + PlatformUnknown = Platform("unknown") +) + +type platformResolver struct { + logger basicLogger + platform Platform + propertyExtractor PropertyExtractor +} + +func (p platformResolver) ResolvePlatform() Platform { + if p.platform != PlatformAuto { + p.logger.Infof("Using requested platform '%s'", p.platform) + return p.platform + } + + hasDXCore, reason := p.propertyExtractor.HasDXCore() + p.logger.Debugf("Is WSL-based system? %v: %v", hasDXCore, reason) + + hasTegraFiles, reason := p.propertyExtractor.HasTegraFiles() + p.logger.Debugf("Is Tegra-based system? %v: %v", hasTegraFiles, reason) + + hasNVML, reason := p.propertyExtractor.HasNvml() + p.logger.Debugf("Is NVML-based system? %v: %v", hasNVML, reason) + + usesOnlyNVGPUModule, reason := p.propertyExtractor.UsesOnlyNVGPUModule() + p.logger.Debugf("Uses nvgpu kernel module? %v: %v", usesOnlyNVGPUModule, reason) + + switch { + case hasDXCore: + return PlatformWSL + case (hasTegraFiles && !hasNVML), usesOnlyNVGPUModule: + return PlatformTegra + case hasNVML: + return PlatformNVML + default: + return PlatformUnknown + } +} diff --git a/pkg/nvlib/info/resolver_test.go b/pkg/nvlib/info/resolver_test.go new file mode 100644 index 0000000..611ec54 --- /dev/null +++ b/pkg/nvlib/info/resolver_test.go @@ -0,0 +1,110 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package info + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestResolvePlatform(t *testing.T) { + testCases := []struct { + platform string + hasTegraFiles bool + hasDXCore bool + hasNVML bool + usesOnlyNVGPUModule bool + expected string + }{ + { + platform: "auto", + hasDXCore: true, + expected: "wsl", + }, + { + platform: "auto", + hasDXCore: false, + hasTegraFiles: true, + hasNVML: false, + expected: "tegra", + }, + { + platform: "auto", + hasDXCore: false, + hasTegraFiles: false, + hasNVML: false, + expected: "unknown", + }, + { + platform: "auto", + hasDXCore: false, + hasTegraFiles: true, + hasNVML: true, + expected: "nvml", + }, + { + platform: "auto", + hasDXCore: false, + hasTegraFiles: true, + hasNVML: true, + usesOnlyNVGPUModule: true, + expected: "tegra", + }, + { + platform: "nvml", + hasDXCore: true, + hasTegraFiles: true, + expected: "nvml", + }, + { + platform: "wsl", + hasDXCore: false, + expected: "wsl", + }, + { + platform: "not-auto", + hasDXCore: true, + expected: "not-auto", + }, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) { + l := New( + WithPropertyExtractor(&PropertyExtractorMock{ + HasDXCoreFunc: func() (bool, string) { + return tc.hasDXCore, "" + }, + HasNvmlFunc: func() (bool, string) { + return tc.hasNVML, "" + }, + HasTegraFilesFunc: func() (bool, string) { + return tc.hasTegraFiles, "" + }, + UsesOnlyNVGPUModuleFunc: func() (bool, string) { + return tc.usesOnlyNVGPUModule, "" + }, + }), + WithPlatform(Platform(tc.platform)), + ) + + require.Equal(t, Platform(tc.expected), l.ResolvePlatform()) + }) + } +}