From 6ff7845b92d1cf70597b87f8afba4e05b48eef04 Mon Sep 17 00:00:00 2001 From: Christopher Desiniotis Date: Tue, 23 Aug 2022 15:24:26 -0700 Subject: [PATCH] nvpci: Add GetGPUByIndex() Signed-off-by: Christopher Desiniotis --- pkg/nvpci/nvpci.go | 15 +++++++++++++++ pkg/nvpci/nvpci_test.go | 7 +++++++ 2 files changed, 22 insertions(+) diff --git a/pkg/nvpci/nvpci.go b/pkg/nvpci/nvpci.go index da9040b..e413e73 100644 --- a/pkg/nvpci/nvpci.go +++ b/pkg/nvpci/nvpci.go @@ -46,6 +46,7 @@ type Interface interface { GetVGAControllers() ([]*NvidiaPCIDevice, error) GetNVSwitches() ([]*NvidiaPCIDevice, error) GetGPUs() ([]*NvidiaPCIDevice, error) + GetGPUByIndex(int) (*NvidiaPCIDevice, error) } // MemoryResources a more human readable handle @@ -353,3 +354,17 @@ func (p *nvpci) GetGPUs() ([]*NvidiaPCIDevice, error) { return filtered, nil } + +// GetGPUByIndex returns an NVIDIA GPU device at a particular index +func (p *nvpci) GetGPUByIndex(i int) (*NvidiaPCIDevice, error) { + gpus, err := p.GetGPUs() + if err != nil { + return nil, fmt.Errorf("error getting all gpus: %v", err) + } + + if i < 0 || i >= len(gpus) { + return nil, fmt.Errorf("invalid index '%d'", i) + } + + return gpus[i], nil +} diff --git a/pkg/nvpci/nvpci_test.go b/pkg/nvpci/nvpci_test.go index af05dc9..c44b7f3 100644 --- a/pkg/nvpci/nvpci_test.go +++ b/pkg/nvpci/nvpci_test.go @@ -66,6 +66,13 @@ func TestNvpci(t *testing.T) { require.Equal(t, ga100PmcID, bar0.Read32(0)) require.Equal(t, devices[0].IsVF, false, "Device incorrectly identified as a VF") + + device, err := nvpci.GetGPUByIndex(0) + require.Nil(t, err, "Error getting GPU at index 0") + require.Equal(t, "0000:80:05.1", device.Address, "Wrong Address found for device") + + device, err = nvpci.GetGPUByIndex(1) + require.Error(t, err, "No error returned when getting GPU at invalid index") } func TestNvpciNUMANode(t *testing.T) {