mirror of
https://github.com/clearml/go-nvlib
synced 2025-01-31 02:47:02 +00:00
Add filtering of devices based on PCI device class
Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
f156c34310
commit
e37e145458
@ -36,6 +36,7 @@ type Interface interface {
|
||||
|
||||
type devicelib struct {
|
||||
nvml nvml.Interface
|
||||
selectedDeviceClasses map[Class]struct{}
|
||||
}
|
||||
|
||||
var _ Interface = &devicelib{}
|
||||
@ -49,6 +50,11 @@ func New(opts ...Option) Interface {
|
||||
if d.nvml == nil {
|
||||
d.nvml = nvml.New()
|
||||
}
|
||||
if d.selectedDeviceClasses == nil {
|
||||
d.selectedDeviceClasses = map[Class]struct{}{
|
||||
ClassCompute: {},
|
||||
}
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
@ -59,5 +65,17 @@ func WithNvml(nvml nvml.Interface) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithSelectedDeviceClasses selects the specified device classes when filtering devices
|
||||
func WithSelectedDeviceClasses(classes ...Class) Option {
|
||||
return func(d *devicelib) {
|
||||
if d.selectedDeviceClasses == nil {
|
||||
d.selectedDeviceClasses = make(map[Class]struct{})
|
||||
}
|
||||
for _, c := range classes {
|
||||
d.selectedDeviceClasses[c] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Option defines a function for passing options to the New() call
|
||||
type Option func(*devicelib)
|
||||
|
@ -18,9 +18,11 @@ package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/NVIDIA/go-nvml/pkg/dl"
|
||||
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
|
||||
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci"
|
||||
)
|
||||
|
||||
// Device defines the set of extended functions associated with a device.Device
|
||||
@ -39,6 +41,15 @@ type device struct {
|
||||
lib *devicelib
|
||||
}
|
||||
|
||||
// Class represents the PCI class for a device
|
||||
type Class uint32
|
||||
|
||||
// Define constants for common device classes
|
||||
const (
|
||||
ClassCompute = Class(nvpci.PCI3dControllerClass)
|
||||
ClassDisplay = Class(nvpci.PCIVgaControllerClass)
|
||||
)
|
||||
|
||||
var _ Device = &device{}
|
||||
|
||||
// NewDevice builds a new Device from an nvml.Device
|
||||
@ -51,6 +62,16 @@ func (d *devicelib) newDevice(dev nvml.Device) (*device, error) {
|
||||
return &device{dev, d}, nil
|
||||
}
|
||||
|
||||
// classIsSelected checks whether the specified class has been selected when constructing the devicelib
|
||||
func (d *devicelib) classIsSelected(c Class) bool {
|
||||
if d.selectedDeviceClasses == nil {
|
||||
return false
|
||||
}
|
||||
_, exists := d.selectedDeviceClasses[c]
|
||||
|
||||
return exists
|
||||
}
|
||||
|
||||
// IsMigCapable checks if a device is capable of having MIG paprtitions created on it
|
||||
func (d *device) IsMigCapable() (bool, error) {
|
||||
err := nvmlLookupSymbol("nvmlDeviceGetMigMode")
|
||||
@ -188,6 +209,35 @@ func (d *device) GetMigProfiles() ([]MigProfile, error) {
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
// getClass returns the PCI device class for the device
|
||||
func (d *device) getClass() (Class, error) {
|
||||
info, ret := d.GetPciInfo()
|
||||
if ret != nvml.SUCCESS {
|
||||
return 0, fmt.Errorf("failed to get PCI info: %v", ret)
|
||||
}
|
||||
|
||||
// We convert the BusId to a string
|
||||
var bytes []byte
|
||||
for _, b := range info.BusId {
|
||||
if byte(b) == '\x00' {
|
||||
break
|
||||
}
|
||||
bytes = append(bytes, byte(b))
|
||||
}
|
||||
id := string(bytes)
|
||||
|
||||
if id != "0000" {
|
||||
id = strings.TrimPrefix(id, "0000")
|
||||
}
|
||||
|
||||
device, err := nvpci.New().GetGPUByPciBusID(id)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to construct PCI device: %v", ret)
|
||||
}
|
||||
|
||||
return Class(device.Class), nil
|
||||
}
|
||||
|
||||
// VisitDevices visits each top-level device and invokes a callback function for it
|
||||
func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
|
||||
count, ret := d.nvml.DeviceGetCount()
|
||||
@ -204,6 +254,15 @@ func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating new device wrapper: %v", err)
|
||||
}
|
||||
|
||||
class, err := dev.getClass()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting PCI device class for device: %v", err)
|
||||
}
|
||||
if !d.classIsSelected(class) {
|
||||
continue
|
||||
}
|
||||
|
||||
err = visit(i, dev)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error visiting device: %v", err)
|
||||
|
Loading…
Reference in New Issue
Block a user