From afdf3edd9914dfb03c0f525e61d3a988cad24a24 Mon Sep 17 00:00:00 2001
From: Christopher Desiniotis <cdesiniotis@nvidia.com>
Date: Thu, 28 Jul 2022 19:16:39 -0700
Subject: [PATCH] Detect if NvidiaPCIDevice is a PF or VF

Signed-off-by: Christopher Desiniotis <cdesiniotis@nvidia.com>
---
 pkg/nvpci/nvpci.go      | 16 ++++++++++++++--
 pkg/nvpci/nvpci_test.go |  2 ++
 2 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/pkg/nvpci/nvpci.go b/pkg/nvpci/nvpci.go
index 8ad2e54..da9040b 100644
--- a/pkg/nvpci/nvpci.go
+++ b/pkg/nvpci/nvpci.go
@@ -75,6 +75,7 @@ type NvidiaPCIDevice struct {
 	NumaNode   int
 	Config     *ConfigSpace
 	Resources  MemoryResources
+	IsVF       bool
 }
 
 // IsVGAController if class == 0x300
@@ -87,7 +88,7 @@ func (d *NvidiaPCIDevice) Is3DController() bool {
 	return d.Class == PCI3dControllerClass
 }
 
-// IsNVSwitch if classe == 0x068
+// IsNVSwitch if class == 0x068
 func (d *NvidiaPCIDevice) IsNVSwitch() bool {
 	return d.Class == PCINvSwitchClass
 }
@@ -218,6 +219,16 @@ func NewDevice(devicePath string) (*NvidiaPCIDevice, error) {
 		return nil, fmt.Errorf("unable to detect iommu_group for %s: %v", address, err)
 	}
 
+	// device is a virtual function (VF) if "physfn" symlink exists
+	var isVF bool
+	_, err = filepath.EvalSymlinks(path.Join(devicePath, "physfn"))
+	if err == nil {
+		isVF = true
+	}
+	if err != nil && !os.IsNotExist(err) {
+		return nil, fmt.Errorf("unable to resolve %s: %v", path.Join(devicePath, "physfn"), err)
+	}
+
 	numa, err := os.ReadFile(path.Join(devicePath, "numa_node"))
 	if err != nil {
 		return nil, fmt.Errorf("unable to read PCI NUMA node for %s: %v", address, err)
@@ -269,6 +280,7 @@ func NewDevice(devicePath string) (*NvidiaPCIDevice, error) {
 		NumaNode:   int(numaNode),
 		Config:     config,
 		Resources:  resources,
+		IsVF:       isVF,
 	}
 
 	return nvdevice, nil
@@ -334,7 +346,7 @@ func (p *nvpci) GetGPUs() ([]*NvidiaPCIDevice, error) {
 
 	var filtered []*NvidiaPCIDevice
 	for _, d := range devices {
-		if d.IsGPU() {
+		if d.IsGPU() && !d.IsVF {
 			filtered = append(filtered, d)
 		}
 	}
diff --git a/pkg/nvpci/nvpci_test.go b/pkg/nvpci/nvpci_test.go
index e6cbdbf..af05dc9 100644
--- a/pkg/nvpci/nvpci_test.go
+++ b/pkg/nvpci/nvpci_test.go
@@ -64,6 +64,8 @@ func TestNvpci(t *testing.T) {
 	}()
 	require.Equal(t, int(resource0.End-resource0.Start+1), bar0.Len())
 	require.Equal(t, ga100PmcID, bar0.Read32(0))
+
+	require.Equal(t, devices[0].IsVF, false, "Device incorrectly identified as a VF")
 }
 
 func TestNvpciNUMANode(t *testing.T) {