diff --git a/pkg/nvpci/mock.go b/pkg/nvpci/mock.go index ce7ecb1..438f562 100644 --- a/pkg/nvpci/mock.go +++ b/pkg/nvpci/mock.go @@ -90,6 +90,15 @@ func (m *MockNvpci) AddMockA100(address string, numaNode int) error { return err } + _, err = os.Create(filepath.Join(deviceDir, "nvidia")) + if err != nil { + return err + } + err = os.Symlink(filepath.Join(deviceDir, "nvidia"), filepath.Join(deviceDir, "driver")) + if err != nil { + return err + } + numa, err := os.Create(filepath.Join(deviceDir, "numa_node")) if err != nil { return err diff --git a/pkg/nvpci/nvpci.go b/pkg/nvpci/nvpci.go index d6737e6..924189f 100644 --- a/pkg/nvpci/nvpci.go +++ b/pkg/nvpci/nvpci.go @@ -20,6 +20,7 @@ import ( "fmt" "os" "path" + "path/filepath" "sort" "strconv" "strings" @@ -69,6 +70,7 @@ type NvidiaPCIDevice struct { Vendor uint16 Class uint32 Device uint16 + Driver string NumaNode int Config *ConfigSpace Resources MemoryResources @@ -192,6 +194,15 @@ func NewDevice(devicePath string) (*NvidiaPCIDevice, error) { return nil, fmt.Errorf("unable to convert device string to uint16: %v", deviceStr) } + driver, err := filepath.EvalSymlinks(path.Join(devicePath, "driver")) + if err == nil { + driver = filepath.Base(driver) + } else if os.IsNotExist(err) { + driver = "" + } else { + return nil, fmt.Errorf("unable to detect driver for %s: %v", address, err) + } + numa, err := os.ReadFile(path.Join(devicePath, "numa_node")) if err != nil { return nil, fmt.Errorf("unable to read PCI NUMA node for %s: %v", address, err) @@ -238,6 +249,7 @@ func NewDevice(devicePath string) (*NvidiaPCIDevice, error) { Vendor: uint16(vendorID), Class: uint32(classID), Device: uint16(deviceID), + Driver: driver, NumaNode: int(numaNode), Config: config, Resources: resources, diff --git a/pkg/nvpci/nvpci_test.go b/pkg/nvpci/nvpci_test.go index f651a56..8dbf50f 100644 --- a/pkg/nvpci/nvpci_test.go +++ b/pkg/nvpci/nvpci_test.go @@ -45,6 +45,7 @@ func TestNvpci(t *testing.T) { require.Nil(t, err, "Error reading config") require.Equal(t, devices[0].Vendor, config.GetVendorID(), "Vendor IDs do not match") require.Equal(t, devices[0].Device, config.GetDeviceID(), "Device IDs do not match") + require.Equal(t, "nvidia", devices[0].Driver, "Wrong driver detected for device") capabilities, err := config.GetPCICapabilities() require.Nil(t, err, "Error getting PCI capabilities")