diff --git a/pkg/nvmdev/nvmdev.go b/pkg/nvmdev/nvmdev.go index c1521ca..9f7b5f1 100644 --- a/pkg/nvmdev/nvmdev.go +++ b/pkg/nvmdev/nvmdev.go @@ -319,6 +319,25 @@ func (m *Device) Delete() error { return nil } +// GetPhysicalFunction gets the physical PCI device backing a 'parent' device +func (p *ParentDevice) GetPhysicalFunction() (*nvpci.NvidiaPCIDevice, error) { + if !p.IsVF { + return p.NvidiaPCIDevice, nil + } + + physfnPath, err := filepath.EvalSymlinks(path.Join(p.Path, "physfn")) + if err != nil { + return nil, fmt.Errorf("unable to resolve %s: %v", path.Join(p.Path, "physfn"), err) + } + + return nvpci.NewDevice(physfnPath) +} + +// GetPhysicalFunction gets the physical PCI device that a vGPU is created on +func (m *Device) GetPhysicalFunction() (*nvpci.NvidiaPCIDevice, error) { + return m.Parent.GetPhysicalFunction() +} + // IsMDEVTypeSupported checks if the mdevType is supported by the GPU func (p *ParentDevice) IsMDEVTypeSupported(mdevType string) bool { _, found := p.mdevPaths[mdevType] diff --git a/pkg/nvmdev/nvmdev_test.go b/pkg/nvmdev/nvmdev_test.go index bd5b60c..517c895 100644 --- a/pkg/nvmdev/nvmdev_test.go +++ b/pkg/nvmdev/nvmdev_test.go @@ -33,6 +33,11 @@ func TestNvmdev(t *testing.T) { require.Equal(t, 1, len(parentDevs), "Wrong number of parent GPU devices") parentA100 := parentDevs[0] + + pf, err := parentA100.GetPhysicalFunction() + require.Nil(t, err, "Error getting physical function backing the Mock A100 parent device") + require.Equal(t, "0000:3b:04.1", pf.Address, "Wrong address for Mock A100 physical function") + supported := parentA100.IsMDEVTypeSupported("A100-4C") require.True(t, supported, "A100-4C should be a supported vGPU type") @@ -46,7 +51,14 @@ func TestNvmdev(t *testing.T) { mdevs, err := nvmdev.GetAllDevices() require.Nil(t, err, "Error getting NVIDIA MDEV (vGPU) devices") require.Equal(t, 1, len(mdevs), "Wrong number of NVIDIA MDEV (vGPU) devices") - require.Equal(t, "A100-4C", mdevs[0].MDEVType, "Wrong value for mdev_type") - require.Equal(t, "vfio_mdev", mdevs[0].Driver, "Wrong driver detected for mdev device") - require.Equal(t, 200, mdevs[0].IommuGroup, "Wrong value for iommu_group") + + mdevA100 := mdevs[0] + + require.Equal(t, "A100-4C", mdevA100.MDEVType, "Wrong value for mdev_type") + require.Equal(t, "vfio_mdev", mdevA100.Driver, "Wrong driver detected for mdev device") + require.Equal(t, 200, mdevA100.IommuGroup, "Wrong value for iommu_group") + + pf, err = mdevA100.GetPhysicalFunction() + require.Nil(t, err, "Error getting the physical function for Mock A100 mediated device") + require.Equal(t, "0000:3b:04.1", pf.Address, "Wrong address for Mock A100 physical function") } diff --git a/pkg/nvpci/nvpci.go b/pkg/nvpci/nvpci.go index da9040b..e413e73 100644 --- a/pkg/nvpci/nvpci.go +++ b/pkg/nvpci/nvpci.go @@ -46,6 +46,7 @@ type Interface interface { GetVGAControllers() ([]*NvidiaPCIDevice, error) GetNVSwitches() ([]*NvidiaPCIDevice, error) GetGPUs() ([]*NvidiaPCIDevice, error) + GetGPUByIndex(int) (*NvidiaPCIDevice, error) } // MemoryResources a more human readable handle @@ -353,3 +354,17 @@ func (p *nvpci) GetGPUs() ([]*NvidiaPCIDevice, error) { return filtered, nil } + +// GetGPUByIndex returns an NVIDIA GPU device at a particular index +func (p *nvpci) GetGPUByIndex(i int) (*NvidiaPCIDevice, error) { + gpus, err := p.GetGPUs() + if err != nil { + return nil, fmt.Errorf("error getting all gpus: %v", err) + } + + if i < 0 || i >= len(gpus) { + return nil, fmt.Errorf("invalid index '%d'", i) + } + + return gpus[i], nil +} diff --git a/pkg/nvpci/nvpci_test.go b/pkg/nvpci/nvpci_test.go index af05dc9..c44b7f3 100644 --- a/pkg/nvpci/nvpci_test.go +++ b/pkg/nvpci/nvpci_test.go @@ -66,6 +66,13 @@ func TestNvpci(t *testing.T) { require.Equal(t, ga100PmcID, bar0.Read32(0)) require.Equal(t, devices[0].IsVF, false, "Device incorrectly identified as a VF") + + device, err := nvpci.GetGPUByIndex(0) + require.Nil(t, err, "Error getting GPU at index 0") + require.Equal(t, "0000:80:05.1", device.Address, "Wrong Address found for device") + + device, err = nvpci.GetGPUByIndex(1) + require.Error(t, err, "No error returned when getting GPU at invalid index") } func TestNvpciNUMANode(t *testing.T) {