mirror of
https://github.com/clearml/go-nvlib
synced 2025-06-26 18:28:08 +00:00
Merge branch 'skip-display-devices' into 'main'
Skip devices based on PCI device class See merge request nvidia/cloud-native/go-nvlib!24
This commit is contained in:
commit
9110850748
2
Makefile
2
Makefile
@ -59,7 +59,7 @@ generate:
|
|||||||
|
|
||||||
lint:
|
lint:
|
||||||
# We use `go list -f '{{.Dir}}' $(MODULE)/...` to skip the `vendor` folder.
|
# We use `go list -f '{{.Dir}}' $(MODULE)/...` to skip the `vendor` folder.
|
||||||
go list -f '{{.Dir}}' $(MODULE)/... | xargs golint -set_exit_status
|
go list -f '{{.Dir}}' $(MODULE)/... | grep -v pkg/nvml | xargs golint -set_exit_status
|
||||||
|
|
||||||
vet:
|
vet:
|
||||||
go vet $(MODULE)/...
|
go vet $(MODULE)/...
|
||||||
|
@ -35,7 +35,8 @@ type Interface interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type devicelib struct {
|
type devicelib struct {
|
||||||
nvml nvml.Interface
|
nvml nvml.Interface
|
||||||
|
selectedDeviceClasses map[Class]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Interface = &devicelib{}
|
var _ Interface = &devicelib{}
|
||||||
@ -49,6 +50,11 @@ func New(opts ...Option) Interface {
|
|||||||
if d.nvml == nil {
|
if d.nvml == nil {
|
||||||
d.nvml = nvml.New()
|
d.nvml = nvml.New()
|
||||||
}
|
}
|
||||||
|
if d.selectedDeviceClasses == nil {
|
||||||
|
d.selectedDeviceClasses = map[Class]struct{}{
|
||||||
|
ClassCompute: {},
|
||||||
|
}
|
||||||
|
}
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,5 +65,17 @@ func WithNvml(nvml nvml.Interface) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithSelectedDeviceClasses selects the specified device classes when filtering devices
|
||||||
|
func WithSelectedDeviceClasses(classes ...Class) Option {
|
||||||
|
return func(d *devicelib) {
|
||||||
|
if d.selectedDeviceClasses == nil {
|
||||||
|
d.selectedDeviceClasses = make(map[Class]struct{})
|
||||||
|
}
|
||||||
|
for _, c := range classes {
|
||||||
|
d.selectedDeviceClasses[c] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Option defines a function for passing options to the New() call
|
// Option defines a function for passing options to the New() call
|
||||||
type Option func(*devicelib)
|
type Option func(*devicelib)
|
||||||
|
@ -18,9 +18,11 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/NVIDIA/go-nvml/pkg/dl"
|
"github.com/NVIDIA/go-nvml/pkg/dl"
|
||||||
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
|
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
|
||||||
|
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Device defines the set of extended functions associated with a device.Device
|
// Device defines the set of extended functions associated with a device.Device
|
||||||
@ -39,13 +41,37 @@ type device struct {
|
|||||||
lib *devicelib
|
lib *devicelib
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Class represents the PCI class for a device
|
||||||
|
type Class uint32
|
||||||
|
|
||||||
|
// Define constants for common device classes
|
||||||
|
const (
|
||||||
|
ClassCompute = Class(nvpci.PCI3dControllerClass)
|
||||||
|
ClassDisplay = Class(nvpci.PCIVgaControllerClass)
|
||||||
|
)
|
||||||
|
|
||||||
var _ Device = &device{}
|
var _ Device = &device{}
|
||||||
|
|
||||||
// NewDevice builds a new Device from an nvml.Device
|
// NewDevice builds a new Device from an nvml.Device
|
||||||
func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) {
|
func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) {
|
||||||
|
return d.newDevice(dev)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newDevice creates a device from an nvml.Device
|
||||||
|
func (d *devicelib) newDevice(dev nvml.Device) (*device, error) {
|
||||||
return &device{dev, d}, nil
|
return &device{dev, d}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// classIsSelected checks whether the specified class has been selected when constructing the devicelib
|
||||||
|
func (d *devicelib) classIsSelected(c Class) bool {
|
||||||
|
if d.selectedDeviceClasses == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, exists := d.selectedDeviceClasses[c]
|
||||||
|
|
||||||
|
return exists
|
||||||
|
}
|
||||||
|
|
||||||
// IsMigCapable checks if a device is capable of having MIG paprtitions created on it
|
// IsMigCapable checks if a device is capable of having MIG paprtitions created on it
|
||||||
func (d *device) IsMigCapable() (bool, error) {
|
func (d *device) IsMigCapable() (bool, error) {
|
||||||
err := nvmlLookupSymbol("nvmlDeviceGetMigMode")
|
err := nvmlLookupSymbol("nvmlDeviceGetMigMode")
|
||||||
@ -183,6 +209,35 @@ func (d *device) GetMigProfiles() ([]MigProfile, error) {
|
|||||||
return profiles, nil
|
return profiles, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getClass returns the PCI device class for the device
|
||||||
|
func (d *device) getClass() (Class, error) {
|
||||||
|
info, ret := d.GetPciInfo()
|
||||||
|
if ret != nvml.SUCCESS {
|
||||||
|
return 0, fmt.Errorf("failed to get PCI info: %v", ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We convert the BusId to a string
|
||||||
|
var bytes []byte
|
||||||
|
for _, b := range info.BusId {
|
||||||
|
if byte(b) == '\x00' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
bytes = append(bytes, byte(b))
|
||||||
|
}
|
||||||
|
id := string(bytes)
|
||||||
|
|
||||||
|
if id != "0000" {
|
||||||
|
id = strings.TrimPrefix(id, "0000")
|
||||||
|
}
|
||||||
|
|
||||||
|
device, err := nvpci.New().GetGPUByPciBusID(id)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to construct PCI device: %v", ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
return Class(device.Class), nil
|
||||||
|
}
|
||||||
|
|
||||||
// VisitDevices visits each top-level device and invokes a callback function for it
|
// VisitDevices visits each top-level device and invokes a callback function for it
|
||||||
func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
|
func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
|
||||||
count, ret := d.nvml.DeviceGetCount()
|
count, ret := d.nvml.DeviceGetCount()
|
||||||
@ -195,10 +250,19 @@ func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
|
|||||||
if ret != nvml.SUCCESS {
|
if ret != nvml.SUCCESS {
|
||||||
return fmt.Errorf("error getting device handle for index '%v': %v", i, ret)
|
return fmt.Errorf("error getting device handle for index '%v': %v", i, ret)
|
||||||
}
|
}
|
||||||
dev, err := d.NewDevice(device)
|
dev, err := d.newDevice(device)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating new device wrapper: %v", err)
|
return fmt.Errorf("error creating new device wrapper: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class, err := dev.getClass()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting PCI device class for device: %v", err)
|
||||||
|
}
|
||||||
|
if !d.classIsSelected(class) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
err = visit(i, dev)
|
err = visit(i, dev)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error visiting device: %v", err)
|
return fmt.Errorf("error visiting device: %v", err)
|
||||||
|
@ -18,13 +18,14 @@ package nvmdev
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci"
|
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -241,7 +242,7 @@ func (m mdev) iommuGroup() (int, error) {
|
|||||||
|
|
||||||
// NewParentDevice constructs a ParentDevice
|
// NewParentDevice constructs a ParentDevice
|
||||||
func NewParentDevice(devicePath string) (*ParentDevice, error) {
|
func NewParentDevice(devicePath string) (*ParentDevice, error) {
|
||||||
nvdevice, err := nvpci.NewDevice(devicePath)
|
nvdevice, err := newNvidiaPCIDeviceFromPath(devicePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to construct NVIDIA PCI device: %v", err)
|
return nil, fmt.Errorf("failed to construct NVIDIA PCI device: %v", err)
|
||||||
}
|
}
|
||||||
@ -330,7 +331,7 @@ func (p *ParentDevice) GetPhysicalFunction() (*nvpci.NvidiaPCIDevice, error) {
|
|||||||
return nil, fmt.Errorf("unable to resolve %s: %v", path.Join(p.Path, "physfn"), err)
|
return nil, fmt.Errorf("unable to resolve %s: %v", path.Join(p.Path, "physfn"), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nvpci.NewDevice(physfnPath)
|
return newNvidiaPCIDeviceFromPath(physfnPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPhysicalFunction gets the physical PCI device that a vGPU is created on
|
// GetPhysicalFunction gets the physical PCI device that a vGPU is created on
|
||||||
@ -374,3 +375,10 @@ func (p *ParentDevice) GetAvailableMDEVInstances(mdevType string) (int, error) {
|
|||||||
|
|
||||||
return availableInstances, nil
|
return availableInstances, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// newNvidiaPCIDeviceFromPath constructs an NvidiaPCIDevice for the specified device path.
|
||||||
|
func newNvidiaPCIDeviceFromPath(devicePath string) (*nvpci.NvidiaPCIDevice, error) {
|
||||||
|
root := filepath.Dir(devicePath)
|
||||||
|
address := filepath.Base(devicePath)
|
||||||
|
return nvpci.NewFrom(root).GetGPUByPciBusID(address)
|
||||||
|
}
|
||||||
|
@ -49,6 +49,7 @@ type Interface interface {
|
|||||||
GetNVSwitches() ([]*NvidiaPCIDevice, error)
|
GetNVSwitches() ([]*NvidiaPCIDevice, error)
|
||||||
GetGPUs() ([]*NvidiaPCIDevice, error)
|
GetGPUs() ([]*NvidiaPCIDevice, error)
|
||||||
GetGPUByIndex(int) (*NvidiaPCIDevice, error)
|
GetGPUByIndex(int) (*NvidiaPCIDevice, error)
|
||||||
|
GetGPUByPciBusID(string) (*NvidiaPCIDevice, error)
|
||||||
GetNetworkControllers() ([]*NvidiaPCIDevice, error)
|
GetNetworkControllers() ([]*NvidiaPCIDevice, error)
|
||||||
GetPciBridges() ([]*NvidiaPCIDevice, error)
|
GetPciBridges() ([]*NvidiaPCIDevice, error)
|
||||||
GetDPUs() ([]*NvidiaPCIDevice, error)
|
GetDPUs() ([]*NvidiaPCIDevice, error)
|
||||||
@ -143,10 +144,10 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) {
|
|||||||
|
|
||||||
var nvdevices []*NvidiaPCIDevice
|
var nvdevices []*NvidiaPCIDevice
|
||||||
for _, deviceDir := range deviceDirs {
|
for _, deviceDir := range deviceDirs {
|
||||||
devicePath := path.Join(p.pciDevicesRoot, deviceDir.Name())
|
deviceAddress := deviceDir.Name()
|
||||||
nvdevice, err := NewDevice(devicePath)
|
nvdevice, err := p.GetGPUByPciBusID(deviceAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error constructing NVIDIA PCI device %s: %v", deviceDir.Name(), err)
|
return nil, fmt.Errorf("error constructing NVIDIA PCI device %s: %v", deviceAddress, err)
|
||||||
}
|
}
|
||||||
if nvdevice == nil {
|
if nvdevice == nil {
|
||||||
continue
|
continue
|
||||||
@ -168,9 +169,9 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) {
|
|||||||
return nvdevices, nil
|
return nvdevices, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDevice constructs an NvidiaPCIDevice
|
// GetGPUByPciBusID constructs an NvidiaPCIDevice for the specified address (PCI Bus ID)
|
||||||
func NewDevice(devicePath string) (*NvidiaPCIDevice, error) {
|
func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) {
|
||||||
address := path.Base(devicePath)
|
devicePath := filepath.Join(p.pciDevicesRoot, address)
|
||||||
|
|
||||||
vendor, err := os.ReadFile(path.Join(devicePath, "vendor"))
|
vendor, err := os.ReadFile(path.Join(devicePath, "vendor"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
Reference in New Issue
Block a user