diff --git a/pkg/nvmdev/nvmdev.go b/pkg/nvmdev/nvmdev.go index 989e4f3..33f2d5a 100644 --- a/pkg/nvmdev/nvmdev.go +++ b/pkg/nvmdev/nvmdev.go @@ -321,21 +321,16 @@ func (m *Device) Delete() error { } // GetPhysicalFunction gets the physical PCI device backing a 'parent' device. -func (p *ParentDevice) GetPhysicalFunction() (*nvpci.NvidiaPCIDevice, error) { - if !p.IsVF { - return p.NvidiaPCIDevice, nil +func (p *ParentDevice) GetPhysicalFunction() *nvpci.NvidiaPCIDevice { + if p.SriovInfo.IsVF() { + return p.SriovInfo.VirtualFunction.PhysicalFunction } - - physfnPath, err := filepath.EvalSymlinks(path.Join(p.Path, "physfn")) - if err != nil { - return nil, fmt.Errorf("unable to resolve %s: %v", path.Join(p.Path, "physfn"), err) - } - - return newNvidiaPCIDeviceFromPath(physfnPath) + // Either it is an SRIOV physical function or a non-SRIOV device, so return the device itself + return p.NvidiaPCIDevice } // GetPhysicalFunction gets the physical PCI device that a vGPU is created on. -func (m *Device) GetPhysicalFunction() (*nvpci.NvidiaPCIDevice, error) { +func (m *Device) GetPhysicalFunction() *nvpci.NvidiaPCIDevice { return m.Parent.GetPhysicalFunction() } diff --git a/pkg/nvmdev/nvmdev_test.go b/pkg/nvmdev/nvmdev_test.go index 7c03f59..7d7cab4 100644 --- a/pkg/nvmdev/nvmdev_test.go +++ b/pkg/nvmdev/nvmdev_test.go @@ -35,8 +35,7 @@ func TestNvmdev(t *testing.T) { parentA100 := parentDevs[0] - pf, err := parentA100.GetPhysicalFunction() - require.Nil(t, err, "Error getting physical function backing the Mock A100 parent device") + pf := parentA100.GetPhysicalFunction() require.Equal(t, "0000:3b:04.1", pf.Address, "Wrong address for Mock A100 physical function") supported := parentA100.IsMDEVTypeSupported("A100-4C") @@ -59,7 +58,6 @@ func TestNvmdev(t *testing.T) { require.Equal(t, "vfio_mdev", mdevA100.Driver, "Wrong driver detected for mdev device") require.Equal(t, 200, mdevA100.IommuGroup, "Wrong value for iommu_group") - pf, err = mdevA100.GetPhysicalFunction() - require.Nil(t, err, "Error getting the physical function for Mock A100 mediated device") + pf = mdevA100.GetPhysicalFunction() require.Equal(t, "0000:3b:04.1", pf.Address, "Wrong address for Mock A100 physical function") } diff --git a/pkg/nvpci/mock.go b/pkg/nvpci/mock.go index 7c1b69d..9b3d6e2 100644 --- a/pkg/nvpci/mock.go +++ b/pkg/nvpci/mock.go @@ -20,6 +20,8 @@ import ( "fmt" "os" "path/filepath" + "regexp" + "strconv" "github.com/NVIDIA/go-nvlib/pkg/nvpci/bytes" ) @@ -55,14 +57,82 @@ func (m *MockNvpci) Cleanup() { os.RemoveAll(m.pciDevicesRoot) } +func validatePCIAddress(addr string) error { + r := regexp.MustCompile(`0{4}:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9]`) + if !r.Match([]byte(addr)) { + return fmt.Errorf(`invalid PCI address should match 0{4}:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9]: %s`, addr) + } + + return nil +} + // AddMockA100 Create an A100 like GPU mock device. -func (m *MockNvpci) AddMockA100(address string, numaNode int) error { - deviceDir := filepath.Join(m.pciDevicesRoot, address) - err := os.MkdirAll(deviceDir, 0755) +func (m *MockNvpci) AddMockA100(address string, numaNode int, sriov *SriovInfo) error { + err := validatePCIAddress(address) if err != nil { return err } + deviceDir := filepath.Join(m.pciDevicesRoot, address) + err = os.MkdirAll(deviceDir, 0755) + if err != nil { + return err + } + + err = createNVIDIAgpuFiles(deviceDir) + if err != nil { + return err + } + + iommuGroup := 20 + _, err = os.Create(filepath.Join(deviceDir, strconv.Itoa(iommuGroup))) + if err != nil { + return err + } + err = os.Symlink(filepath.Join(deviceDir, strconv.Itoa(iommuGroup)), filepath.Join(deviceDir, "iommu_group")) + if err != nil { + return err + } + + numa, err := os.Create(filepath.Join(deviceDir, "numa_node")) + if err != nil { + return err + } + _, err = numa.WriteString(fmt.Sprintf("%v", numaNode)) + if err != nil { + return err + } + + if sriov != nil && sriov.PhysicalFunction != nil { + totalVFs, err := os.Create(filepath.Join(deviceDir, "sriov_totalvfs")) + if err != nil { + return err + } + _, err = fmt.Fprintf(totalVFs, "%d", sriov.PhysicalFunction.TotalVFs) + if err != nil { + return err + } + + numVFs, err := os.Create(filepath.Join(deviceDir, "sriov_numvfs")) + if err != nil { + return err + } + _, err = fmt.Fprintf(numVFs, "%d", sriov.PhysicalFunction.NumVFs) + if err != nil { + return err + } + for i := 1; i <= int(sriov.PhysicalFunction.NumVFs); i++ { + err = m.createVf(address, i, iommuGroup, numaNode) + if err != nil { + return err + } + } + } + + return nil +} + +func createNVIDIAgpuFiles(deviceDir string) error { vendor, err := os.Create(filepath.Join(deviceDir, "vendor")) if err != nil { return err @@ -99,24 +169,6 @@ func (m *MockNvpci) AddMockA100(address string, numaNode int) error { return err } - _, err = os.Create(filepath.Join(deviceDir, "20")) - if err != nil { - return err - } - err = os.Symlink(filepath.Join(deviceDir, "20"), filepath.Join(deviceDir, "iommu_group")) - if err != nil { - return err - } - - numa, err := os.Create(filepath.Join(deviceDir, "numa_node")) - if err != nil { - return err - } - _, err = numa.WriteString(fmt.Sprintf("%v", numaNode)) - if err != nil { - return err - } - config, err := os.Create(filepath.Join(deviceDir, "config")) if err != nil { return err @@ -156,3 +208,53 @@ func (m *MockNvpci) AddMockA100(address string, numaNode int) error { return nil } + +func (m *MockNvpci) createVf(pfAddress string, id, iommu_group, numaNode int) error { + functionID := pfAddress[len(pfAddress)-1] + // we are verifying the last character of pfAddress is integer. + functionNumber, err := strconv.Atoi(string(functionID)) + if err != nil { + return fmt.Errorf("can't conver physical function pci address function number %s to integer: %v", string(functionID), err) + } + + vfFunctionNumber := functionNumber + id + vfAddress := pfAddress[:len(pfAddress)-1] + strconv.Itoa(vfFunctionNumber) + + deviceDir := filepath.Join(m.pciDevicesRoot, vfAddress) + err = os.MkdirAll(deviceDir, 0755) + if err != nil { + return err + } + + err = createNVIDIAgpuFiles(deviceDir) + if err != nil { + return err + } + + vfIommuGroup := strconv.Itoa(iommu_group + id) + + _, err = os.Create(filepath.Join(deviceDir, vfIommuGroup)) + if err != nil { + return err + } + err = os.Symlink(filepath.Join(deviceDir, vfIommuGroup), filepath.Join(deviceDir, "iommu_group")) + if err != nil { + return err + } + + numa, err := os.Create(filepath.Join(deviceDir, "numa_node")) + if err != nil { + return err + } + _, err = numa.WriteString(fmt.Sprintf("%v", numaNode)) + if err != nil { + return err + } + + err = os.Symlink(filepath.Join(m.pciDevicesRoot, pfAddress), filepath.Join(deviceDir, "physfn")) + if err != nil { + return err + } + + return nil +} diff --git a/pkg/nvpci/nvpci.go b/pkg/nvpci/nvpci.go index 6d83a57..6ff197b 100644 --- a/pkg/nvpci/nvpci.go +++ b/pkg/nvpci/nvpci.go @@ -76,6 +76,32 @@ type nvpci struct { var _ Interface = (*nvpci)(nil) var _ ResourceInterface = (*MemoryResources)(nil) +// SriovInfo indicates whether device is VF/PF for SRIOV capable devices. +// Only one should be set at any given time. +type SriovInfo struct { + PhysicalFunction *SriovPhysicalFunction + VirtualFunction *SriovVirtualFunction +} + +// SriovPhysicalFunction stores info about SRIOV physical function. +type SriovPhysicalFunction struct { + TotalVFs uint64 + NumVFs uint64 +} + +// SriovVirtualFunction keeps data about SRIOV virtual function. +type SriovVirtualFunction struct { + PhysicalFunction *NvidiaPCIDevice +} + +func (s *SriovInfo) IsPF() bool { + return s != nil && s.PhysicalFunction != nil +} + +func (s *SriovInfo) IsVF() bool { + return s != nil && s.VirtualFunction != nil +} + // NvidiaPCIDevice represents a PCI device for an NVIDIA product. type NvidiaPCIDevice struct { Path string @@ -90,7 +116,7 @@ type NvidiaPCIDevice struct { NumaNode int Config *ConfigSpace Resources MemoryResources - IsVF bool + SriovInfo SriovInfo } // IsVGAController if class == 0x300. @@ -178,9 +204,11 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) { } var nvdevices []*NvidiaPCIDevice + // Cache devices for each GetAllDevices invocation to speed things up. + cache := make(map[string]*NvidiaPCIDevice) for _, deviceDir := range deviceDirs { deviceAddress := deviceDir.Name() - nvdevice, err := p.GetGPUByPciBusID(deviceAddress) + nvdevice, err := p.getGPUByPciBusID(deviceAddress, cache) if err != nil { return nil, fmt.Errorf("error constructing NVIDIA PCI device %s: %v", deviceAddress, err) } @@ -206,6 +234,16 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) { // GetGPUByPciBusID constructs an NvidiaPCIDevice for the specified address (PCI Bus ID). func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) { + // Pass nil as to force reading device information from sysfs. + return p.getGPUByPciBusID(address, nil) +} + +func (p *nvpci) getGPUByPciBusID(address string, cache map[string]*NvidiaPCIDevice) (*NvidiaPCIDevice, error) { + if cache != nil { + if pciDevice, exists := cache[address]; exists { + return pciDevice, nil + } + } devicePath := filepath.Join(p.pciDevicesRoot, address) vendor, err := os.ReadFile(path.Join(devicePath, "vendor")) @@ -265,16 +303,6 @@ func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) { return nil, fmt.Errorf("unable to detect iommu_group for %s: %v", address, err) } - // device is a virtual function (VF) if "physfn" symlink exists. - var isVF bool - _, err = filepath.EvalSymlinks(path.Join(devicePath, "physfn")) - if err == nil { - isVF = true - } - if err != nil && !os.IsNotExist(err) { - return nil, fmt.Errorf("unable to resolve %s: %v", path.Join(devicePath, "physfn"), err) - } - numa, err := os.ReadFile(path.Join(devicePath, "numa_node")) if err != nil { return nil, fmt.Errorf("unable to read PCI NUMA node for %s: %v", address, err) @@ -328,6 +356,28 @@ func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) { className = UnknownClassString } + var sriovInfo SriovInfo + // Device is a virtual function (VF) if "physfn" symlink exists. + physFnAddress, err := filepath.EvalSymlinks(path.Join(devicePath, "physfn")) + if err == nil { + physFn, err := p.getGPUByPciBusID(filepath.Base(physFnAddress), cache) + if err != nil { + return nil, fmt.Errorf("unable to detect physfn for %s: %v", address, err) + } + sriovInfo = SriovInfo{ + VirtualFunction: &SriovVirtualFunction{ + PhysicalFunction: physFn, + }, + } + } else if os.IsNotExist(err) { + sriovInfo, err = p.getSriovInfoForPhysicalFunction(devicePath) + if err != nil { + return nil, fmt.Errorf("unable to read SRIOV physical function details for %s: %v", devicePath, err) + } + } else { + return nil, fmt.Errorf("unable to read %s: %v", path.Join(devicePath, "physfn"), err) + } + nvdevice := &NvidiaPCIDevice{ Path: devicePath, Address: address, @@ -339,9 +389,14 @@ func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) { NumaNode: int(numaNode), Config: config, Resources: resources, - IsVF: isVF, DeviceName: deviceName, ClassName: className, + SriovInfo: sriovInfo, + } + + // Cache physical functions only as VF can't be a root device. + if cache != nil && sriovInfo.IsPF() { + cache[address] = nvdevice } return nvdevice, nil @@ -407,7 +462,7 @@ func (p *nvpci) GetGPUs() ([]*NvidiaPCIDevice, error) { var filtered []*NvidiaPCIDevice for _, d := range devices { - if d.IsGPU() && !d.IsVF { + if d.IsGPU() && !d.SriovInfo.IsVF() { filtered = append(filtered, d) } } @@ -428,3 +483,41 @@ func (p *nvpci) GetGPUByIndex(i int) (*NvidiaPCIDevice, error) { return gpus[i], nil } + +func (p *nvpci) getSriovInfoForPhysicalFunction(devicePath string) (sriovInfo SriovInfo, err error) { + totalVfsPath := filepath.Join(devicePath, "sriov_totalvfs") + numVfsPath := filepath.Join(devicePath, "sriov_numvfs") + + // No file for sriov_totalvfs exists? Not an SRIOV device, return nil + _, err = os.Stat(totalVfsPath) + if err != nil && os.IsNotExist(err) { + return sriovInfo, nil + } + sriovTotalVfs, err := os.ReadFile(totalVfsPath) + if err != nil { + return sriovInfo, fmt.Errorf("unable to read sriov_totalvfs: %v", err) + } + totalVfsStr := strings.TrimSpace(string(sriovTotalVfs)) + totalVfsInt, err := strconv.ParseUint(totalVfsStr, 10, 16) + if err != nil { + return sriovInfo, fmt.Errorf("unable to convert sriov_totalvfs to uint64: %v", err) + } + + sriovNumVfs, err := os.ReadFile(numVfsPath) + if err != nil { + return sriovInfo, fmt.Errorf("unable to read sriov_numvfs for: %v", err) + } + numVfsStr := strings.TrimSpace(string(sriovNumVfs)) + numVfsInt, err := strconv.ParseUint(numVfsStr, 10, 16) + if err != nil { + return sriovInfo, fmt.Errorf("unable to convert sriov_numvfs to uint64: %v", err) + } + + sriovInfo = SriovInfo{ + PhysicalFunction: &SriovPhysicalFunction{ + TotalVFs: totalVfsInt, + NumVFs: numVfsInt, + }, + } + return sriovInfo, nil +} diff --git a/pkg/nvpci/nvpci_test.go b/pkg/nvpci/nvpci_test.go index 18b41f3..14a851d 100644 --- a/pkg/nvpci/nvpci_test.go +++ b/pkg/nvpci/nvpci_test.go @@ -31,7 +31,7 @@ func TestNvpci(t *testing.T) { require.Nil(t, err, "Error creating NewMockNvpci") defer nvpci.Cleanup() - err = nvpci.AddMockA100("0000:80:05.1", 0) + err = nvpci.AddMockA100("0000:80:05.1", 0, nil) require.Nil(t, err, "Error adding Mock A100 device to MockNvpci") devices, err := nvpci.GetGPUs() @@ -65,7 +65,7 @@ func TestNvpci(t *testing.T) { require.Equal(t, int(resource0.End-resource0.Start+1), bar0.Len()) require.Equal(t, ga100PmcID, bar0.Read32(0)) - require.Equal(t, devices[0].IsVF, false, "Device incorrectly identified as a VF") + require.Equal(t, devices[0].SriovInfo.IsVF(), false, "Device incorrectly identified as a VF") device, err := nvpci.GetGPUByIndex(0) require.Nil(t, err, "Error getting GPU at index 0") @@ -100,7 +100,7 @@ func TestNvpciNUMANode(t *testing.T) { require.Nil(t, err, "Error creating NewMockNvpci") defer nvpci.Cleanup() - err = nvpci.AddMockA100("0000:80:05.1", tc.NumaNode) + err = nvpci.AddMockA100("0000:80:05.1", tc.NumaNode, nil) require.Nil(t, err, "Error adding Mock A100 device to MockNvpci") devices, err := nvpci.GetGPUs() @@ -110,3 +110,64 @@ func TestNvpciNUMANode(t *testing.T) { }) } } + +func TestNvpciSRIOV(t *testing.T) { + testCases := []struct { + Description string + Sriov *SriovInfo + }{ + { + Description: "sriov set", + Sriov: &SriovInfo{ + PhysicalFunction: &SriovPhysicalFunction{ + TotalVFs: 32, + NumVFs: 16, + }, + }, + }, + { + Description: "sriov not set", + }, + } + + for _, tc := range testCases { + t.Run(tc.Description, func(t *testing.T) { + nvpci, err := NewMockNvpci() + require.Nil(t, err, "Error creating NewMockNvpci") + defer nvpci.Cleanup() + + err = nvpci.AddMockA100("0000:80:05.1", 0, tc.Sriov) + require.Nil(t, err, "Error adding Mock A100 device to MockNvpci") + + gpus, err := nvpci.GetGPUs() + require.Nil(t, err, "Error getting GPUs") + require.Equal(t, 1, len(gpus), "Wrong number of GPU devices") + + devices, err := nvpci.GetAllDevices() + require.Nil(t, err, "Error getting devices") + + if tc.Sriov != nil { + require.Len(t, devices, int(tc.Sriov.PhysicalFunction.NumVFs)+1, "Expected number of devices to be NumVFs +1(PF)") + + require.Equal(t, false, gpus[0].SriovInfo.IsVF(), "GPU should not be marked as VF") + require.Equal(t, true, gpus[0].SriovInfo.IsPF(), "GPU should be marked as PF") + require.NotNil(t, gpus[0].SriovInfo, "SriovInfo should not be set to nil") + require.NotNil(t, gpus[0].SriovInfo.PhysicalFunction, "SriovInfo.PhysicalFunction should not be set to nil") + require.Equal(t, uint64(32), gpus[0].SriovInfo.PhysicalFunction.TotalVFs, "Wrong number of total VFs") + require.Equal(t, uint64(16), gpus[0].SriovInfo.PhysicalFunction.NumVFs, "Wrong number of num VFs") + require.Nil(t, gpus[0].SriovInfo.VirtualFunction, "VirtualFunction should be set to nil") + for i := 1; i < int(tc.Sriov.PhysicalFunction.NumVFs); i++ { + require.Equal(t, true, devices[i].SriovInfo.IsVF(), "Device should be marked as VF") + require.Equal(t, false, devices[i].SriovInfo.IsPF(), "Device should not be marked as PF") + require.Equal(t, gpus[0], devices[i].SriovInfo.VirtualFunction.PhysicalFunction, "VFs PhysicalFunction should be equal only GPU in the system") + } + } else { + require.Equal(t, len(gpus), len(devices), "When no SRIOV specified number of GPUs should equal number of devices") + + require.Equal(t, false, gpus[0].SriovInfo.IsVF(), "GPU should not be marked as VF") + require.Equal(t, false, gpus[0].SriovInfo.IsPF(), "GPU should not be marked as PF") + require.Equal(t, SriovInfo{}, gpus[0].SriovInfo, "SriovInfo should be empty") + } + }) + } +}