diff --git a/internal/info/proc/devices/devices.go b/internal/info/proc/devices/devices.go index 95430428..232df3b2 100644 --- a/internal/info/proc/devices/devices.go +++ b/internal/info/proc/devices/devices.go @@ -33,7 +33,7 @@ const ( NVIDIAModesetMinor = 254 NVIDIAFrontend = Name("nvidia-frontend") - NVIDIAGPU = NVIDIAFrontend + NVIDIAGPU = Name("nvidia") NVIDIACaps = Name("nvidia-caps") NVIDIAUVM = Name("nvidia-uvm") @@ -65,10 +65,25 @@ func (d devices) Exists(name Name) bool { return exists } -// Get a Device from Devices +// Get a Device from Devices. It also has fallback logic to ensure device name changes in /proc/devices are handled +// For e.g:- For GPU drivers 550.40.x or greater, the gpu device has been renamed from "nvidia-frontend" to "nvidia". func (d devices) Get(name Name) (Major, bool) { - device, exists := d[name] - return device, exists + for _, n := range name.getWithFallback() { + device, exists := d[n] + if exists { + return device, true + } + } + return 0, false +} + +// getWithFallback returns a prioritised list of device names for a specific name. +// This allows multiple names to be associated with a single name to support various driver versions. +func (n Name) getWithFallback() []Name { + if n == NVIDIAGPU || n == NVIDIAFrontend { + return []Name{NVIDIAGPU, NVIDIAFrontend} + } + return []Name{n} } // GetNVIDIADevices returns the set of NVIDIA Devices on the machine diff --git a/internal/info/proc/devices/devices_test.go b/internal/info/proc/devices/devices_test.go index 037ccbe0..97c56b7a 100644 --- a/internal/info/proc/devices/devices_test.go +++ b/internal/info/proc/devices/devices_test.go @@ -41,6 +41,11 @@ func TestNvidiaDevices(t *testing.T) { } _, exists := nvidiaDevices.Get("bogus") require.False(t, exists, "Unexpected 'bogus' device found") + + // assert that nvidia and nvidia-frontend can be used interchangeably and have the device major numbers + m, exists := nvidiaDevices.Get("nvidia") + require.True(t, exists) + require.Equal(t, devices["nvidia-frontend"], m) } func TestProcessDeviceFile(t *testing.T) { diff --git a/internal/system/nvdevices/devices_test.go b/internal/system/nvdevices/devices_test.go index 388f91dc..e1b68e65 100644 --- a/internal/system/nvdevices/devices_test.go +++ b/internal/system/nvdevices/devices_test.go @@ -31,11 +31,25 @@ func TestCreateControlDevices(t *testing.T) { nvidiaDevices := &devices.DevicesMock{ GetFunc: func(name devices.Name) (devices.Major, bool) { - devices := map[devices.Name]devices.Major{ + devs := map[devices.Name]devices.Major{ "nvidia-frontend": 195, "nvidia-uvm": 243, } - return devices[name], true + + // devs550_40 represents the device map from the nvidia gpu drivers >= 550.40.x + devs550_40 := map[devices.Name]devices.Major{ + "nvidia": 195, + "nvidia-uvm": 243, + } + + d, ok := devs[name] + if ok { + return d, ok + } + + // if device d is not found, fallback to the second mock device map + d, ok = devs550_40[name] + return d, ok }, } @@ -46,6 +60,7 @@ func TestCreateControlDevices(t *testing.T) { root string devices devices.Devices mknodeError error + hasError bool expectedError error expectedCalls []struct { S string @@ -58,6 +73,7 @@ func TestCreateControlDevices(t *testing.T) { root: "", devices: nvidiaDevices, mknodeError: nil, + hasError: false, expectedCalls: []struct { S string N1 int @@ -73,6 +89,7 @@ func TestCreateControlDevices(t *testing.T) { description: "some root specified", root: "/some/root", devices: nvidiaDevices, + hasError: false, mknodeError: nil, expectedCalls: []struct { S string @@ -88,6 +105,7 @@ func TestCreateControlDevices(t *testing.T) { { description: "mknod error returns error", devices: nvidiaDevices, + hasError: true, mknodeError: mknodeError, expectedError: mknodeError, // We expect the first call to this to fail, and the rest to be skipped @@ -106,8 +124,24 @@ func TestCreateControlDevices(t *testing.T) { return 0, false }, }, + hasError: true, expectedError: errInvalidDeviceNode, }, + { + description: "nvidia device renamed from nvidia-frontend to nvidia", + devices: nvidiaDevices, + hasError: false, + expectedCalls: []struct { + S string + N1 int + N2 int + }{ + {"/dev/nvidiactl", 195, 255}, + {"/dev/nvidia-modeset", 195, 254}, + {"/dev/nvidia-uvm", 243, 0}, + {"/dev/nvidia-uvm-tools", 243, 1}, + }, + }, } for _, tc := range testCases { @@ -126,9 +160,12 @@ func TestCreateControlDevices(t *testing.T) { d.mknoder = mknode err := d.CreateNVIDIAControlDevices() - require.ErrorIs(t, err, tc.expectedError) + if tc.hasError { + require.ErrorContains(t, err, tc.expectedError.Error()) + } else { + require.Nil(t, err) + } require.EqualValues(t, tc.expectedCalls, mknode.MknodeCalls()) }) } - }