From dc53500d0e8e5c9a43e4fdb961a3b9b8497dfbd0 Mon Sep 17 00:00:00 2001
From: Christopher Desiniotis <cdesiniotis@nvidia.com>
Date: Thu, 7 Jul 2022 13:42:25 -0700
Subject: [PATCH 1/6] Update .gitignore

Signed-off-by: Christopher Desiniotis <cdesiniotis@nvidia.com>
---
 .gitignore | 1 +
 1 file changed, 1 insertion(+)

diff --git a/.gitignore b/.gitignore
index 3819313..0eb7e8c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
 *.swp
 *.swo
+*.test

From e2d858daed9828d2310b9259945626136b9f2ae4 Mon Sep 17 00:00:00 2001
From: Christopher Desiniotis <cdesiniotis@nvidia.com>
Date: Thu, 7 Jul 2022 13:43:11 -0700
Subject: [PATCH 2/6] Use 'os' instead of 'ioutil' which is recommended
 starting with Go 1.16

Signed-off-by: Christopher Desiniotis <cdesiniotis@nvidia.com>
---
 pkg/nvmdev/mock.go   |  5 ++---
 pkg/nvmdev/nvmdev.go | 11 +++++------
 pkg/nvpci/config.go  |  4 ++--
 pkg/nvpci/mock.go    |  3 +--
 pkg/nvpci/nvpci.go   | 15 +++++++--------
 5 files changed, 17 insertions(+), 21 deletions(-)

diff --git a/pkg/nvmdev/mock.go b/pkg/nvmdev/mock.go
index 76a82e2..8fd0e8c 100644
--- a/pkg/nvmdev/mock.go
+++ b/pkg/nvmdev/mock.go
@@ -20,7 +20,6 @@ import (
 	"fmt"
 	"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci"
 	"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci/bytes"
-	"io/ioutil"
 	"os"
 	"path/filepath"
 )
@@ -34,7 +33,7 @@ var _ Interface = (*MockNvmdev)(nil)
 
 // NewMock creates new mock mediated (vGPU) and parent PCI devices and removes old devices
 func NewMock() (mock *MockNvmdev, rerr error) {
-	mdevParentsRootDir, err := ioutil.TempDir("", "")
+	mdevParentsRootDir, err := os.MkdirTemp(os.TempDir(), "")
 	if err != nil {
 		return nil, err
 	}
@@ -43,7 +42,7 @@ func NewMock() (mock *MockNvmdev, rerr error) {
 			os.RemoveAll(mdevParentsRootDir)
 		}
 	}()
-	mdevDevicesRootDir, err := ioutil.TempDir("", "")
+	mdevDevicesRootDir, err := os.MkdirTemp(os.TempDir(), "")
 	if err != nil {
 		return nil, err
 	}
