From e37e145458d3661f48e6597d30f0998b233f12f1 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Tue, 15 Nov 2022 14:16:37 +0100 Subject: [PATCH] Add filtering of devices based on PCI device class Signed-off-by: Evan Lezar --- pkg/nvlib/device/api.go | 20 ++++++++++++- pkg/nvlib/device/device.go | 59 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/pkg/nvlib/device/api.go b/pkg/nvlib/device/api.go index bcc37eb..498bda8 100644 --- a/pkg/nvlib/device/api.go +++ b/pkg/nvlib/device/api.go @@ -35,7 +35,8 @@ type Interface interface { } type devicelib struct { - nvml nvml.Interface + nvml nvml.Interface + selectedDeviceClasses map[Class]struct{} } var _ Interface = &devicelib{} @@ -49,6 +50,11 @@ func New(opts ...Option) Interface { if d.nvml == nil { d.nvml = nvml.New() } + if d.selectedDeviceClasses == nil { + d.selectedDeviceClasses = map[Class]struct{}{ + ClassCompute: {}, + } + } return d } @@ -59,5 +65,17 @@ func WithNvml(nvml nvml.Interface) Option { } } +// WithSelectedDeviceClasses selects the specified device classes when filtering devices +func WithSelectedDeviceClasses(classes ...Class) Option { + return func(d *devicelib) { + if d.selectedDeviceClasses == nil { + d.selectedDeviceClasses = make(map[Class]struct{}) + } + for _, c := range classes { + d.selectedDeviceClasses[c] = struct{}{} + } + } +} + // Option defines a function for passing options to the New() call type Option func(*devicelib) diff --git a/pkg/nvlib/device/device.go b/pkg/nvlib/device/device.go index a659354..a1fbac5 100644 --- a/pkg/nvlib/device/device.go +++ b/pkg/nvlib/device/device.go @@ -18,9 +18,11 @@ package device import ( "fmt" + "strings" "github.com/NVIDIA/go-nvml/pkg/dl" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" + "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci" ) // Device defines the set of extended functions associated with a device.Device @@ -39,6 +41,15 @@ type device struct { lib *devicelib } +// Class represents the PCI class for a device +type Class uint32 + +// Define constants for common device classes +const ( + ClassCompute = Class(nvpci.PCI3dControllerClass) + ClassDisplay = Class(nvpci.PCIVgaControllerClass) +) + var _ Device = &device{} // NewDevice builds a new Device from an nvml.Device @@ -51,6 +62,16 @@ func (d *devicelib) newDevice(dev nvml.Device) (*device, error) { return &device{dev, d}, nil } +// classIsSelected checks whether the specified class has been selected when constructing the devicelib +func (d *devicelib) classIsSelected(c Class) bool { + if d.selectedDeviceClasses == nil { + return false + } + _, exists := d.selectedDeviceClasses[c] + + return exists +} + // IsMigCapable checks if a device is capable of having MIG paprtitions created on it func (d *device) IsMigCapable() (bool, error) { err := nvmlLookupSymbol("nvmlDeviceGetMigMode") @@ -188,6 +209,35 @@ func (d *device) GetMigProfiles() ([]MigProfile, error) { return profiles, nil } +// getClass returns the PCI device class for the device +func (d *device) getClass() (Class, error) { + info, ret := d.GetPciInfo() + if ret != nvml.SUCCESS { + return 0, fmt.Errorf("failed to get PCI info: %v", ret) + } + + // We convert the BusId to a string + var bytes []byte + for _, b := range info.BusId { + if byte(b) == '\x00' { + break + } + bytes = append(bytes, byte(b)) + } + id := string(bytes) + + if id != "0000" { + id = strings.TrimPrefix(id, "0000") + } + + device, err := nvpci.New().GetGPUByPciBusID(id) + if err != nil { + return 0, fmt.Errorf("failed to construct PCI device: %v", ret) + } + + return Class(device.Class), nil +} + // VisitDevices visits each top-level device and invokes a callback function for it func (d *devicelib) VisitDevices(visit func(int, Device) error) error { count, ret := d.nvml.DeviceGetCount() @@ -204,6 +254,15 @@ func (d *devicelib) VisitDevices(visit func(int, Device) error) error { if err != nil { return fmt.Errorf("error creating new device wrapper: %v", err) } + + class, err := dev.getClass() + if err != nil { + return fmt.Errorf("error getting PCI device class for device: %v", err) + } + if !d.classIsSelected(class) { + continue + } + err = visit(i, dev) if err != nil { return fmt.Errorf("error visiting device: %v", err)