From d60aa34a78954a45d0a9b45ac434b74a4199df8d Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Tue, 11 Jun 2024 15:05:40 +0200 Subject: [PATCH] Add function to get the PCI bus ID for a device Signed-off-by: Evan Lezar --- pkg/nvlib/device/device.go | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/pkg/nvlib/device/device.go b/pkg/nvlib/device/device.go index 5e1510c..5b21fc1 100644 --- a/pkg/nvlib/device/device.go +++ b/pkg/nvlib/device/device.go @@ -18,6 +18,7 @@ package device import ( "fmt" + "strings" "github.com/NVIDIA/go-nvml/pkg/nvml" ) @@ -30,6 +31,7 @@ type Device interface { GetCudaComputeCapabilityAsString() (string, error) GetMigDevices() ([]MigDevice, error) GetMigProfiles() ([]MigProfile, error) + GetPCIBusID() (string, error) IsMigCapable() (bool, error) IsMigEnabled() (bool, error) VisitMigDevices(func(j int, m MigDevice) error) error @@ -140,6 +142,29 @@ func (d *device) GetBrandAsString() (string, error) { return "", fmt.Errorf("error interpreting device brand as string: %v", brand) } +// GetPCIBusID returns the string representation of the bus ID. +func (d *device) GetPCIBusID() (string, error) { + info, ret := d.GetPciInfo() + if ret != nvml.SUCCESS { + return "", fmt.Errorf("error getting PCI info: %w", ret) + } + + var bytes []byte + for _, b := range info.BusId { + if byte(b) == '\x00' { + break + } + bytes = append(bytes, byte(b)) + } + id := strings.ToLower(string(bytes)) + + if id != "0000" { + id = strings.TrimPrefix(id, "0000") + } + + return id, nil +} + // GetCudaComputeCapabilityAsString returns the Device's CUDA compute capability as a version string. func (d *device) GetCudaComputeCapabilityAsString() (string, error) { major, minor, ret := d.GetCudaComputeCapability()