Add GetGPUByPciBusID to nvpci.Interface

This change adds a GetGPUByPciBusID method to the nvpci Interface.
The exising NewDevice function is moved to nvmdev where it is used.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2022-11-15 14:29:32 +01:00
parent 0e8a479bd5
commit e96d9c58f1
2 changed files with 18 additions and 9 deletions

View File

@ -18,13 +18,14 @@ package nvmdev
import ( import (
"fmt" "fmt"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci"
) )
const ( const (
@ -241,7 +242,7 @@ func (m mdev) iommuGroup() (int, error) {
// NewParentDevice constructs a ParentDevice // NewParentDevice constructs a ParentDevice
func NewParentDevice(devicePath string) (*ParentDevice, error) { func NewParentDevice(devicePath string) (*ParentDevice, error) {
nvdevice, err := nvpci.NewDevice(devicePath) nvdevice, err := newNvidiaPCIDeviceFromPath(devicePath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to construct NVIDIA PCI device: %v", err) return nil, fmt.Errorf("failed to construct NVIDIA PCI device: %v", err)
} }
@ -330,7 +331,7 @@ func (p *ParentDevice) GetPhysicalFunction() (*nvpci.NvidiaPCIDevice, error) {
return nil, fmt.Errorf("unable to resolve %s: %v", path.Join(p.Path, "physfn"), err) return nil, fmt.Errorf("unable to resolve %s: %v", path.Join(p.Path, "physfn"), err)
} }
return nvpci.NewDevice(physfnPath) return newNvidiaPCIDeviceFromPath(physfnPath)
} }
// GetPhysicalFunction gets the physical PCI device that a vGPU is created on // GetPhysicalFunction gets the physical PCI device that a vGPU is created on
@ -374,3 +375,10 @@ func (p *ParentDevice) GetAvailableMDEVInstances(mdevType string) (int, error) {
return availableInstances, nil return availableInstances, nil
} }
// newNvidiaPCIDeviceFromPath constructs an NvidiaPCIDevice for the specified device path.
func newNvidiaPCIDeviceFromPath(devicePath string) (*nvpci.NvidiaPCIDevice, error) {
root := filepath.Dir(devicePath)
address := filepath.Base(devicePath)
return nvpci.NewFrom(root).GetGPUByPciBusID(address)
}

View File

@ -49,6 +49,7 @@ type Interface interface {
GetNVSwitches() ([]*NvidiaPCIDevice, error) GetNVSwitches() ([]*NvidiaPCIDevice, error)
GetGPUs() ([]*NvidiaPCIDevice, error) GetGPUs() ([]*NvidiaPCIDevice, error)
GetGPUByIndex(int) (*NvidiaPCIDevice, error) GetGPUByIndex(int) (*NvidiaPCIDevice, error)
GetGPUByPciBusID(string) (*NvidiaPCIDevice, error)
GetNetworkControllers() ([]*NvidiaPCIDevice, error) GetNetworkControllers() ([]*NvidiaPCIDevice, error)
GetPciBridges() ([]*NvidiaPCIDevice, error) GetPciBridges() ([]*NvidiaPCIDevice, error)
GetDPUs() ([]*NvidiaPCIDevice, error) GetDPUs() ([]*NvidiaPCIDevice, error)
@ -143,10 +144,10 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) {
var nvdevices []*NvidiaPCIDevice var nvdevices []*NvidiaPCIDevice
for _, deviceDir := range deviceDirs { for _, deviceDir := range deviceDirs {
devicePath := path.Join(p.pciDevicesRoot, deviceDir.Name()) deviceAddress := deviceDir.Name()
nvdevice, err := NewDevice(devicePath) nvdevice, err := p.GetGPUByPciBusID(deviceAddress)
if err != nil { if err != nil {
return nil, fmt.Errorf("error constructing NVIDIA PCI device %s: %v", deviceDir.Name(), err) return nil, fmt.Errorf("error constructing NVIDIA PCI device %s: %v", deviceAddress, err)
} }
if nvdevice == nil { if nvdevice == nil {
continue continue
@ -168,9 +169,9 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) {
return nvdevices, nil return nvdevices, nil
} }
// NewDevice constructs an NvidiaPCIDevice // GetGPUByPciBusID constructs an NvidiaPCIDevice for the specified address (PCI Bus ID)
func NewDevice(devicePath string) (*NvidiaPCIDevice, error) { func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) {
address := path.Base(devicePath) devicePath := filepath.Join(p.pciDevicesRoot, address)
vendor, err := os.ReadFile(path.Join(devicePath, "vendor")) vendor, err := os.ReadFile(path.Join(devicePath, "vendor"))
if err != nil { if err != nil {