diff --git a/pkg/nvmdev/nvmdev.go b/pkg/nvmdev/nvmdev.go
index ea45836..6a6f060 100644
--- a/pkg/nvmdev/nvmdev.go
+++ b/pkg/nvmdev/nvmdev.go
@@ -19,7 +19,6 @@ package nvmdev
 import (
 	"fmt"
 	"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci"
-	"io/ioutil"
 	"os"
 	"path"
 	"path/filepath"
@@ -67,7 +66,7 @@ func New() Interface {
 
 // GetAllParentDevices returns all NVIDIA Parent PCI devices on the system
 func (m *nvmdev) GetAllParentDevices() ([]*ParentDevice, error) {
-	deviceDirs, err := ioutil.ReadDir(m.mdevParentsRoot)
+	deviceDirs, err := os.ReadDir(m.mdevParentsRoot)
 	if err != nil {
 		return nil, fmt.Errorf("unable to read PCI bus devices: %v", err)
 	}
@@ -101,7 +100,7 @@ func (m *nvmdev) GetAllParentDevices() ([]*ParentDevice, error) {
 
 // GetAllDevices returns all NVIDIA mdev (vGPU) devices on the system
 func (m *nvmdev) GetAllDevices() ([]*Device, error) {
-	deviceDirs, err := ioutil.ReadDir(m.mdevDevicesRoot)
+	deviceDirs, err := os.ReadDir(m.mdevDevicesRoot)
 	if err != nil {
 		return nil, fmt.Errorf("unable to read MDEV devices directory: %v", err)
 	}
@@ -174,7 +173,7 @@ func (m mdev) parentDevicePath() string {
 }
 
 func (m mdev) Type() (string, error) {
-	mdevType, err := ioutil.ReadFile(path.Join(string(m), "name"))
+	mdevType, err := os.ReadFile(path.Join(string(m), "name"))
 	if err != nil {
 		return "", fmt.Errorf("unable to read mdev_type name for mdev %s: %v", m, err)
 	}
@@ -205,7 +204,7 @@ func NewParentDevice(devicePath string) (*ParentDevice, error) {
 	}
 	mdevTypesMap := make(map[string]string)
 	for _, path := range paths {
-		name, err := ioutil.ReadFile(path)
+		name, err := os.ReadFile(path)
 		if err != nil {
 			return nil, fmt.Errorf("unable to read file %s: %v", path, err)
 		}
@@ -292,7 +291,7 @@ func (p *ParentDevice) GetAvailableMDEVInstances(mdevType string) (int, error) {
 		return -1, nil
 	}
 
-	available, err := ioutil.ReadFile(filepath.Join(mdevPath, "available_instances"))
+	available, err := os.ReadFile(filepath.Join(mdevPath, "available_instances"))
 	if err != nil {
 		return -1, fmt.Errorf("unable to read available_instances file: %v", err)
 	}
diff --git a/pkg/nvpci/config.go b/pkg/nvpci/config.go
index 7cd2920..8a23342 100644
--- a/pkg/nvpci/config.go
+++ b/pkg/nvpci/config.go
@@ -18,7 +18,7 @@ package nvpci
 
 import (
 	"fmt"
-	"io/ioutil"
+	"os"
 
 	"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci/bytes"
 )
@@ -71,7 +71,7 @@ type PCICapabilities struct {
 }
 
 func (cs *ConfigSpace) Read() (ConfigSpaceIO, error) {
-	config, err := ioutil.ReadFile(cs.Path)
+	config, err := os.ReadFile(cs.Path)
 	if err != nil {
 		return nil, fmt.Errorf("failed to open file: %v", err)
 	}
diff --git a/pkg/nvpci/mock.go b/pkg/nvpci/mock.go
index 5c13ae1..ce7ecb1 100644
--- a/pkg/nvpci/mock.go
+++ b/pkg/nvpci/mock.go
@@ -18,7 +18,6 @@ package nvpci
 
 import (
 	"fmt"
-	"io/ioutil"
 	"os"
 	"path/filepath"
 
@@ -34,7 +33,7 @@ var _ Interface = (*MockNvpci)(nil)
 
 // NewMockNvpci create new mock PCI and remove old devices
 func NewMockNvpci() (mock *MockNvpci, rerr error) {
-	rootDir, err := ioutil.TempDir("", "")
+	rootDir, err := os.MkdirTemp(os.TempDir(), "")
 	if err != nil {
 		return nil, err
 	}
diff --git a/pkg/nvpci/nvpci.go b/pkg/nvpci/nvpci.go
index 61a8bd3..d6737e6 100644
--- a/pkg/nvpci/nvpci.go
+++ b/pkg/nvpci/nvpci.go
@@ -18,7 +18,6 @@ package nvpci
 
 import (
 	"fmt"
-	"io/ioutil"
 	"os"
 	"path"
 	"sort"
@@ -104,7 +103,7 @@ func (d *NvidiaPCIDevice) IsResetAvailable() bool {
 
 // Reset perform a reset to apply a new configuration at HW level
 func (d *NvidiaPCIDevice) Reset() error {
-	err := ioutil.WriteFile(path.Join(d.Path, "reset"), []byte("1"), 0)
+	err := os.WriteFile(path.Join(d.Path, "reset"), []byte("1"), 0)
 	if err != nil {
 		return fmt.Errorf("unable to write to reset file: %v", err)
 	}
@@ -123,7 +122,7 @@ func NewFrom(root string) Interface {
 
 // GetAllDevices returns all Nvidia PCI devices on the system
 func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) {
-	deviceDirs, err := ioutil.ReadDir(p.pciDevicesRoot)
+	deviceDirs, err := os.ReadDir(p.pciDevicesRoot)
 	if err != nil {
 		return nil, fmt.Errorf("unable to read PCI bus devices: %v", err)
 	}
@@ -159,7 +158,7 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) {
 func NewDevice(devicePath string) (*NvidiaPCIDevice, error) {
 	address := path.Base(devicePath)
 
-	vendor, err := ioutil.ReadFile(path.Join(devicePath, "vendor"))
+	vendor, err := os.ReadFile(path.Join(devicePath, "vendor"))
 	if err != nil {
 		return nil, fmt.Errorf("unable to read PCI device vendor id for %s: %v", address, err)
 	}
@@ -173,7 +172,7 @@ func NewDevice(devicePath string) (*NvidiaPCIDevice, error) {
 		return nil, nil
 	}
 
-	class, err := ioutil.ReadFile(path.Join(devicePath, "class"))
+	class, err := os.ReadFile(path.Join(devicePath, "class"))
 	if err != nil {
 		return nil, fmt.Errorf("unable to read PCI device class for %s: %v", address, err)
 	}
@@ -183,7 +182,7 @@ func NewDevice(devicePath string) (*NvidiaPCIDevice, error) {
 		return nil, fmt.Errorf("unable to convert class string to uint32: %v", classStr)
 	}
 
-	device, err := ioutil.ReadFile(path.Join(devicePath, "device"))
+	device, err := os.ReadFile(path.Join(devicePath, "device"))
 	if err != nil {
 		return nil, fmt.Errorf("unable to read PCI device id for %s: %v", address, err)
 	}
@@ -193,7 +192,7 @@ func NewDevice(devicePath string) (*NvidiaPCIDevice, error) {
 		return nil, fmt.Errorf("unable to convert device string to uint16: %v", deviceStr)
 	}
 
-	numa, err := ioutil.ReadFile(path.Join(devicePath, "numa_node"))
+	numa, err := os.ReadFile(path.Join(devicePath, "numa_node"))
 	if err != nil {
 		return nil, fmt.Errorf("unable to read PCI NUMA node for %s: %v", address, err)
 	}
@@ -207,7 +206,7 @@ func NewDevice(devicePath string) (*NvidiaPCIDevice, error) {
 		Path: path.Join(devicePath, "config"),
 	}
 
-	resource, err := ioutil.ReadFile(path.Join(devicePath, "resource"))
+	resource, err := os.ReadFile(path.Join(devicePath, "resource"))
 	if err != nil {
 		return nil, fmt.Errorf("unable to read PCI resource file for %s: %v", address, err)
 	}

From 09edde0a0b219c116778edf62ddca94dbebc2619 Mon Sep 17 00:00:00 2001
From: Christopher Desiniotis <cdesiniotis@nvidia.com>
Date: Thu, 7 Jul 2022 14:40:09 -0700
Subject: [PATCH 3/6] Detect driver bound to an NvidiaPCIDevice

Signed-off-by: Christopher Desiniotis <cdesiniotis@nvidia.com>
---
 pkg/nvpci/mock.go       |  9 +++++++++
 pkg/nvpci/nvpci.go      | 12 ++++++++++++
 pkg/nvpci/nvpci_test.go |  1 +
 3 files changed, 22 insertions(+)

diff --git a/pkg/nvpci/mock.go b/pkg/nvpci/mock.go
index ce7ecb1..438f562 100644
--- a/pkg/nvpci/mock.go
+++ b/pkg/nvpci/mock.go
@@ -90,6 +90,15 @@ func (m *MockNvpci) AddMockA100(address string, numaNode int) error {
 		return err
 	}
 
+	_, err = os.Create(filepath.Join(deviceDir, "nvidia"))
+	if err != nil {
+		return err
+	}
+	err = os.Symlink(filepath.Join(deviceDir, "nvidia"), filepath.Join(deviceDir, "driver"))
+	if err != nil {
+		return err
+	}
+
 	numa, err := os.Create(filepath.Join(deviceDir, "numa_node"))
 	if err != nil {
 		return err
diff --git a/pkg/nvpci/nvpci.go b/pkg/nvpci/nvpci.go
index d6737e6..924189f 100644
--- a/pkg/nvpci/nvpci.go
+++ b/pkg/nvpci/nvpci.go
@@ -20,6 +20,7 @@ import (
 	"fmt"
 	"os"
 	"path"
+	"path/filepath"
 	"sort"
 	"strconv"
 	"strings"
@@ -69,6 +70,7 @@ type NvidiaPCIDevice struct {
 	Vendor    uint16
 	Class     uint32
 	Device    uint16
+	Driver    string
 	NumaNode  int
 	Config    *ConfigSpace
 	Resources MemoryResources
@@ -192,6 +194,15 @@ func NewDevice(devicePath string) (*NvidiaPCIDevice, error) {
 		return nil, fmt.Errorf("unable to convert device string to uint16: %v", deviceStr)
 	}
 
+	driver, err := filepath.EvalSymlinks(path.Join(devicePath, "driver"))
+	if err == nil {
+		driver = filepath.Base(driver)
+	} else if os.IsNotExist(err) {
+		driver = ""
+	} else {
+		return nil, fmt.Errorf("unable to detect driver for %s: %v", address, err)
+	}
+
 	numa, err := os.ReadFile(path.Join(devicePath, "numa_node"))
 	if err != nil {
 		return nil, fmt.Errorf("unable to read PCI NUMA node for %s: %v", address, err)
@@ -238,6 +249,7 @@ func NewDevice(devicePath string) (*NvidiaPCIDevice, error) {
 		Vendor:    uint16(vendorID),
 		Class:     uint32(classID),
 		Device:    uint16(deviceID),
+		Driver:    driver,
 		NumaNode:  int(numaNode),
 		Config:    config,
 		Resources: resources,
diff --git a/pkg/nvpci/nvpci_test.go b/pkg/nvpci/nvpci_test.go
index f651a56..8dbf50f 100644
--- a/pkg/nvpci/nvpci_test.go
+++ b/pkg/nvpci/nvpci_test.go
@@ -45,6 +45,7 @@ func TestNvpci(t *testing.T) {
 	require.Nil(t, err, "Error reading config")
 	require.Equal(t, devices[0].Vendor, config.GetVendorID(), "Vendor IDs do not match")
 	require.Equal(t, devices[0].Device, config.GetDeviceID(), "Device IDs do not match")
+	require.Equal(t, "nvidia", devices[0].Driver, "Wrong driver detected for device")
 
 	capabilities, err := config.GetPCICapabilities()
 	require.Nil(t, err, "Error getting PCI capabilities")

From d65cf69086da592893d95644ffd3b916cc2b938a Mon Sep 17 00:00:00 2001
From: Christopher Desiniotis <cdesiniotis@nvidia.com>
Date: Thu, 7 Jul 2022 14:54:52 -0700
Subject: [PATCH 4/6] Bump go to 1.16 in Makefile to align with minimum go
 version specified in go.mod

Signed-off-by: Christopher Desiniotis <cdesiniotis@nvidia.com>
---
 Makefile | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/Makefile b/Makefile
index 61ccd68..fb5f48a 100644
--- a/Makefile
+++ b/Makefile
@@ -16,7 +16,7 @@ MODULE := gitlab.com/nvidia/cloud-native/go-nvlib
 
 DOCKER ?= docker
 
-GOLANG_VERSION := 1.15
+GOLANG_VERSION := 1.16
 
 ifeq ($(IMAGE),)
 REGISTRY ?= nvidia

From 805db5afa8c6442a2a6fcbf29d188e76fbff30a7 Mon Sep 17 00:00:00 2001
From: Christopher Desiniotis <cdesiniotis@nvidia.com>
Date: Fri, 8 Jul 2022 11:53:03 -0700
Subject: [PATCH 5/6] Refactor how mdev's are represented internally in nvmdev.

The 'mdev' string now represents the absolute path to an
mdev device (/sys/bus/pci/devices/<addr>/<uuid>) instead
of the  mdev_type directory for the mdev device
(/sys/bus/pci/devices/<addr>/mdev_supported_types/<mdev-type>).
This is more intuitive and will make it easier to get
more information about a particular mdev device -
like the driver or iommu_group it belongs to - which can
be found at /sys/bus/pci/devices/<addr>/<uuid>.

Signed-off-by: Christopher Desiniotis <cdesiniotis@nvidia.com>
---
 pkg/nvmdev/mock.go        | 14 ++++++++++----
 pkg/nvmdev/nvmdev.go      | 18 ++++++++++++------
 pkg/nvmdev/nvmdev_test.go |  4 +---
 3 files changed, 23 insertions(+), 13 deletions(-)

diff --git a/pkg/nvmdev/mock.go b/pkg/nvmdev/mock.go
index 8fd0e8c..ed7a055 100644
--- a/pkg/nvmdev/mock.go
+++ b/pkg/nvmdev/mock.go
@@ -183,14 +183,20 @@ func (m *MockNvmdev) AddMockA100Parent(address string, numaNode int) error {
 
 // AddMockA100Mdev creates an A100 like MDEV (vGPU) mock device.
 // The corresponding mocked parent A100 device must be created beforehand.
-func (m *MockNvmdev) AddMockA100Mdev(uuid string, mdevType string, parentMdevTypeDir string) error {
-	deviceDir := filepath.Join(m.mdevDevicesRoot, uuid)
-	err := os.MkdirAll(deviceDir, 0755)
+func (m *MockNvmdev) AddMockA100Mdev(uuid string, mdevType string, mdevTypeDir string, parentDeviceDir string) error {
+	mdevDeviceDir := filepath.Join(parentDeviceDir, uuid)
+	err := os.Mkdir(mdevDeviceDir, 0755)
 	if err != nil {
 		return err
 	}
 
-	err = os.Symlink(parentMdevTypeDir, filepath.Join(deviceDir, "mdev_type"))
+	parentMdevTypeDir := filepath.Join(parentDeviceDir, "mdev_supported_types", mdevTypeDir)
+	err = os.Symlink(parentMdevTypeDir, filepath.Join(mdevDeviceDir, "mdev_type"))
+	if err != nil {
+		return err
+	}
+
+	err = os.Symlink(mdevDeviceDir, filepath.Join(m.mdevDevicesRoot, uuid))
 	if err != nil {
 		return err
 	}
diff --git a/pkg/nvmdev/nvmdev.go b/pkg/nvmdev/nvmdev.go
index 6a6f060..3e4befe 100644
--- a/pkg/nvmdev/nvmdev.go
+++ b/pkg/nvmdev/nvmdev.go
@@ -153,27 +153,33 @@ func NewDevice(root string, uuid string) (*Device, error) {
 	return &device, nil
 }
 
+// mdev represents the path to an NVIDIA mdev (vGPU) device.
 type mdev string
 
 func newMdev(devicePath string) (mdev, error) {
-	mdevTypeDir, err := filepath.EvalSymlinks(path.Join(devicePath, "mdev_type"))
+	mdevDir, err := filepath.EvalSymlinks(devicePath)
 	if err != nil {
-		return "", fmt.Errorf("error resolving mdev_type link: %v", err)
+		return "", fmt.Errorf("error resolving symlink for %s: %v", devicePath, err)
 	}
 
-	return mdev(mdevTypeDir), nil
+	return mdev(mdevDir), nil
 }
 
 func (m mdev) String() string {
 	return string(m)
 }
 func (m mdev) parentDevicePath() string {
-	// /sys/bus/pci/devices/<addr>/mdev_supported_types/<mdev_type>
-	return path.Dir(path.Dir(string(m)))
+	// /sys/bus/pci/devices/<addr>/<uuid>
+	return path.Dir(string(m))
 }
 
 func (m mdev) Type() (string, error) {
-	mdevType, err := os.ReadFile(path.Join(string(m), "name"))
+	mdevTypeDir, err := filepath.EvalSymlinks(path.Join(string(m), "mdev_type"))
+	if err != nil {
+		return "", fmt.Errorf("error resolving mdev_type link for mdev %s: %v", m, err)
+	}
+
+	mdevType, err := os.ReadFile(path.Join(mdevTypeDir, "name"))
 	if err != nil {
 		return "", fmt.Errorf("unable to read mdev_type name for mdev %s: %v", m, err)
 	}
diff --git a/pkg/nvmdev/nvmdev_test.go b/pkg/nvmdev/nvmdev_test.go
index 43815f4..d044a9c 100644
--- a/pkg/nvmdev/nvmdev_test.go
+++ b/pkg/nvmdev/nvmdev_test.go
@@ -18,7 +18,6 @@ package nvmdev
 
 import (
 	"github.com/stretchr/testify/require"
-	"path/filepath"
 	"testing"
 )
 
@@ -41,8 +40,7 @@ func TestNvmdev(t *testing.T) {
 	require.Nil(t, err, "Error checking if A100-4Q vGPU type is available for creation")
 	require.True(t, available, "A100-4C should be available to create")
 
-	err = nvmdev.AddMockA100Mdev("b1914f0a-15cf-416e-8967-55fc7cb68e20", "A100-4C",
-		filepath.Join(parentDevs[0].Path, "mdev_supported_types/nvidia-500"))
+	err = nvmdev.AddMockA100Mdev("b1914f0a-15cf-416e-8967-55fc7cb68e20", "A100-4C", "nvidia-500", parentDevs[0].Path)
 	require.Nil(t, err, "Error adding Mock A100 mediated device")
 
 	mdevs, err := nvmdev.GetAllDevices()

From 8209652159f2e3cc7cdb5de9e475001058282274 Mon Sep 17 00:00:00 2001
From: Christopher Desiniotis <cdesiniotis@nvidia.com>
Date: Fri, 8 Jul 2022 13:11:01 -0700
Subject: [PATCH 6/6] Detect driver bound to mdev devices

Signed-off-by: Christopher Desiniotis <cdesiniotis@nvidia.com>
---
 pkg/nvmdev/mock.go        |  9 +++++++++
 pkg/nvmdev/nvmdev.go      | 15 +++++++++++++++
 pkg/nvmdev/nvmdev_test.go |  1 +
 3 files changed, 25 insertions(+)

diff --git a/pkg/nvmdev/mock.go b/pkg/nvmdev/mock.go
index ed7a055..626e3e9 100644
--- a/pkg/nvmdev/mock.go
+++ b/pkg/nvmdev/mock.go
@@ -196,6 +196,15 @@ func (m *MockNvmdev) AddMockA100Mdev(uuid string, mdevType string, mdevTypeDir s
 		return err
 	}
 
+	_, err = os.Create(filepath.Join(mdevDeviceDir, "vfio_mdev"))
+	if err != nil {
+		return err
+	}
+	err = os.Symlink(filepath.Join(mdevDeviceDir, "vfio_mdev"), filepath.Join(mdevDeviceDir, "driver"))
+	if err != nil {
+		return err
+	}
+
 	err = os.Symlink(mdevDeviceDir, filepath.Join(m.mdevDevicesRoot, uuid))
 	if err != nil {
 		return err
diff --git a/pkg/nvmdev/nvmdev.go b/pkg/nvmdev/nvmdev.go
index 3e4befe..dfcfef7 100644
--- a/pkg/nvmdev/nvmdev.go
+++ b/pkg/nvmdev/nvmdev.go
@@ -56,6 +56,7 @@ type Device struct {
 	Path     string
 	UUID     string
 	MDEVType string
+	Driver   string
 	Parent   *ParentDevice
 }
 
@@ -143,10 +144,16 @@ func NewDevice(root string, uuid string) (*Device, error) {
 		return nil, fmt.Errorf("error getting mdev type: %v", err)
 	}
 
+	driver, err := m.driver()
+	if err != nil {
+		return nil, fmt.Errorf("error detecting driver: %v", err)
+	}
+
 	device := Device{
 		Path:     path,
 		UUID:     uuid,
 		MDEVType: mdevType,
+		Driver:   driver,
 		Parent:   parent,
 	}
 
@@ -193,6 +200,14 @@ func (m mdev) Type() (string, error) {
 	return mdevTypeSplit[1], nil
 }
 
+func (m mdev) driver() (string, error) {
+	driver, err := filepath.EvalSymlinks(path.Join(string(m), "driver"))
+	if err != nil {
+		return "", err
+	}
+	return filepath.Base(driver), nil
+}
+
 // NewParentDevice constructs a ParentDevice
 func NewParentDevice(devicePath string) (*ParentDevice, error) {
 	nvdevice, err := nvpci.NewDevice(devicePath)
diff --git a/pkg/nvmdev/nvmdev_test.go b/pkg/nvmdev/nvmdev_test.go
index d044a9c..78d1058 100644
--- a/pkg/nvmdev/nvmdev_test.go
+++ b/pkg/nvmdev/nvmdev_test.go
@@ -46,4 +46,5 @@ func TestNvmdev(t *testing.T) {
 	mdevs, err := nvmdev.GetAllDevices()
 	require.Nil(t, err, "Error getting NVIDIA MDEV (vGPU) devices")
 	require.Equal(t, 1, len(mdevs), "Wrong number of NVIDIA MDEV (vGPU) devices")
+	require.Equal(t, "vfio_mdev", mdevs[0].Driver, "Wrong driver detected for mdev device")
 }