Treat missing nvidia device majors as an error

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2023-05-23 15:16:08 +02:00
parent 8df4a98d7b
commit 05632c0a40
2 changed files with 32 additions and 14 deletions

View File

@ -18,6 +18,7 @@ package devices
import (
"bufio"
"errors"
"fmt"
"io"
"os"
@ -72,7 +73,14 @@ func (d devices) Get(name Name) (Major, bool) {
// GetNVIDIADevices returns the set of NVIDIA Devices on the machine
func GetNVIDIADevices() (Devices, error) {
devicesFile, err := os.Open(procDevicesPath)
return nvidiaDevices(procDevicesPath)
}
// nvidiaDevices returns the set of NVIDIA Devices from the specified devices file.
// This is useful for testing since we may be testing on a system where `/proc/devices` does
// contain a reference to NVIDIA devices.
func nvidiaDevices(devicesPath string) (Devices, error) {
devicesFile, err := os.Open(devicesPath)
if os.IsNotExist(err) {
return nil, nil
}
@ -81,20 +89,28 @@ func GetNVIDIADevices() (Devices, error) {
}
defer devicesFile.Close()
return nvidiaDeviceFrom(devicesFile), nil
return nvidiaDeviceFrom(devicesFile)
}
func nvidiaDeviceFrom(reader io.Reader) devices {
var errNoNvidiaDevices = errors.New("no NVIDIA devices found")
func nvidiaDeviceFrom(reader io.Reader) (devices, error) {
allDevices := devicesFrom(reader)
nvidiaDevices := make(devices)
var hasNvidiaDevices bool
for n, d := range allDevices {
if !strings.HasPrefix(string(n), nvidiaDevicePrefix) {
continue
}
nvidiaDevices[n] = d
hasNvidiaDevices = true
}
return nvidiaDevices
if !hasNvidiaDevices {
return nil, errNoNvidiaDevices
}
return nvidiaDevices, nil
}
func devicesFrom(reader io.Reader) devices {

View File

@ -45,21 +45,23 @@ func TestNvidiaDevices(t *testing.T) {
func TestProcessDeviceFile(t *testing.T) {
testCases := []struct {
lines []string
expected devices
lines []string
expected devices
expectedError error
}{
{[]string{}, make(devices)},
{[]string{"Not a valid line:"}, make(devices)},
{[]string{"195 nvidia-frontend"}, devices{"nvidia-frontend": 195}},
{[]string{"195 nvidia-frontend", "235 nvidia-caps"}, devices{"nvidia-frontend": 195, "nvidia-caps": 235}},
{[]string{" 195 nvidia-frontend"}, devices{"nvidia-frontend": 195}},
{[]string{"Not a valid line:", "", "195 nvidia-frontend"}, devices{"nvidia-frontend": 195}},
{[]string{"195 not-nvidia-frontend"}, make(devices)},
{lines: []string{}, expectedError: errNoNvidiaDevices},
{lines: []string{"Not a valid line:"}, expectedError: errNoNvidiaDevices},
{lines: []string{"195 nvidia-frontend"}, expected: devices{"nvidia-frontend": 195}},
{lines: []string{"195 nvidia-frontend", "235 nvidia-caps"}, expected: devices{"nvidia-frontend": 195, "nvidia-caps": 235}},
{lines: []string{" 195 nvidia-frontend"}, expected: devices{"nvidia-frontend": 195}},
{lines: []string{"Not a valid line:", "", "195 nvidia-frontend"}, expected: devices{"nvidia-frontend": 195}},
{lines: []string{"195 not-nvidia-frontend"}, expectedError: errNoNvidiaDevices},
}
for i, tc := range testCases {
t.Run(fmt.Sprintf("testcase %d", i), func(t *testing.T) {
contents := strings.NewReader(strings.Join(tc.lines, "\n"))
d := nvidiaDeviceFrom(contents)
d, err := nvidiaDeviceFrom(contents)
require.ErrorIs(t, err, tc.expectedError)
require.EqualValues(t, tc.expected, d)
})