diff --git a/cmd/nvidia-ctk/system/create-dev-char-symlinks/all.go b/cmd/nvidia-ctk/system/create-dev-char-symlinks/all.go index 92d3d631..c8c0e3b7 100644 --- a/cmd/nvidia-ctk/system/create-dev-char-symlinks/all.go +++ b/cmd/nvidia-ctk/system/create-dev-char-symlinks/all.go @@ -40,12 +40,23 @@ func newAllPossible(logger *logrus.Logger, driverRoot string) (nodeLister, error if err != nil { return nil, fmt.Errorf("failed reading device majors: %v", err) } + + var requiredMajors []devices.Name migCaps, err := nvcaps.NewMigCaps() if err != nil { return nil, fmt.Errorf("failed to read MIG caps: %v", err) } if migCaps == nil { migCaps = make(nvcaps.MigCaps) + } else { + requiredMajors = append(requiredMajors, devices.NVIDIACaps) + } + + requiredMajors = append(requiredMajors, devices.NVIDIAGPU, devices.NVIDIAUVM) + for _, name := range requiredMajors { + if !deviceMajors.Exists(name) { + return nil, fmt.Errorf("missing required device major %s", name) + } } l := allPossible{ diff --git a/internal/info/proc/devices/devices.go b/internal/info/proc/devices/devices.go index 377f20d1..95430428 100644 --- a/internal/info/proc/devices/devices.go +++ b/internal/info/proc/devices/devices.go @@ -18,6 +18,7 @@ package devices import ( "bufio" + "errors" "fmt" "io" "os" @@ -72,7 +73,14 @@ func (d devices) Get(name Name) (Major, bool) { // GetNVIDIADevices returns the set of NVIDIA Devices on the machine func GetNVIDIADevices() (Devices, error) { - devicesFile, err := os.Open(procDevicesPath) + return nvidiaDevices(procDevicesPath) +} + +// nvidiaDevices returns the set of NVIDIA Devices from the specified devices file. +// This is useful for testing since we may be testing on a system where `/proc/devices` does +// contain a reference to NVIDIA devices. +func nvidiaDevices(devicesPath string) (Devices, error) { + devicesFile, err := os.Open(devicesPath) if os.IsNotExist(err) { return nil, nil } @@ -81,20 +89,28 @@ func GetNVIDIADevices() (Devices, error) { } defer devicesFile.Close() - return nvidiaDeviceFrom(devicesFile), nil + return nvidiaDeviceFrom(devicesFile) } -func nvidiaDeviceFrom(reader io.Reader) devices { +var errNoNvidiaDevices = errors.New("no NVIDIA devices found") + +func nvidiaDeviceFrom(reader io.Reader) (devices, error) { allDevices := devicesFrom(reader) nvidiaDevices := make(devices) + + var hasNvidiaDevices bool for n, d := range allDevices { if !strings.HasPrefix(string(n), nvidiaDevicePrefix) { continue } nvidiaDevices[n] = d + hasNvidiaDevices = true } - return nvidiaDevices + if !hasNvidiaDevices { + return nil, errNoNvidiaDevices + } + return nvidiaDevices, nil } func devicesFrom(reader io.Reader) devices { diff --git a/internal/info/proc/devices/devices_test.go b/internal/info/proc/devices/devices_test.go index 78bce52f..037ccbe0 100644 --- a/internal/info/proc/devices/devices_test.go +++ b/internal/info/proc/devices/devices_test.go @@ -45,21 +45,23 @@ func TestNvidiaDevices(t *testing.T) { func TestProcessDeviceFile(t *testing.T) { testCases := []struct { - lines []string - expected devices + lines []string + expected devices + expectedError error }{ - {[]string{}, make(devices)}, - {[]string{"Not a valid line:"}, make(devices)}, - {[]string{"195 nvidia-frontend"}, devices{"nvidia-frontend": 195}}, - {[]string{"195 nvidia-frontend", "235 nvidia-caps"}, devices{"nvidia-frontend": 195, "nvidia-caps": 235}}, - {[]string{" 195 nvidia-frontend"}, devices{"nvidia-frontend": 195}}, - {[]string{"Not a valid line:", "", "195 nvidia-frontend"}, devices{"nvidia-frontend": 195}}, - {[]string{"195 not-nvidia-frontend"}, make(devices)}, + {lines: []string{}, expectedError: errNoNvidiaDevices}, + {lines: []string{"Not a valid line:"}, expectedError: errNoNvidiaDevices}, + {lines: []string{"195 nvidia-frontend"}, expected: devices{"nvidia-frontend": 195}}, + {lines: []string{"195 nvidia-frontend", "235 nvidia-caps"}, expected: devices{"nvidia-frontend": 195, "nvidia-caps": 235}}, + {lines: []string{" 195 nvidia-frontend"}, expected: devices{"nvidia-frontend": 195}}, + {lines: []string{"Not a valid line:", "", "195 nvidia-frontend"}, expected: devices{"nvidia-frontend": 195}}, + {lines: []string{"195 not-nvidia-frontend"}, expectedError: errNoNvidiaDevices}, } for i, tc := range testCases { t.Run(fmt.Sprintf("testcase %d", i), func(t *testing.T) { contents := strings.NewReader(strings.Join(tc.lines, "\n")) - d := nvidiaDeviceFrom(contents) + d, err := nvidiaDeviceFrom(contents) + require.ErrorIs(t, err, tc.expectedError) require.EqualValues(t, tc.expected, d) })