Merge branch 'skip-display-devices' into 'main'

Skip devices based on PCI device class

See merge request nvidia/cloud-native/go-nvlib!24
This commit is contained in:
Kevin Klues 2022-11-16 09:48:48 +00:00
commit 9110850748
5 changed files with 103 additions and 12 deletions

View File

@ -59,7 +59,7 @@ generate:
lint: lint:
# We use `go list -f '{{.Dir}}' $(MODULE)/...` to skip the `vendor` folder. # We use `go list -f '{{.Dir}}' $(MODULE)/...` to skip the `vendor` folder.
go list -f '{{.Dir}}' $(MODULE)/... | xargs golint -set_exit_status go list -f '{{.Dir}}' $(MODULE)/... | grep -v pkg/nvml | xargs golint -set_exit_status
vet: vet:
go vet $(MODULE)/... go vet $(MODULE)/...

View File

@ -35,7 +35,8 @@ type Interface interface {
} }
type devicelib struct { type devicelib struct {
nvml nvml.Interface nvml nvml.Interface
selectedDeviceClasses map[Class]struct{}
} }
var _ Interface = &devicelib{} var _ Interface = &devicelib{}
@ -49,6 +50,11 @@ func New(opts ...Option) Interface {
if d.nvml == nil { if d.nvml == nil {
d.nvml = nvml.New() d.nvml = nvml.New()
} }
if d.selectedDeviceClasses == nil {
d.selectedDeviceClasses = map[Class]struct{}{
ClassCompute: {},
}
}
return d return d
} }
@ -59,5 +65,17 @@ func WithNvml(nvml nvml.Interface) Option {
} }
} }
// WithSelectedDeviceClasses selects the specified device classes when filtering devices
func WithSelectedDeviceClasses(classes ...Class) Option {
return func(d *devicelib) {
if d.selectedDeviceClasses == nil {
d.selectedDeviceClasses = make(map[Class]struct{})
}
for _, c := range classes {
d.selectedDeviceClasses[c] = struct{}{}
}
}
}
// Option defines a function for passing options to the New() call // Option defines a function for passing options to the New() call
type Option func(*devicelib) type Option func(*devicelib)

View File

