Merge branch 'fix-device-symlinks' into 'main'

Fix creation of device symlinks in /dev/char

See merge request nvidia/container-toolkit/container-toolkit!399
This commit is contained in:
Evan Lezar 2023-05-31 17:31:04 +00:00 committed by Evan Lezar
parent 2d7bb636b9
commit 8697267e6b
3 changed files with 43 additions and 14 deletions

View File

@ -40,12 +40,23 @@ func newAllPossible(logger *logrus.Logger, driverRoot string) (nodeLister, error
if err != nil { if err != nil {
return nil, fmt.Errorf("failed reading device majors: %v", err) return nil, fmt.Errorf("failed reading device majors: %v", err)
} }
var requiredMajors []devices.Name
migCaps, err := nvcaps.NewMigCaps() migCaps, err := nvcaps.NewMigCaps()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read MIG caps: %v", err) return nil, fmt.Errorf("failed to read MIG caps: %v", err)
} }
if migCaps == nil { if migCaps == nil {
migCaps = make(nvcaps.MigCaps) migCaps = make(nvcaps.MigCaps)
} else {
requiredMajors = append(requiredMajors, devices.NVIDIACaps)
}
requiredMajors = append(requiredMajors, devices.NVIDIAGPU, devices.NVIDIAUVM)
for _, name := range requiredMajors {
if !deviceMajors.Exists(name) {
return nil, fmt.Errorf("missing required device major %s", name)
}
} }
l := allPossible{ l := allPossible{

View File

@ -18,6 +18,7 @@ package devices
import ( import (
"bufio" "bufio"
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -72,7 +73,14 @@ func (d devices) Get(name Name) (Major, bool) {
// GetNVIDIADevices returns the set of NVIDIA Devices on the machine // GetNVIDIADevices returns the set of NVIDIA Devices on the machine
func GetNVIDIADevices() (Devices, error) { func GetNVIDIADevices() (Devices, error) {
devicesFile, err := os.Open(procDevicesPath) return nvidiaDevices(procDevicesPath)
}
// nvidiaDevices returns the set of NVIDIA Devices from the specified devices file.
// This is useful for testing since we may be testing on a system where `/proc/devices` does
// contain a reference to NVIDIA devices.
func nvidiaDevices(devicesPath string) (Devices, error) {
devicesFile, err := os.Open(devicesPath)
if os.IsNotExist(err) { if os.IsNotExist(err) {
return nil, nil return nil, nil
} }
@ -81,20 +89,28 @@ func GetNVIDIADevices() (Devices, error) {
} }
defer devicesFile.Close() defer devicesFile.Close()
return nvidiaDeviceFrom(devicesFile), nil return nvidiaDeviceFrom(devicesFile)
} }
func nvidiaDeviceFrom(reader io.Reader) devices { var errNoNvidiaDevices = errors.New("no NVIDIA devices found")
func nvidiaDeviceFrom(reader io.Reader) (devices, error) {
allDevices := devicesFrom(reader) allDevices := devicesFrom(reader)
nvidiaDevices := make(devices) nvidiaDevices := make(devices)
var hasNvidiaDevices bool
for n, d := range allDevices { for n, d := range allDevices {
if !strings.HasPrefix(string(n), nvidiaDevicePrefix) { if !strings.HasPrefix(string(n), nvidiaDevicePrefix) {
continue continue
} }
nvidiaDevices[n] = d nvidiaDevices[n] = d
hasNvidiaDevices = true
} }
return nvidiaDevices if !hasNvidiaDevices {
return nil, errNoNvidiaDevices
}
return nvidiaDevices, nil
} }
func devicesFrom(reader io.Reader) devices { func devicesFrom(reader io.Reader) devices {

View File

@ -45,21 +45,23 @@ func TestNvidiaDevices(t *testing.T) {
func TestProcessDeviceFile(t *testing.T) { func TestProcessDeviceFile(t *testing.T) {
testCases := []struct { testCases := []struct {
lines []string lines []string
expected devices expected devices
expectedError error
}{ }{
{[]string{}, make(devices)}, {lines: []string{}, expectedError: errNoNvidiaDevices},
{[]string{"Not a valid line:"}, make(devices)}, {lines: []string{"Not a valid line:"}, expectedError: errNoNvidiaDevices},
{[]string{"195 nvidia-frontend"}, devices{"nvidia-frontend": 195}}, {lines: []string{"195 nvidia-frontend"}, expected: devices{"nvidia-frontend": 195}},
{[]string{"195 nvidia-frontend", "235 nvidia-caps"}, devices{"nvidia-frontend": 195, "nvidia-caps": 235}}, {lines: []string{"195 nvidia-frontend", "235 nvidia-caps"}, expected: devices{"nvidia-frontend": 195, "nvidia-caps": 235}},
{[]string{" 195 nvidia-frontend"}, devices{"nvidia-frontend": 195}}, {lines: []string{" 195 nvidia-frontend"}, expected: devices{"nvidia-frontend": 195}},
{[]string{"Not a valid line:", "", "195 nvidia-frontend"}, devices{"nvidia-frontend": 195}}, {lines: []string{"Not a valid line:", "", "195 nvidia-frontend"}, expected: devices{"nvidia-frontend": 195}},
{[]string{"195 not-nvidia-frontend"}, make(devices)}, {lines: []string{"195 not-nvidia-frontend"}, expectedError: errNoNvidiaDevices},
} }
for i, tc := range testCases { for i, tc := range testCases {
t.Run(fmt.Sprintf("testcase %d", i), func(t *testing.T) { t.Run(fmt.Sprintf("testcase %d", i), func(t *testing.T) {
contents := strings.NewReader(strings.Join(tc.lines, "\n")) contents := strings.NewReader(strings.Join(tc.lines, "\n"))
d := nvidiaDeviceFrom(contents) d, err := nvidiaDeviceFrom(contents)
require.ErrorIs(t, err, tc.expectedError)
require.EqualValues(t, tc.expected, d) require.EqualValues(t, tc.expected, d)
}) })