From e96d9c58f131faf3d49813d0a7c9e7ff56cbe036 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Tue, 15 Nov 2022 14:29:32 +0100 Subject: [PATCH 1/4] Add GetGPUByPciBusID to nvpci.Interface This change adds a GetGPUByPciBusID method to the nvpci Interface. The exising NewDevice function is moved to nvmdev where it is used. Signed-off-by: Evan Lezar --- pkg/nvmdev/nvmdev.go | 14 +++++++++++--- pkg/nvpci/nvpci.go | 13 +++++++------ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/pkg/nvmdev/nvmdev.go b/pkg/nvmdev/nvmdev.go index 9f7b5f1..fa2c3fc 100644 --- a/pkg/nvmdev/nvmdev.go +++ b/pkg/nvmdev/nvmdev.go @@ -18,13 +18,14 @@ package nvmdev import ( "fmt" - "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci" "os" "path" "path/filepath" "sort" "strconv" "strings" + + "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci" ) const ( @@ -241,7 +242,7 @@ func (m mdev) iommuGroup() (int, error) { // NewParentDevice constructs a ParentDevice func NewParentDevice(devicePath string) (*ParentDevice, error) { - nvdevice, err := nvpci.NewDevice(devicePath) + nvdevice, err := newNvidiaPCIDeviceFromPath(devicePath) if err != nil { return nil, fmt.Errorf("failed to construct NVIDIA PCI device: %v", err) } @@ -330,7 +331,7 @@ func (p *ParentDevice) GetPhysicalFunction() (*nvpci.NvidiaPCIDevice, error) { return nil, fmt.Errorf("unable to resolve %s: %v", path.Join(p.Path, "physfn"), err) } - return nvpci.NewDevice(physfnPath) + return newNvidiaPCIDeviceFromPath(physfnPath) } // GetPhysicalFunction gets the physical PCI device that a vGPU is created on @@ -374,3 +375,10 @@ func (p *ParentDevice) GetAvailableMDEVInstances(mdevType string) (int, error) { return availableInstances, nil } + +// newNvidiaPCIDeviceFromPath constructs an NvidiaPCIDevice for the specified device path. +func newNvidiaPCIDeviceFromPath(devicePath string) (*nvpci.NvidiaPCIDevice, error) { + root := filepath.Dir(devicePath) + address := filepath.Base(devicePath) + return nvpci.NewFrom(root).GetGPUByPciBusID(address) +} diff --git a/pkg/nvpci/nvpci.go b/pkg/nvpci/nvpci.go index ae5bcd2..e6b23b6 100644 --- a/pkg/nvpci/nvpci.go +++ b/pkg/nvpci/nvpci.go @@ -49,6 +49,7 @@ type Interface interface { GetNVSwitches() ([]*NvidiaPCIDevice, error) GetGPUs() ([]*NvidiaPCIDevice, error) GetGPUByIndex(int) (*NvidiaPCIDevice, error) + GetGPUByPciBusID(string) (*NvidiaPCIDevice, error) GetNetworkControllers() ([]*NvidiaPCIDevice, error) GetPciBridges() ([]*NvidiaPCIDevice, error) GetDPUs() ([]*NvidiaPCIDevice, error) @@ -143,10 +144,10 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) { var nvdevices []*NvidiaPCIDevice for _, deviceDir := range deviceDirs { - devicePath := path.Join(p.pciDevicesRoot, deviceDir.Name()) - nvdevice, err := NewDevice(devicePath) + deviceAddress := deviceDir.Name() + nvdevice, err := p.GetGPUByPciBusID(deviceAddress) if err != nil { - return nil, fmt.Errorf("error constructing NVIDIA PCI device %s: %v", deviceDir.Name(), err) + return nil, fmt.Errorf("error constructing NVIDIA PCI device %s: %v", deviceAddress, err) } if nvdevice == nil { continue @@ -168,9 +169,9 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) { return nvdevices, nil } -// NewDevice constructs an NvidiaPCIDevice -func NewDevice(devicePath string) (*NvidiaPCIDevice, error) { - address := path.Base(devicePath) +// GetGPUByPciBusID constructs an NvidiaPCIDevice for the specified address (PCI Bus ID) +func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) { + devicePath := filepath.Join(p.pciDevicesRoot, address) vendor, err := os.ReadFile(path.Join(devicePath, "vendor")) if err != nil { From f156c34310a59edf4067adbb8c1546811d6112ec Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Tue, 15 Nov 2022 13:59:50 +0100 Subject: [PATCH 2/4] Add private constructor for creating a device Signed-off-by: Evan Lezar --- pkg/nvlib/device/device.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pkg/nvlib/device/device.go b/pkg/nvlib/device/device.go index e33480e..a659354 100644 --- a/pkg/nvlib/device/device.go +++ b/pkg/nvlib/device/device.go @@ -43,6 +43,11 @@ var _ Device = &device{} // NewDevice builds a new Device from an nvml.Device func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) { + return d.newDevice(dev) +} + +// newDevice creates a device from an nvml.Device +func (d *devicelib) newDevice(dev nvml.Device) (*device, error) { return &device{dev, d}, nil } @@ -195,7 +200,7 @@ func (d *devicelib) VisitDevices(visit func(int, Device) error) error { if ret != nvml.SUCCESS { return fmt.Errorf("error getting device handle for index '%v': %v", i, ret) } - dev, err := d.NewDevice(device) + dev, err := d.newDevice(device) if err != nil { return fmt.Errorf("error creating new device wrapper: %v", err) } From e37e145458d3661f48e6597d30f0998b233f12f1 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Tue, 15 Nov 2022 14:16:37 +0100 Subject: [PATCH 3/4] Add filtering of devices based on PCI device class Signed-off-by: Evan Lezar --- pkg/nvlib/device/api.go | 20 ++++++++++++- pkg/nvlib/device/device.go | 59 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/pkg/nvlib/device/api.go b/pkg/nvlib/device/api.go index bcc37eb..498bda8 100644 --- a/pkg/nvlib/device/api.go +++ b/pkg/nvlib/device/api.go @@ -35,7 +35,8 @@ type Interface interface { } type devicelib struct { - nvml nvml.Interface + nvml nvml.Interface + selectedDeviceClasses map[Class]struct{} } var _ Interface = &devicelib{} @@ -49,6 +50,11 @@ func New(opts ...Option) Interface { if d.nvml == nil { d.nvml = nvml.New() } + if d.selectedDeviceClasses == nil { + d.selectedDeviceClasses = map[Class]struct{}{ + ClassCompute: {}, + } + } return d } @@ -59,5 +65,17 @@ func WithNvml(nvml nvml.Interface) Option { } } +// WithSelectedDeviceClasses selects the specified device classes when filtering devices +func WithSelectedDeviceClasses(classes ...Class) Option { + return func(d *devicelib) { + if d.selectedDeviceClasses == nil { + d.selectedDeviceClasses = make(map[Class]struct{}) + } + for _, c := range classes { + d.selectedDeviceClasses[c] = struct{}{} + } + } +} + // Option defines a function for passing options to the New() call type Option func(*devicelib) diff --git a/pkg/nvlib/device/device.go b/pkg/nvlib/device/device.go index a659354..a1fbac5 100644 --- a/pkg/nvlib/device/device.go +++ b/pkg/nvlib/device/device.go @@ -18,9 +18,11 @@ package device import ( "fmt" + "strings" "github.com/NVIDIA/go-nvml/pkg/dl" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" + "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci" ) // Device defines the set of extended functions associated with a device.Device @@ -39,6 +41,15 @@ type device struct { lib *devicelib } +// Class represents the PCI class for a device +type Class uint32 + +// Define constants for common device classes +const ( + ClassCompute = Class(nvpci.PCI3dControllerClass) + ClassDisplay = Class(nvpci.PCIVgaControllerClass) +) + var _ Device = &device{} // NewDevice builds a new Device from an nvml.Device @@ -51,6 +62,16 @@ func (d *devicelib) newDevice(dev nvml.Device) (*device, error) { return &device{dev, d}, nil } +// classIsSelected checks whether the specified class has been selected when constructing the devicelib +func (d *devicelib) classIsSelected(c Class) bool { + if d.selectedDeviceClasses == nil { + return false + } + _, exists := d.selectedDeviceClasses[c] + + return exists +} + // IsMigCapable checks if a device is capable of having MIG paprtitions created on it func (d *device) IsMigCapable() (bool, error) { err := nvmlLookupSymbol("nvmlDeviceGetMigMode") @@ -188,6 +209,35 @@ func (d *device) GetMigProfiles() ([]MigProfile, error) { return profiles, nil } +// getClass returns the PCI device class for the device +func (d *device) getClass() (Class, error) { + info, ret := d.GetPciInfo() + if ret != nvml.SUCCESS { + return 0, fmt.Errorf("failed to get PCI info: %v", ret) + } + + // We convert the BusId to a string + var bytes []byte + for _, b := range info.BusId { + if byte(b) == '\x00' { + break + } + bytes = append(bytes, byte(b)) + } + id := string(bytes) + + if id != "0000" { + id = strings.TrimPrefix(id, "0000") + } + + device, err := nvpci.New().GetGPUByPciBusID(id) + if err != nil { + return 0, fmt.Errorf("failed to construct PCI device: %v", ret) + } + + return Class(device.Class), nil +} + // VisitDevices visits each top-level device and invokes a callback function for it func (d *devicelib) VisitDevices(visit func(int, Device) error) error { count, ret := d.nvml.DeviceGetCount() @@ -204,6 +254,15 @@ func (d *devicelib) VisitDevices(visit func(int, Device) error) error { if err != nil { return fmt.Errorf("error creating new device wrapper: %v", err) } + + class, err := dev.getClass() + if err != nil { + return fmt.Errorf("error getting PCI device class for device: %v", err) + } + if !d.classIsSelected(class) { + continue + } + err = visit(i, dev) if err != nil { return fmt.Errorf("error visiting device: %v", err) From 4a0fdc2e8a5c1869f82995b57726abb0d94337e3 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Wed, 16 Nov 2022 10:01:51 +0100 Subject: [PATCH 4/4] Skip pkg/nvml folder when linting Signed-off-by: Evan Lezar --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index cd54b22..78a5692 100644 --- a/Makefile +++ b/Makefile @@ -59,7 +59,7 @@ generate: lint: # We use `go list -f '{{.Dir}}' $(MODULE)/...` to skip the `vendor` folder. - go list -f '{{.Dir}}' $(MODULE)/... | xargs golint -set_exit_status + go list -f '{{.Dir}}' $(MODULE)/... | grep -v pkg/nvml | xargs golint -set_exit_status vet: go vet $(MODULE)/...