mirror of
https://github.com/clearml/go-nvlib
synced 2025-04-06 22:04:03 +00:00
Merge pull request #38 from PiotrProkop/add-vfs-info
feat: add additional SRIOV info to NvidiaPciDevice
This commit is contained in:
commit
20ba32166d
@ -321,21 +321,16 @@ func (m *Device) Delete() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetPhysicalFunction gets the physical PCI device backing a 'parent' device.
|
// GetPhysicalFunction gets the physical PCI device backing a 'parent' device.
|
||||||
func (p *ParentDevice) GetPhysicalFunction() (*nvpci.NvidiaPCIDevice, error) {
|
func (p *ParentDevice) GetPhysicalFunction() *nvpci.NvidiaPCIDevice {
|
||||||
if !p.IsVF {
|
if p.SriovInfo.IsVF() {
|
||||||
return p.NvidiaPCIDevice, nil
|
return p.SriovInfo.VirtualFunction.PhysicalFunction
|
||||||
}
|
}
|
||||||
|
// Either it is an SRIOV physical function or a non-SRIOV device, so return the device itself
|
||||||
physfnPath, err := filepath.EvalSymlinks(path.Join(p.Path, "physfn"))
|
return p.NvidiaPCIDevice
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to resolve %s: %v", path.Join(p.Path, "physfn"), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return newNvidiaPCIDeviceFromPath(physfnPath)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPhysicalFunction gets the physical PCI device that a vGPU is created on.
|
// GetPhysicalFunction gets the physical PCI device that a vGPU is created on.
|
||||||
func (m *Device) GetPhysicalFunction() (*nvpci.NvidiaPCIDevice, error) {
|
func (m *Device) GetPhysicalFunction() *nvpci.NvidiaPCIDevice {
|
||||||
return m.Parent.GetPhysicalFunction()
|
return m.Parent.GetPhysicalFunction()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,8 +35,7 @@ func TestNvmdev(t *testing.T) {
|
|||||||
|
|
||||||
parentA100 := parentDevs[0]
|
parentA100 := parentDevs[0]
|
||||||
|
|
||||||
pf, err := parentA100.GetPhysicalFunction()
|
pf := parentA100.GetPhysicalFunction()
|
||||||
require.Nil(t, err, "Error getting physical function backing the Mock A100 parent device")
|
|
||||||
require.Equal(t, "0000:3b:04.1", pf.Address, "Wrong address for Mock A100 physical function")
|
require.Equal(t, "0000:3b:04.1", pf.Address, "Wrong address for Mock A100 physical function")
|
||||||
|
|
||||||
supported := parentA100.IsMDEVTypeSupported("A100-4C")
|
supported := parentA100.IsMDEVTypeSupported("A100-4C")
|
||||||
@ -59,7 +58,6 @@ func TestNvmdev(t *testing.T) {
|
|||||||
require.Equal(t, "vfio_mdev", mdevA100.Driver, "Wrong driver detected for mdev device")
|
require.Equal(t, "vfio_mdev", mdevA100.Driver, "Wrong driver detected for mdev device")
|
||||||
require.Equal(t, 200, mdevA100.IommuGroup, "Wrong value for iommu_group")
|
require.Equal(t, 200, mdevA100.IommuGroup, "Wrong value for iommu_group")
|
||||||
|
|
||||||
pf, err = mdevA100.GetPhysicalFunction()
|
pf = mdevA100.GetPhysicalFunction()
|
||||||
require.Nil(t, err, "Error getting the physical function for Mock A100 mediated device")
|
|
||||||
require.Equal(t, "0000:3b:04.1", pf.Address, "Wrong address for Mock A100 physical function")
|
require.Equal(t, "0000:3b:04.1", pf.Address, "Wrong address for Mock A100 physical function")
|
||||||
}
|
}
|
||||||
|
@ -20,6 +20,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/NVIDIA/go-nvlib/pkg/nvpci/bytes"
|
"github.com/NVIDIA/go-nvlib/pkg/nvpci/bytes"
|
||||||
)
|
)
|
||||||
@ -55,14 +57,82 @@ func (m *MockNvpci) Cleanup() {
|
|||||||
os.RemoveAll(m.pciDevicesRoot)
|
os.RemoveAll(m.pciDevicesRoot)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validatePCIAddress(addr string) error {
|
||||||
|
r := regexp.MustCompile(`0{4}:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9]`)
|
||||||
|
if !r.Match([]byte(addr)) {
|
||||||
|
return fmt.Errorf(`invalid PCI address should match 0{4}:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9]: %s`, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddMockA100 Create an A100 like GPU mock device.
|
// AddMockA100 Create an A100 like GPU mock device.
|
||||||
func (m *MockNvpci) AddMockA100(address string, numaNode int) error {
|
func (m *MockNvpci) AddMockA100(address string, numaNode int, sriov *SriovInfo) error {
|
||||||
deviceDir := filepath.Join(m.pciDevicesRoot, address)
|
err := validatePCIAddress(address)
|
||||||
err := os.MkdirAll(deviceDir, 0755)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
deviceDir := filepath.Join(m.pciDevicesRoot, address)
|
||||||
|
err = os.MkdirAll(deviceDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = createNVIDIAgpuFiles(deviceDir)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
iommuGroup := 20
|
||||||
|
_, err = os.Create(filepath.Join(deviceDir, strconv.Itoa(iommuGroup)))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = os.Symlink(filepath.Join(deviceDir, strconv.Itoa(iommuGroup)), filepath.Join(deviceDir, "iommu_group"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
numa, err := os.Create(filepath.Join(deviceDir, "numa_node"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = numa.WriteString(fmt.Sprintf("%v", numaNode))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if sriov != nil && sriov.PhysicalFunction != nil {
|
||||||
|
totalVFs, err := os.Create(filepath.Join(deviceDir, "sriov_totalvfs"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = fmt.Fprintf(totalVFs, "%d", sriov.PhysicalFunction.TotalVFs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
numVFs, err := os.Create(filepath.Join(deviceDir, "sriov_numvfs"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = fmt.Fprintf(numVFs, "%d", sriov.PhysicalFunction.NumVFs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for i := 1; i <= int(sriov.PhysicalFunction.NumVFs); i++ {
|
||||||
|
err = m.createVf(address, i, iommuGroup, numaNode)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createNVIDIAgpuFiles(deviceDir string) error {
|
||||||
vendor, err := os.Create(filepath.Join(deviceDir, "vendor"))
|
vendor, err := os.Create(filepath.Join(deviceDir, "vendor"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -99,24 +169,6 @@ func (m *MockNvpci) AddMockA100(address string, numaNode int) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = os.Create(filepath.Join(deviceDir, "20"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = os.Symlink(filepath.Join(deviceDir, "20"), filepath.Join(deviceDir, "iommu_group"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
numa, err := os.Create(filepath.Join(deviceDir, "numa_node"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = numa.WriteString(fmt.Sprintf("%v", numaNode))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
config, err := os.Create(filepath.Join(deviceDir, "config"))
|
config, err := os.Create(filepath.Join(deviceDir, "config"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -156,3 +208,53 @@ func (m *MockNvpci) AddMockA100(address string, numaNode int) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockNvpci) createVf(pfAddress string, id, iommu_group, numaNode int) error {
|
||||||
|
functionID := pfAddress[len(pfAddress)-1]
|
||||||
|
// we are verifying the last character of pfAddress is integer.
|
||||||
|
functionNumber, err := strconv.Atoi(string(functionID))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("can't conver physical function pci address function number %s to integer: %v", string(functionID), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
vfFunctionNumber := functionNumber + id
|
||||||
|
vfAddress := pfAddress[:len(pfAddress)-1] + strconv.Itoa(vfFunctionNumber)
|
||||||
|
|
||||||
|
deviceDir := filepath.Join(m.pciDevicesRoot, vfAddress)
|
||||||
|
err = os.MkdirAll(deviceDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = createNVIDIAgpuFiles(deviceDir)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
vfIommuGroup := strconv.Itoa(iommu_group + id)
|
||||||
|
|
||||||
|
_, err = os.Create(filepath.Join(deviceDir, vfIommuGroup))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = os.Symlink(filepath.Join(deviceDir, vfIommuGroup), filepath.Join(deviceDir, "iommu_group"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
numa, err := os.Create(filepath.Join(deviceDir, "numa_node"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = numa.WriteString(fmt.Sprintf("%v", numaNode))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.Symlink(filepath.Join(m.pciDevicesRoot, pfAddress), filepath.Join(deviceDir, "physfn"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -76,6 +76,32 @@ type nvpci struct {
|
|||||||
var _ Interface = (*nvpci)(nil)
|
var _ Interface = (*nvpci)(nil)
|
||||||
var _ ResourceInterface = (*MemoryResources)(nil)
|
var _ ResourceInterface = (*MemoryResources)(nil)
|
||||||
|
|
||||||
|
// SriovInfo indicates whether device is VF/PF for SRIOV capable devices.
|
||||||
|
// Only one should be set at any given time.
|
||||||
|
type SriovInfo struct {
|
||||||
|
PhysicalFunction *SriovPhysicalFunction
|
||||||
|
VirtualFunction *SriovVirtualFunction
|
||||||
|
}
|
||||||
|
|
||||||
|
// SriovPhysicalFunction stores info about SRIOV physical function.
|
||||||
|
type SriovPhysicalFunction struct {
|
||||||
|
TotalVFs uint64
|
||||||
|
NumVFs uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// SriovVirtualFunction keeps data about SRIOV virtual function.
|
||||||
|
type SriovVirtualFunction struct {
|
||||||
|
PhysicalFunction *NvidiaPCIDevice
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SriovInfo) IsPF() bool {
|
||||||
|
return s != nil && s.PhysicalFunction != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SriovInfo) IsVF() bool {
|
||||||
|
return s != nil && s.VirtualFunction != nil
|
||||||
|
}
|
||||||
|
|
||||||
// NvidiaPCIDevice represents a PCI device for an NVIDIA product.
|
// NvidiaPCIDevice represents a PCI device for an NVIDIA product.
|
||||||
type NvidiaPCIDevice struct {
|
type NvidiaPCIDevice struct {
|
||||||
Path string
|
Path string
|
||||||
@ -90,7 +116,7 @@ type NvidiaPCIDevice struct {
|
|||||||
NumaNode int
|
NumaNode int
|
||||||
Config *ConfigSpace
|
Config *ConfigSpace
|
||||||
Resources MemoryResources
|
Resources MemoryResources
|
||||||
IsVF bool
|
SriovInfo SriovInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsVGAController if class == 0x300.
|
// IsVGAController if class == 0x300.
|
||||||
@ -178,9 +204,11 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var nvdevices []*NvidiaPCIDevice
|
var nvdevices []*NvidiaPCIDevice
|
||||||
|
// Cache devices for each GetAllDevices invocation to speed things up.
|
||||||
|
cache := make(map[string]*NvidiaPCIDevice)
|
||||||
for _, deviceDir := range deviceDirs {
|
for _, deviceDir := range deviceDirs {
|
||||||
deviceAddress := deviceDir.Name()
|
deviceAddress := deviceDir.Name()
|
||||||
nvdevice, err := p.GetGPUByPciBusID(deviceAddress)
|
nvdevice, err := p.getGPUByPciBusID(deviceAddress, cache)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error constructing NVIDIA PCI device %s: %v", deviceAddress, err)
|
return nil, fmt.Errorf("error constructing NVIDIA PCI device %s: %v", deviceAddress, err)
|
||||||
}
|
}
|
||||||
@ -206,6 +234,16 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) {
|
|||||||
|
|
||||||
// GetGPUByPciBusID constructs an NvidiaPCIDevice for the specified address (PCI Bus ID).
|
// GetGPUByPciBusID constructs an NvidiaPCIDevice for the specified address (PCI Bus ID).
|
||||||
func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) {
|
func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) {
|
||||||
|
// Pass nil as to force reading device information from sysfs.
|
||||||
|
return p.getGPUByPciBusID(address, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *nvpci) getGPUByPciBusID(address string, cache map[string]*NvidiaPCIDevice) (*NvidiaPCIDevice, error) {
|
||||||
|
if cache != nil {
|
||||||
|
if pciDevice, exists := cache[address]; exists {
|
||||||
|
return pciDevice, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
devicePath := filepath.Join(p.pciDevicesRoot, address)
|
devicePath := filepath.Join(p.pciDevicesRoot, address)
|
||||||
|
|
||||||
vendor, err := os.ReadFile(path.Join(devicePath, "vendor"))
|
vendor, err := os.ReadFile(path.Join(devicePath, "vendor"))
|
||||||
@ -265,16 +303,6 @@ func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) {
|
|||||||
return nil, fmt.Errorf("unable to detect iommu_group for %s: %v", address, err)
|
return nil, fmt.Errorf("unable to detect iommu_group for %s: %v", address, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// device is a virtual function (VF) if "physfn" symlink exists.
|
|
||||||
var isVF bool
|
|
||||||
_, err = filepath.EvalSymlinks(path.Join(devicePath, "physfn"))
|
|
||||||
if err == nil {
|
|
||||||
isVF = true
|
|
||||||
}
|
|
||||||
if err != nil && !os.IsNotExist(err) {
|
|
||||||
return nil, fmt.Errorf("unable to resolve %s: %v", path.Join(devicePath, "physfn"), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
numa, err := os.ReadFile(path.Join(devicePath, "numa_node"))
|
numa, err := os.ReadFile(path.Join(devicePath, "numa_node"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to read PCI NUMA node for %s: %v", address, err)
|
return nil, fmt.Errorf("unable to read PCI NUMA node for %s: %v", address, err)
|
||||||
@ -328,6 +356,28 @@ func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) {
|
|||||||
className = UnknownClassString
|
className = UnknownClassString
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var sriovInfo SriovInfo
|
||||||
|
// Device is a virtual function (VF) if "physfn" symlink exists.
|
||||||
|
physFnAddress, err := filepath.EvalSymlinks(path.Join(devicePath, "physfn"))
|
||||||
|
if err == nil {
|
||||||
|
physFn, err := p.getGPUByPciBusID(filepath.Base(physFnAddress), cache)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to detect physfn for %s: %v", address, err)
|
||||||
|
}
|
||||||
|
sriovInfo = SriovInfo{
|
||||||
|
VirtualFunction: &SriovVirtualFunction{
|
||||||
|
PhysicalFunction: physFn,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else if os.IsNotExist(err) {
|
||||||
|
sriovInfo, err = p.getSriovInfoForPhysicalFunction(devicePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to read SRIOV physical function details for %s: %v", devicePath, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("unable to read %s: %v", path.Join(devicePath, "physfn"), err)
|
||||||
|
}
|
||||||
|
|
||||||
nvdevice := &NvidiaPCIDevice{
|
nvdevice := &NvidiaPCIDevice{
|
||||||
Path: devicePath,
|
Path: devicePath,
|
||||||
Address: address,
|
Address: address,
|
||||||
@ -339,9 +389,14 @@ func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) {
|
|||||||
NumaNode: int(numaNode),
|
NumaNode: int(numaNode),
|
||||||
Config: config,
|
Config: config,
|
||||||
Resources: resources,
|
Resources: resources,
|
||||||
IsVF: isVF,
|
|
||||||
DeviceName: deviceName,
|
DeviceName: deviceName,
|
||||||
ClassName: className,
|
ClassName: className,
|
||||||
|
SriovInfo: sriovInfo,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache physical functions only as VF can't be a root device.
|
||||||
|
if cache != nil && sriovInfo.IsPF() {
|
||||||
|
cache[address] = nvdevice
|
||||||
}
|
}
|
||||||
|
|
||||||
return nvdevice, nil
|
return nvdevice, nil
|
||||||
@ -407,7 +462,7 @@ func (p *nvpci) GetGPUs() ([]*NvidiaPCIDevice, error) {
|
|||||||
|
|
||||||
var filtered []*NvidiaPCIDevice
|
var filtered []*NvidiaPCIDevice
|
||||||
for _, d := range devices {
|
for _, d := range devices {
|
||||||
if d.IsGPU() && !d.IsVF {
|
if d.IsGPU() && !d.SriovInfo.IsVF() {
|
||||||
filtered = append(filtered, d)
|
filtered = append(filtered, d)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -428,3 +483,41 @@ func (p *nvpci) GetGPUByIndex(i int) (*NvidiaPCIDevice, error) {
|
|||||||
|
|
||||||
return gpus[i], nil
|
return gpus[i], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *nvpci) getSriovInfoForPhysicalFunction(devicePath string) (sriovInfo SriovInfo, err error) {
|
||||||
|
totalVfsPath := filepath.Join(devicePath, "sriov_totalvfs")
|
||||||
|
numVfsPath := filepath.Join(devicePath, "sriov_numvfs")
|
||||||
|
|
||||||
|
// No file for sriov_totalvfs exists? Not an SRIOV device, return nil
|
||||||
|
_, err = os.Stat(totalVfsPath)
|
||||||
|
if err != nil && os.IsNotExist(err) {
|
||||||
|
return sriovInfo, nil
|
||||||
|
}
|
||||||
|
sriovTotalVfs, err := os.ReadFile(totalVfsPath)
|
||||||
|
if err != nil {
|
||||||
|
return sriovInfo, fmt.Errorf("unable to read sriov_totalvfs: %v", err)
|
||||||
|
}
|
||||||
|
totalVfsStr := strings.TrimSpace(string(sriovTotalVfs))
|
||||||
|
totalVfsInt, err := strconv.ParseUint(totalVfsStr, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return sriovInfo, fmt.Errorf("unable to convert sriov_totalvfs to uint64: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sriovNumVfs, err := os.ReadFile(numVfsPath)
|
||||||
|
if err != nil {
|
||||||
|
return sriovInfo, fmt.Errorf("unable to read sriov_numvfs for: %v", err)
|
||||||
|
}
|
||||||
|
numVfsStr := strings.TrimSpace(string(sriovNumVfs))
|
||||||
|
numVfsInt, err := strconv.ParseUint(numVfsStr, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return sriovInfo, fmt.Errorf("unable to convert sriov_numvfs to uint64: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sriovInfo = SriovInfo{
|
||||||
|
PhysicalFunction: &SriovPhysicalFunction{
|
||||||
|
TotalVFs: totalVfsInt,
|
||||||
|
NumVFs: numVfsInt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return sriovInfo, nil
|
||||||
|
}
|
||||||
|
@ -31,7 +31,7 @@ func TestNvpci(t *testing.T) {
|
|||||||
require.Nil(t, err, "Error creating NewMockNvpci")
|
require.Nil(t, err, "Error creating NewMockNvpci")
|
||||||
defer nvpci.Cleanup()
|
defer nvpci.Cleanup()
|
||||||
|
|
||||||
err = nvpci.AddMockA100("0000:80:05.1", 0)
|
err = nvpci.AddMockA100("0000:80:05.1", 0, nil)
|
||||||
require.Nil(t, err, "Error adding Mock A100 device to MockNvpci")
|
require.Nil(t, err, "Error adding Mock A100 device to MockNvpci")
|
||||||
|
|
||||||
devices, err := nvpci.GetGPUs()
|
devices, err := nvpci.GetGPUs()
|
||||||
@ -65,7 +65,7 @@ func TestNvpci(t *testing.T) {
|
|||||||
require.Equal(t, int(resource0.End-resource0.Start+1), bar0.Len())
|
require.Equal(t, int(resource0.End-resource0.Start+1), bar0.Len())
|
||||||
require.Equal(t, ga100PmcID, bar0.Read32(0))
|
require.Equal(t, ga100PmcID, bar0.Read32(0))
|
||||||
|
|
||||||
require.Equal(t, devices[0].IsVF, false, "Device incorrectly identified as a VF")
|
require.Equal(t, devices[0].SriovInfo.IsVF(), false, "Device incorrectly identified as a VF")
|
||||||
|
|
||||||
device, err := nvpci.GetGPUByIndex(0)
|
device, err := nvpci.GetGPUByIndex(0)
|
||||||
require.Nil(t, err, "Error getting GPU at index 0")
|
require.Nil(t, err, "Error getting GPU at index 0")
|
||||||
@ -100,7 +100,7 @@ func TestNvpciNUMANode(t *testing.T) {
|
|||||||
require.Nil(t, err, "Error creating NewMockNvpci")
|
require.Nil(t, err, "Error creating NewMockNvpci")
|
||||||
defer nvpci.Cleanup()
|
defer nvpci.Cleanup()
|
||||||
|
|
||||||
err = nvpci.AddMockA100("0000:80:05.1", tc.NumaNode)
|
err = nvpci.AddMockA100("0000:80:05.1", tc.NumaNode, nil)
|
||||||
require.Nil(t, err, "Error adding Mock A100 device to MockNvpci")
|
require.Nil(t, err, "Error adding Mock A100 device to MockNvpci")
|
||||||
|
|
||||||
devices, err := nvpci.GetGPUs()
|
devices, err := nvpci.GetGPUs()
|
||||||
@ -110,3 +110,64 @@ func TestNvpciNUMANode(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNvpciSRIOV(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
Description string
|
||||||
|
Sriov *SriovInfo
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Description: "sriov set",
|
||||||
|
Sriov: &SriovInfo{
|
||||||
|
PhysicalFunction: &SriovPhysicalFunction{
|
||||||
|
TotalVFs: 32,
|
||||||
|
NumVFs: 16,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Description: "sriov not set",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.Description, func(t *testing.T) {
|
||||||
|
nvpci, err := NewMockNvpci()
|
||||||
|
require.Nil(t, err, "Error creating NewMockNvpci")
|
||||||
|
defer nvpci.Cleanup()
|
||||||
|
|
||||||
|
err = nvpci.AddMockA100("0000:80:05.1", 0, tc.Sriov)
|
||||||
|
require.Nil(t, err, "Error adding Mock A100 device to MockNvpci")
|
||||||
|
|
||||||
|
gpus, err := nvpci.GetGPUs()
|
||||||
|
require.Nil(t, err, "Error getting GPUs")
|
||||||
|
require.Equal(t, 1, len(gpus), "Wrong number of GPU devices")
|
||||||
|
|
||||||
|
devices, err := nvpci.GetAllDevices()
|
||||||
|
require.Nil(t, err, "Error getting devices")
|
||||||
|
|
||||||
|
if tc.Sriov != nil {
|
||||||
|
require.Len(t, devices, int(tc.Sriov.PhysicalFunction.NumVFs)+1, "Expected number of devices to be NumVFs +1(PF)")
|
||||||
|
|
||||||
|
require.Equal(t, false, gpus[0].SriovInfo.IsVF(), "GPU should not be marked as VF")
|
||||||
|
require.Equal(t, true, gpus[0].SriovInfo.IsPF(), "GPU should be marked as PF")
|
||||||
|
require.NotNil(t, gpus[0].SriovInfo, "SriovInfo should not be set to nil")
|
||||||
|
require.NotNil(t, gpus[0].SriovInfo.PhysicalFunction, "SriovInfo.PhysicalFunction should not be set to nil")
|
||||||
|
require.Equal(t, uint64(32), gpus[0].SriovInfo.PhysicalFunction.TotalVFs, "Wrong number of total VFs")
|
||||||
|
require.Equal(t, uint64(16), gpus[0].SriovInfo.PhysicalFunction.NumVFs, "Wrong number of num VFs")
|
||||||
|
require.Nil(t, gpus[0].SriovInfo.VirtualFunction, "VirtualFunction should be set to nil")
|
||||||
|
for i := 1; i < int(tc.Sriov.PhysicalFunction.NumVFs); i++ {
|
||||||
|
require.Equal(t, true, devices[i].SriovInfo.IsVF(), "Device should be marked as VF")
|
||||||
|
require.Equal(t, false, devices[i].SriovInfo.IsPF(), "Device should not be marked as PF")
|
||||||
|
require.Equal(t, gpus[0], devices[i].SriovInfo.VirtualFunction.PhysicalFunction, "VFs PhysicalFunction should be equal only GPU in the system")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
require.Equal(t, len(gpus), len(devices), "When no SRIOV specified number of GPUs should equal number of devices")
|
||||||
|
|
||||||
|
require.Equal(t, false, gpus[0].SriovInfo.IsVF(), "GPU should not be marked as VF")
|
||||||
|
require.Equal(t, false, gpus[0].SriovInfo.IsPF(), "GPU should not be marked as PF")
|
||||||
|
require.Equal(t, SriovInfo{}, gpus[0].SriovInfo, "SriovInfo should be empty")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user