Add nvpci.Interface to the nvmdev struct to aid in unit tests

Signed-off-by: Christopher Desiniotis <cdesiniotis@nvidia.com>
This commit is contained in:
Christopher Desiniotis 2024-07-15 14:32:53 -07:00
parent 44a5440a0a
commit 7c3222d683
No known key found for this signature in database
GPG Key ID: 603C8E544D789A89
2 changed files with 53 additions and 10 deletions

View File

@ -28,6 +28,7 @@ import (
// MockNvmdev mock mdev device.
type MockNvmdev struct {
*nvmdev
pciDevicesRoot string
}
var _ Interface = (*MockNvmdev)(nil)
@ -53,8 +54,24 @@ func NewMock() (mock *MockNvmdev, rerr error) {
}
}()
pciRootDir, err := os.MkdirTemp(os.TempDir(), "")
if err != nil {
return nil, err
}
defer func() {
if rerr != nil {
os.RemoveAll(pciRootDir)
}
}()
nvpciLib := nvpci.New(nvpci.WithPCIDevicesRoot(pciRootDir))
mock = &MockNvmdev{
&nvmdev{mdevParentsRootDir, mdevDevicesRootDir},
nvmdev: &nvmdev{
mdevParentsRoot: mdevParentsRootDir,
mdevDevicesRoot: mdevDevicesRootDir,
nvpci: nvpciLib,
},
pciDevicesRoot: pciRootDir,
}
return mock, nil
@ -64,12 +81,20 @@ func NewMock() (mock *MockNvmdev, rerr error) {
func (m *MockNvmdev) Cleanup() {
os.RemoveAll(m.mdevParentsRoot)
os.RemoveAll(m.mdevDevicesRoot)
os.RemoveAll(m.pciDevicesRoot)
}
// AddMockA100Parent creates an A100 like parent GPU mock device.
func (m *MockNvmdev) AddMockA100Parent(address string, numaNode int) error {
pciDeviceDir := filepath.Join(m.pciDevicesRoot, address)
err := os.MkdirAll(pciDeviceDir, 0755)
if err != nil {
return err
}
// /sys/class/mdev_bus/<address> is a symlink to /sys/bus/pci/devices/<address>
deviceDir := filepath.Join(m.mdevParentsRoot, address)
err := os.MkdirAll(deviceDir, 0755)
err = os.Symlink(pciDeviceDir, deviceDir)
if err != nil {
return err
}

View File

@ -42,6 +42,7 @@ type Interface interface {
type nvmdev struct {
mdevParentsRoot string
mdevDevicesRoot string
nvpci nvpci.Interface
}
var _ Interface = (*nvmdev)(nil)
@ -63,8 +64,25 @@ type Device struct {
}
// New interface that allows us to get a list of all NVIDIA parent and MDEV (vGPU) devices.
func New() Interface {
return &nvmdev{mdevParentsRoot, mdevDevicesRoot}
func New(opts ...Option) Interface {
n := &nvmdev{mdevParentsRoot: mdevParentsRoot, mdevDevicesRoot: mdevDevicesRoot}
for _, opt := range opts {
opt(n)
}
if n.nvpci == nil {
n.nvpci = nvpci.New()
}
return n
}
// Option defines a function for passing options to the New() call.
type Option func(*nvmdev)
// WithNvpciLib provides an Option to set the nvpci library.
func WithNvpciLib(nvpciLib nvpci.Interface) Option {
return func(n *nvmdev) {
n.nvpci = nvpciLib
}
}
// GetAllParentDevices returns all NVIDIA Parent PCI devices on the system.
@ -77,7 +95,7 @@ func (m *nvmdev) GetAllParentDevices() ([]*ParentDevice, error) {
var nvdevices []*ParentDevice
for _, deviceDir := range deviceDirs {
devicePath := path.Join(m.mdevParentsRoot, deviceDir.Name())
nvdevice, err := NewParentDevice(devicePath)
nvdevice, err := m.NewParentDevice(devicePath)
if err != nil {
return nil, fmt.Errorf("error constructing NVIDIA parent device: %v", err)
}
@ -110,7 +128,7 @@ func (m *nvmdev) GetAllDevices() ([]*Device, error) {
var nvdevices []*Device
for _, deviceDir := range deviceDirs {
nvdevice, err := NewDevice(m.mdevDevicesRoot, deviceDir.Name())
nvdevice, err := m.NewDevice(m.mdevDevicesRoot, deviceDir.Name())
if err != nil {
return nil, fmt.Errorf("error constructing MDEV device: %v", err)
}
@ -124,7 +142,7 @@ func (m *nvmdev) GetAllDevices() ([]*Device, error) {
}
// NewDevice constructs a Device, which represents an NVIDIA mdev (vGPU) device.
func NewDevice(root string, uuid string) (*Device, error) {
func (n *nvmdev) NewDevice(root string, uuid string) (*Device, error) {
path := path.Join(root, uuid)
m, err := newMdev(path)
@ -132,7 +150,7 @@ func NewDevice(root string, uuid string) (*Device, error) {
return nil, err
}
parent, err := NewParentDevice(m.parentDevicePath())
parent, err := n.NewParentDevice(m.parentDevicePath())
if err != nil {
return nil, fmt.Errorf("error constructing NVIDIA PCI device: %v", err)
}
@ -241,9 +259,9 @@ func (m mdev) iommuGroup() (int, error) {
}
// NewParentDevice constructs a ParentDevice.
func NewParentDevice(devicePath string) (*ParentDevice, error) {
func (m *nvmdev) NewParentDevice(devicePath string) (*ParentDevice, error) {
address := filepath.Base(devicePath)
nvdevice, err := nvpci.New().GetGPUByPciBusID(address)
nvdevice, err := m.nvpci.GetGPUByPciBusID(address)
if err != nil {
return nil, fmt.Errorf("failed to construct NVIDIA PCI device: %v", err)
}