@ -18,9 +18,11 @@ package device
import ( import (
"fmt" "fmt"
"strings"
"github.com/NVIDIA/go-nvml/pkg/dl" "github.com/NVIDIA/go-nvml/pkg/dl"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci"
) )
// Device defines the set of extended functions associated with a device.Device // Device defines the set of extended functions associated with a device.Device
@ -39,13 +41,37 @@ type device struct {
lib *devicelib lib *devicelib
} }
// Class represents the PCI class for a device
type Class uint32
// Define constants for common device classes
const (
ClassCompute = Class(nvpci.PCI3dControllerClass)
ClassDisplay = Class(nvpci.PCIVgaControllerClass)
)
var _ Device = &device{} var _ Device = &device{}
// NewDevice builds a new Device from an nvml.Device // NewDevice builds a new Device from an nvml.Device
func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) { func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) {
return d.newDevice(dev)
}
// newDevice creates a device from an nvml.Device
func (d *devicelib) newDevice(dev nvml.Device) (*device, error) {
return &device{dev, d}, nil return &device{dev, d}, nil
} }
// classIsSelected checks whether the specified class has been selected when constructing the devicelib
func (d *devicelib) classIsSelected(c Class) bool {
if d.selectedDeviceClasses == nil {
return false
}
_, exists := d.selectedDeviceClasses[c]
return exists
}
// IsMigCapable checks if a device is capable of having MIG paprtitions created on it // IsMigCapable checks if a device is capable of having MIG paprtitions created on it
func (d *device) IsMigCapable() (bool, error) { func (d *device) IsMigCapable() (bool, error) {
err := nvmlLookupSymbol("nvmlDeviceGetMigMode") err := nvmlLookupSymbol("nvmlDeviceGetMigMode")
@ -183,6 +209,35 @@ func (d *device) GetMigProfiles() ([]MigProfile, error) {
return profiles, nil return profiles, nil
} }
// getClass returns the PCI device class for the device
func (d *device) getClass() (Class, error) {
info, ret := d.GetPciInfo()
if ret != nvml.SUCCESS {
return 0, fmt.Errorf("failed to get PCI info: %v", ret)
}
// We convert the BusId to a string
var bytes []byte
for _, b := range info.BusId {
if byte(b) == '\x00' {
break
}
bytes = append(bytes, byte(b))
}
id := string(bytes)
if id != "0000" {
id = strings.TrimPrefix(id, "0000")
}
device, err := nvpci.New().GetGPUByPciBusID(id)
if err != nil {
return 0, fmt.Errorf("failed to construct PCI device: %v", ret)
}
return Class(device.Class), nil
}
// VisitDevices visits each top-level device and invokes a callback function for it // VisitDevices visits each top-level device and invokes a callback function for it
func (d *devicelib) VisitDevices(visit func(int, Device) error) error { func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
count, ret := d.nvml.DeviceGetCount() count, ret := d.nvml.DeviceGetCount()
@ -195,10 +250,19 @@ func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
if ret != nvml.SUCCESS { if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device handle for index '%v': %v", i, ret) return fmt.Errorf("error getting device handle for index '%v': %v", i, ret)
} }
dev, err := d.NewDevice(device) dev, err := d.newDevice(device)
if err != nil { if err != nil {
return fmt.Errorf("error creating new device wrapper: %v", err) return fmt.Errorf("error creating new device wrapper: %v", err)
} }
class, err := dev.getClass()
if err != nil {
return fmt.Errorf("error getting PCI device class for device: %v", err)
}
if !d.classIsSelected(class) {
continue
}
err = visit(i, dev) err = visit(i, dev)
if err != nil { if err != nil {
return fmt.Errorf("error visiting device: %v", err) return fmt.Errorf("error visiting device: %v", err)

View File

@ -18,13 +18,14 @@ package nvmdev
import ( import (
"fmt" "fmt"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci"
) )
const ( const (
@ -241,7 +242,7 @@ func (m mdev) iommuGroup() (int, error) {
// NewParentDevice constructs a ParentDevice // NewParentDevice constructs a ParentDevice
func NewParentDevice(devicePath string) (*ParentDevice, error) { func NewParentDevice(devicePath string) (*ParentDevice, error) {
nvdevice, err := nvpci.NewDevice(devicePath) nvdevice, err := newNvidiaPCIDeviceFromPath(devicePath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to construct NVIDIA PCI device: %v", err) return nil, fmt.Errorf("failed to construct NVIDIA PCI device: %v", err)
} }
@ -330,7 +331,7 @@ func (p *ParentDevice) GetPhysicalFunction() (*nvpci.NvidiaPCIDevice, error) {
return nil, fmt.Errorf("unable to resolve %s: %v", path.Join(p.Path, "physfn"), err) return nil, fmt.Errorf("unable to resolve %s: %v", path.Join(p.Path, "physfn"), err)
} }
return nvpci.NewDevice(physfnPath) return newNvidiaPCIDeviceFromPath(physfnPath)
} }
// GetPhysicalFunction gets the physical PCI device that a vGPU is created on // GetPhysicalFunction gets the physical PCI device that a vGPU is created on
@ -374,3 +375,10 @@ func (p *ParentDevice) GetAvailableMDEVInstances(mdevType string) (int, error) {
return availableInstances, nil return availableInstances, nil
} }
// newNvidiaPCIDeviceFromPath constructs an NvidiaPCIDevice for the specified device path.
func newNvidiaPCIDeviceFromPath(devicePath string) (*nvpci.NvidiaPCIDevice, error) {
root := filepath.Dir(devicePath)
address := filepath.Base(devicePath)
return nvpci.NewFrom(root).GetGPUByPciBusID(address)
}

View File

@ -49,6 +49,7 @@ type Interface interface {
GetNVSwitches() ([]*NvidiaPCIDevice, error) GetNVSwitches() ([]*NvidiaPCIDevice, error)
GetGPUs() ([]*NvidiaPCIDevice, error) GetGPUs() ([]*NvidiaPCIDevice, error)
GetGPUByIndex(int) (*NvidiaPCIDevice, error) GetGPUByIndex(int) (*NvidiaPCIDevice, error)
GetGPUByPciBusID(string) (*NvidiaPCIDevice, error)
GetNetworkControllers() ([]*NvidiaPCIDevice, error) GetNetworkControllers() ([]*NvidiaPCIDevice, error)
GetPciBridges() ([]*NvidiaPCIDevice, error) GetPciBridges() ([]*NvidiaPCIDevice, error)
GetDPUs() ([]*NvidiaPCIDevice, error) GetDPUs() ([]*NvidiaPCIDevice, error)
@ -143,10 +144,10 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) {
var nvdevices []*NvidiaPCIDevice var nvdevices []*NvidiaPCIDevice
for _, deviceDir := range deviceDirs { for _, deviceDir := range deviceDirs {
devicePath := path.Join(p.pciDevicesRoot, deviceDir.Name()) deviceAddress := deviceDir.Name()
nvdevice, err := NewDevice(devicePath) nvdevice, err := p.GetGPUByPciBusID(deviceAddress)
if err != nil { if err != nil {
return nil, fmt.Errorf("error constructing NVIDIA PCI device %s: %v", deviceDir.Name(), err) return nil, fmt.Errorf("error constructing NVIDIA PCI device %s: %v", deviceAddress, err)
} }
if nvdevice == nil { if nvdevice == nil {
continue continue
@ -168,9 +169,9 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) {
return nvdevices, nil return nvdevices, nil
} }
// NewDevice constructs an NvidiaPCIDevice // GetGPUByPciBusID constructs an NvidiaPCIDevice for the specified address (PCI Bus ID)
func NewDevice(devicePath string) (*NvidiaPCIDevice, error) { func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) {
address := path.Base(devicePath) devicePath := filepath.Join(p.pciDevicesRoot, address)
vendor, err := os.ReadFile(path.Join(devicePath, "vendor")) vendor, err := os.ReadFile(path.Join(devicePath, "vendor"))
if err != nil { if err != nil {