feat: add additinal SRIOV info to NvidiaPciDevice

Signed-off-by: PiotrProkop <pprokop@nvidia.com>
This commit is contained in:
PiotrProkop 2024-05-27 11:40:08 +02:00
parent 7604335102
commit bf3f431fc8
5 changed files with 302 additions and 53 deletions

View File

@ -321,21 +321,16 @@ func (m *Device) Delete() error {
}
// GetPhysicalFunction gets the physical PCI device backing a 'parent' device.
func (p *ParentDevice) GetPhysicalFunction() (*nvpci.NvidiaPCIDevice, error) {
if !p.IsVF {
return p.NvidiaPCIDevice, nil
func (p *ParentDevice) GetPhysicalFunction() *nvpci.NvidiaPCIDevice {
if p.SriovInfo.IsVF() {
return p.SriovInfo.VirtualFunction.PhysicalFunction
}
physfnPath, err := filepath.EvalSymlinks(path.Join(p.Path, "physfn"))
if err != nil {
return nil, fmt.Errorf("unable to resolve %s: %v", path.Join(p.Path, "physfn"), err)
}
return newNvidiaPCIDeviceFromPath(physfnPath)
// Either it is an SRIOV physical function or a non-SRIOV device, so return the device itself
return p.NvidiaPCIDevice
}
// GetPhysicalFunction gets the physical PCI device that a vGPU is created on.
func (m *Device) GetPhysicalFunction() (*nvpci.NvidiaPCIDevice, error) {
func (m *Device) GetPhysicalFunction() *nvpci.NvidiaPCIDevice {
return m.Parent.GetPhysicalFunction()
}

View File

@ -35,8 +35,7 @@ func TestNvmdev(t *testing.T) {
parentA100 := parentDevs[0]
pf, err := parentA100.GetPhysicalFunction()
require.Nil(t, err, "Error getting physical function backing the Mock A100 parent device")
pf := parentA100.GetPhysicalFunction()
require.Equal(t, "0000:3b:04.1", pf.Address, "Wrong address for Mock A100 physical function")
supported := parentA100.IsMDEVTypeSupported("A100-4C")
@ -59,7 +58,6 @@ func TestNvmdev(t *testing.T) {
require.Equal(t, "vfio_mdev", mdevA100.Driver, "Wrong driver detected for mdev device")
require.Equal(t, 200, mdevA100.IommuGroup, "Wrong value for iommu_group")
pf, err = mdevA100.GetPhysicalFunction()
require.Nil(t, err, "Error getting the physical function for Mock A100 mediated device")
pf = mdevA100.GetPhysicalFunction()
require.Equal(t, "0000:3b:04.1", pf.Address, "Wrong address for Mock A100 physical function")
}

View File

@ -20,6 +20,8 @@ import (
"fmt"
"os"
"path/filepath"
"regexp"
"strconv"
"github.com/NVIDIA/go-nvlib/pkg/nvpci/bytes"
)
@ -55,14 +57,82 @@ func (m *MockNvpci) Cleanup() {
os.RemoveAll(m.pciDevicesRoot)
}
func validatePCIAddress(addr string) error {
r := regexp.MustCompile(`0{4}:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9]`)
if !r.Match([]byte(addr)) {
return fmt.Errorf(`invalid PCI address should match 0{4}:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9]: %s`, addr)
}
return nil
}
// AddMockA100 Create an A100 like GPU mock device.
func (m *MockNvpci) AddMockA100(address string, numaNode int) error {
deviceDir := filepath.Join(m.pciDevicesRoot, address)
err := os.MkdirAll(deviceDir, 0755)
func (m *MockNvpci) AddMockA100(address string, numaNode int, sriov *SriovInfo) error {
err := validatePCIAddress(address)
if err != nil {
return err
}
deviceDir := filepath.Join(m.pciDevicesRoot, address)
err = os.MkdirAll(deviceDir, 0755)
if err != nil {
return err
}
err = createNVIDIAgpuFiles(deviceDir)
if err != nil {
return err
}
iommuGroup := 20
_, err = os.Create(filepath.Join(deviceDir, strconv.Itoa(iommuGroup)))
if err != nil {
return err
}
err = os.Symlink(filepath.Join(deviceDir, strconv.Itoa(iommuGroup)), filepath.Join(deviceDir, "iommu_group"))
if err != nil {
return err
}
numa, err := os.Create(filepath.Join(deviceDir, "numa_node"))
if err != nil {
return err
}
_, err = numa.WriteString(fmt.Sprintf("%v", numaNode))
if err != nil {
return err
}
if sriov != nil && sriov.PhysicalFunction != nil {
totalVFs, err := os.Create(filepath.Join(deviceDir, "sriov_totalvfs"))
if err != nil {
return err
}
_, err = fmt.Fprintf(totalVFs, "%d", sriov.PhysicalFunction.TotalVFs)
if err != nil {
return err
}
numVFs, err := os.Create(filepath.Join(deviceDir, "sriov_numvfs"))
if err != nil {
return err
}
_, err = fmt.Fprintf(numVFs, "%d", sriov.PhysicalFunction.NumVFs)
if err != nil {
return err
}
for i := 1; i <= int(sriov.PhysicalFunction.NumVFs); i++ {
err = m.createVf(address, i, iommuGroup, numaNode)
if err != nil {
return err
}
}
}
return nil
}
func createNVIDIAgpuFiles(deviceDir string) error {
vendor, err := os.Create(filepath.Join(deviceDir, "vendor"))
if err != nil {
return err
@ -99,24 +169,6 @@ func (m *MockNvpci) AddMockA100(address string, numaNode int) error {
return err
}
_, err = os.Create(filepath.Join(deviceDir, "20"))
if err != nil {
return err
}
err = os.Symlink(filepath.Join(deviceDir, "20"), filepath.Join(deviceDir, "iommu_group"))
if err != nil {
return err
}
numa, err := os.Create(filepath.Join(deviceDir, "numa_node"))
if err != nil {
return err
}
_, err = numa.WriteString(fmt.Sprintf("%v", numaNode))
if err != nil {
return err
}
config, err := os.Create(filepath.Join(deviceDir, "config"))
if err != nil {
return err
@ -156,3 +208,53 @@ func (m *MockNvpci) AddMockA100(address string, numaNode int) error {
return nil
}
func (m *MockNvpci) createVf(pfAddress string, id, iommu_group, numaNode int) error {
functionID := pfAddress[len(pfAddress)-1]
// we are verifying the last character of pfAddress is integer.
functionNumber, err := strconv.Atoi(string(functionID))
if err != nil {
return fmt.Errorf("can't conver physical function pci address function number %s to integer: %v", string(functionID), err)
}
vfFunctionNumber := functionNumber + id
vfAddress := pfAddress[:len(pfAddress)-1] + strconv.Itoa(vfFunctionNumber)
deviceDir := filepath.Join(m.pciDevicesRoot, vfAddress)
err = os.MkdirAll(deviceDir, 0755)
if err != nil {
return err
}
err = createNVIDIAgpuFiles(deviceDir)
if err != nil {
return err
}
vfIommuGroup := strconv.Itoa(iommu_group + id)
_, err = os.Create(filepath.Join(deviceDir, vfIommuGroup))
if err != nil {
return err
}
err = os.Symlink(filepath.Join(deviceDir, vfIommuGroup), filepath.Join(deviceDir, "iommu_group"))
if err != nil {
return err
}
numa, err := os.Create(filepath.Join(deviceDir, "numa_node"))
if err != nil {
return err
}
_, err = numa.WriteString(fmt.Sprintf("%v", numaNode))
if err != nil {
return err
}
err = os.Symlink(filepath.Join(m.pciDevicesRoot, pfAddress), filepath.Join(deviceDir, "physfn"))
if err != nil {
return err
}
return nil
}

View File

@ -76,6 +76,32 @@ type nvpci struct {
var _ Interface = (*nvpci)(nil)
var _ ResourceInterface = (*MemoryResources)(nil)
// SriovInfo indicates whether device is VF/PF for SRIOV capable devices.
// Only one should be set at any given time.
type SriovInfo struct {
PhysicalFunction *SriovPhysicalFunction
VirtualFunction *SriovVirtualFunction
}
// SriovPhysicalFunction stores info about SRIOV physical function.
type SriovPhysicalFunction struct {
TotalVFs uint64
NumVFs uint64
}
// SriovVirtualFunction keeps data about SRIOV virtual function.
type SriovVirtualFunction struct {
PhysicalFunction *NvidiaPCIDevice
}
func (s *SriovInfo) IsPF() bool {
return s != nil && s.PhysicalFunction != nil
}
func (s *SriovInfo) IsVF() bool {
return s != nil && s.VirtualFunction != nil
}
// NvidiaPCIDevice represents a PCI device for an NVIDIA product.
type NvidiaPCIDevice struct {
Path string
@ -90,7 +116,7 @@ type NvidiaPCIDevice struct {
NumaNode int
Config *ConfigSpace
Resources MemoryResources
IsVF bool
SriovInfo SriovInfo
}
// IsVGAController if class == 0x300.
@ -178,9 +204,11 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) {
}
var nvdevices []*NvidiaPCIDevice
// Cache devices for each GetAllDevices invocation to speed things up.
cache := make(map[string]*NvidiaPCIDevice)
for _, deviceDir := range deviceDirs {
deviceAddress := deviceDir.Name()
nvdevice, err := p.GetGPUByPciBusID(deviceAddress)
nvdevice, err := p.getGPUByPciBusID(deviceAddress, cache)
if err != nil {
return nil, fmt.Errorf("error constructing NVIDIA PCI device %s: %v", deviceAddress, err)
}
@ -206,6 +234,16 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) {
// GetGPUByPciBusID constructs an NvidiaPCIDevice for the specified address (PCI Bus ID).
func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) {
// Pass nil as to force reading device information from sysfs.
return p.getGPUByPciBusID(address, nil)
}
func (p *nvpci) getGPUByPciBusID(address string, cache map[string]*NvidiaPCIDevice) (*NvidiaPCIDevice, error) {
if cache != nil {
if pciDevice, exists := cache[address]; exists {
return pciDevice, nil
}
}
devicePath := filepath.Join(p.pciDevicesRoot, address)
vendor, err := os.ReadFile(path.Join(devicePath, "vendor"))
@ -265,16 +303,6 @@ func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) {
return nil, fmt.Errorf("unable to detect iommu_group for %s: %v", address, err)
}
// device is a virtual function (VF) if "physfn" symlink exists.
var isVF bool
_, err = filepath.EvalSymlinks(path.Join(devicePath, "physfn"))
if err == nil {
isVF = true
}
if err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("unable to resolve %s: %v", path.Join(devicePath, "physfn"), err)
}
numa, err := os.ReadFile(path.Join(devicePath, "numa_node"))
if err != nil {
return nil, fmt.Errorf("unable to read PCI NUMA node for %s: %v", address, err)
@ -328,6 +356,28 @@ func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) {
className = UnknownClassString
}
var sriovInfo SriovInfo
// Device is a virtual function (VF) if "physfn" symlink exists.
physFnAddress, err := filepath.EvalSymlinks(path.Join(devicePath, "physfn"))
if err == nil {
physFn, err := p.getGPUByPciBusID(filepath.Base(physFnAddress), cache)
if err != nil {
return nil, fmt.Errorf("unable to detect physfn for %s: %v", address, err)
}
sriovInfo = SriovInfo{
VirtualFunction: &SriovVirtualFunction{
PhysicalFunction: physFn,
},
}
} else if os.IsNotExist(err) {
sriovInfo, err = p.getSriovInfoForPhysicalFunction(devicePath)
if err != nil {
return nil, fmt.Errorf("unable to read SRIOV physical function details for %s: %v", devicePath, err)
}
} else {
return nil, fmt.Errorf("unable to read %s: %v", path.Join(devicePath, "physfn"), err)
}
nvdevice := &NvidiaPCIDevice{
Path: devicePath,
Address: address,
@ -339,9 +389,14 @@ func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) {
NumaNode: int(numaNode),
Config: config,
Resources: resources,
IsVF: isVF,
DeviceName: deviceName,
ClassName: className,
SriovInfo: sriovInfo,
}
// Cache physical functions only as VF can't be a root device.
if cache != nil && sriovInfo.IsPF() {
cache[address] = nvdevice
}
return nvdevice, nil
@ -407,7 +462,7 @@ func (p *nvpci) GetGPUs() ([]*NvidiaPCIDevice, error) {
var filtered []*NvidiaPCIDevice
for _, d := range devices {
if d.IsGPU() && !d.IsVF {
if d.IsGPU() && !d.SriovInfo.IsVF() {
filtered = append(filtered, d)
}
}
@ -428,3 +483,41 @@ func (p *nvpci) GetGPUByIndex(i int) (*NvidiaPCIDevice, error) {
return gpus[i], nil
}
func (p *nvpci) getSriovInfoForPhysicalFunction(devicePath string) (sriovInfo SriovInfo, err error) {
totalVfsPath := filepath.Join(devicePath, "sriov_totalvfs")
numVfsPath := filepath.Join(devicePath, "sriov_numvfs")
// No file for sriov_totalvfs exists? Not an SRIOV device, return nil
_, err = os.Stat(totalVfsPath)
if err != nil && os.IsNotExist(err) {
return sriovInfo, nil
}
sriovTotalVfs, err := os.ReadFile(totalVfsPath)
if err != nil {
return sriovInfo, fmt.Errorf("unable to read sriov_totalvfs: %v", err)
}
totalVfsStr := strings.TrimSpace(string(sriovTotalVfs))
totalVfsInt, err := strconv.ParseUint(totalVfsStr, 10, 16)
if err != nil {
return sriovInfo, fmt.Errorf("unable to convert sriov_totalvfs to uint64: %v", err)
}
sriovNumVfs, err := os.ReadFile(numVfsPath)
if err != nil {
return sriovInfo, fmt.Errorf("unable to read sriov_numvfs for: %v", err)
}
numVfsStr := strings.TrimSpace(string(sriovNumVfs))
numVfsInt, err := strconv.ParseUint(numVfsStr, 10, 16)
if err != nil {
return sriovInfo, fmt.Errorf("unable to convert sriov_numvfs to uint64: %v", err)
}
sriovInfo = SriovInfo{
PhysicalFunction: &SriovPhysicalFunction{
TotalVFs: totalVfsInt,
NumVFs: numVfsInt,
},
}
return sriovInfo, nil
}

View File

@ -31,7 +31,7 @@ func TestNvpci(t *testing.T) {
require.Nil(t, err, "Error creating NewMockNvpci")
defer nvpci.Cleanup()
err = nvpci.AddMockA100("0000:80:05.1", 0)
err = nvpci.AddMockA100("0000:80:05.1", 0, nil)
require.Nil(t, err, "Error adding Mock A100 device to MockNvpci")
devices, err := nvpci.GetGPUs()
@ -65,7 +65,7 @@ func TestNvpci(t *testing.T) {
require.Equal(t, int(resource0.End-resource0.Start+1), bar0.Len())
require.Equal(t, ga100PmcID, bar0.Read32(0))
require.Equal(t, devices[0].IsVF, false, "Device incorrectly identified as a VF")
require.Equal(t, devices[0].SriovInfo.IsVF(), false, "Device incorrectly identified as a VF")
device, err := nvpci.GetGPUByIndex(0)
require.Nil(t, err, "Error getting GPU at index 0")
@ -100,7 +100,7 @@ func TestNvpciNUMANode(t *testing.T) {
require.Nil(t, err, "Error creating NewMockNvpci")
defer nvpci.Cleanup()
err = nvpci.AddMockA100("0000:80:05.1", tc.NumaNode)
err = nvpci.AddMockA100("0000:80:05.1", tc.NumaNode, nil)
require.Nil(t, err, "Error adding Mock A100 device to MockNvpci")
devices, err := nvpci.GetGPUs()
@ -110,3 +110,64 @@ func TestNvpciNUMANode(t *testing.T) {
})
}
}
func TestNvpciSRIOV(t *testing.T) {
testCases := []struct {
Description string
Sriov *SriovInfo
}{
{
Description: "sriov set",
Sriov: &SriovInfo{
PhysicalFunction: &SriovPhysicalFunction{
TotalVFs: 32,
NumVFs: 16,
},
},
},
{
Description: "sriov not set",
},
}
for _, tc := range testCases {
t.Run(tc.Description, func(t *testing.T) {
nvpci, err := NewMockNvpci()
require.Nil(t, err, "Error creating NewMockNvpci")
defer nvpci.Cleanup()
err = nvpci.AddMockA100("0000:80:05.1", 0, tc.Sriov)
require.Nil(t, err, "Error adding Mock A100 device to MockNvpci")
gpus, err := nvpci.GetGPUs()
require.Nil(t, err, "Error getting GPUs")
require.Equal(t, 1, len(gpus), "Wrong number of GPU devices")
devices, err := nvpci.GetAllDevices()
require.Nil(t, err, "Error getting devices")
if tc.Sriov != nil {
require.Len(t, devices, int(tc.Sriov.PhysicalFunction.NumVFs)+1, "Expected number of devices to be NumVFs +1(PF)")
require.Equal(t, false, gpus[0].SriovInfo.IsVF(), "GPU should not be marked as VF")
require.Equal(t, true, gpus[0].SriovInfo.IsPF(), "GPU should be marked as PF")
require.NotNil(t, gpus[0].SriovInfo, "SriovInfo should not be set to nil")
require.NotNil(t, gpus[0].SriovInfo.PhysicalFunction, "SriovInfo.PhysicalFunction should not be set to nil")
require.Equal(t, uint64(32), gpus[0].SriovInfo.PhysicalFunction.TotalVFs, "Wrong number of total VFs")
require.Equal(t, uint64(16), gpus[0].SriovInfo.PhysicalFunction.NumVFs, "Wrong number of num VFs")
require.Nil(t, gpus[0].SriovInfo.VirtualFunction, "VirtualFunction should be set to nil")
for i := 1; i < int(tc.Sriov.PhysicalFunction.NumVFs); i++ {
require.Equal(t, true, devices[i].SriovInfo.IsVF(), "Device should be marked as VF")
require.Equal(t, false, devices[i].SriovInfo.IsPF(), "Device should not be marked as PF")
require.Equal(t, gpus[0], devices[i].SriovInfo.VirtualFunction.PhysicalFunction, "VFs PhysicalFunction should be equal only GPU in the system")
}
} else {
require.Equal(t, len(gpus), len(devices), "When no SRIOV specified number of GPUs should equal number of devices")
require.Equal(t, false, gpus[0].SriovInfo.IsVF(), "GPU should not be marked as VF")
require.Equal(t, false, gpus[0].SriovInfo.IsPF(), "GPU should not be marked as PF")
require.Equal(t, SriovInfo{}, gpus[0].SriovInfo, "SriovInfo should be empty")
}
})
}
}