diff --git a/pkg/nvmdev/mock.go b/pkg/nvmdev/mock.go index 9425743..78192f8 100644 --- a/pkg/nvmdev/mock.go +++ b/pkg/nvmdev/mock.go @@ -28,6 +28,7 @@ import ( // MockNvmdev mock mdev device. type MockNvmdev struct { *nvmdev + pciDevicesRoot string } var _ Interface = (*MockNvmdev)(nil) @@ -53,8 +54,24 @@ func NewMock() (mock *MockNvmdev, rerr error) { } }() + pciRootDir, err := os.MkdirTemp(os.TempDir(), "") + if err != nil { + return nil, err + } + defer func() { + if rerr != nil { + os.RemoveAll(pciRootDir) + } + }() + + nvpciLib := nvpci.New(nvpci.WithPCIDevicesRoot(pciRootDir)) mock = &MockNvmdev{ - &nvmdev{mdevParentsRootDir, mdevDevicesRootDir}, + nvmdev: &nvmdev{ + mdevParentsRoot: mdevParentsRootDir, + mdevDevicesRoot: mdevDevicesRootDir, + nvpci: nvpciLib, + }, + pciDevicesRoot: pciRootDir, } return mock, nil @@ -64,12 +81,20 @@ func NewMock() (mock *MockNvmdev, rerr error) { func (m *MockNvmdev) Cleanup() { os.RemoveAll(m.mdevParentsRoot) os.RemoveAll(m.mdevDevicesRoot) + os.RemoveAll(m.pciDevicesRoot) } // AddMockA100Parent creates an A100 like parent GPU mock device. func (m *MockNvmdev) AddMockA100Parent(address string, numaNode int) error { + pciDeviceDir := filepath.Join(m.pciDevicesRoot, address) + err := os.MkdirAll(pciDeviceDir, 0755) + if err != nil { + return err + } + + // /sys/class/mdev_bus/
is a symlink to /sys/bus/pci/devices/
deviceDir := filepath.Join(m.mdevParentsRoot, address) - err := os.MkdirAll(deviceDir, 0755) + err = os.Symlink(pciDeviceDir, deviceDir) if err != nil { return err } diff --git a/pkg/nvmdev/nvmdev.go b/pkg/nvmdev/nvmdev.go index 33f2d5a..c85d79d 100644 --- a/pkg/nvmdev/nvmdev.go +++ b/pkg/nvmdev/nvmdev.go @@ -42,6 +42,7 @@ type Interface interface { type nvmdev struct { mdevParentsRoot string mdevDevicesRoot string + nvpci nvpci.Interface } var _ Interface = (*nvmdev)(nil) @@ -63,8 +64,25 @@ type Device struct { } // New interface that allows us to get a list of all NVIDIA parent and MDEV (vGPU) devices. -func New() Interface { - return &nvmdev{mdevParentsRoot, mdevDevicesRoot} +func New(opts ...Option) Interface { + n := &nvmdev{mdevParentsRoot: mdevParentsRoot, mdevDevicesRoot: mdevDevicesRoot} + for _, opt := range opts { + opt(n) + } + if n.nvpci == nil { + n.nvpci = nvpci.New() + } + return n +} + +// Option defines a function for passing options to the New() call. +type Option func(*nvmdev) + +// WithNvpciLib provides an Option to set the nvpci library. +func WithNvpciLib(nvpciLib nvpci.Interface) Option { + return func(n *nvmdev) { + n.nvpci = nvpciLib + } } // GetAllParentDevices returns all NVIDIA Parent PCI devices on the system. @@ -77,7 +95,7 @@ func (m *nvmdev) GetAllParentDevices() ([]*ParentDevice, error) { var nvdevices []*ParentDevice for _, deviceDir := range deviceDirs { devicePath := path.Join(m.mdevParentsRoot, deviceDir.Name()) - nvdevice, err := NewParentDevice(devicePath) + nvdevice, err := m.NewParentDevice(devicePath) if err != nil { return nil, fmt.Errorf("error constructing NVIDIA parent device: %v", err) } @@ -110,7 +128,7 @@ func (m *nvmdev) GetAllDevices() ([]*Device, error) { var nvdevices []*Device for _, deviceDir := range deviceDirs { - nvdevice, err := NewDevice(m.mdevDevicesRoot, deviceDir.Name()) + nvdevice, err := m.NewDevice(m.mdevDevicesRoot, deviceDir.Name()) if err != nil { return nil, fmt.Errorf("error constructing MDEV device: %v", err) } @@ -124,7 +142,7 @@ func (m *nvmdev) GetAllDevices() ([]*Device, error) { } // NewDevice constructs a Device, which represents an NVIDIA mdev (vGPU) device. -func NewDevice(root string, uuid string) (*Device, error) { +func (n *nvmdev) NewDevice(root string, uuid string) (*Device, error) { path := path.Join(root, uuid) m, err := newMdev(path) @@ -132,7 +150,7 @@ func NewDevice(root string, uuid string) (*Device, error) { return nil, err } - parent, err := NewParentDevice(m.parentDevicePath()) + parent, err := n.NewParentDevice(m.parentDevicePath()) if err != nil { return nil, fmt.Errorf("error constructing NVIDIA PCI device: %v", err) } @@ -241,8 +259,9 @@ func (m mdev) iommuGroup() (int, error) { } // NewParentDevice constructs a ParentDevice. -func NewParentDevice(devicePath string) (*ParentDevice, error) { - nvdevice, err := newNvidiaPCIDeviceFromPath(devicePath) +func (m *nvmdev) NewParentDevice(devicePath string) (*ParentDevice, error) { + address := filepath.Base(devicePath) + nvdevice, err := m.nvpci.GetGPUByPciBusID(address) if err != nil { return nil, fmt.Errorf("failed to construct NVIDIA PCI device: %v", err) } @@ -370,12 +389,3 @@ func (p *ParentDevice) GetAvailableMDEVInstances(mdevType string) (int, error) { return availableInstances, nil } - -// newNvidiaPCIDeviceFromPath constructs an NvidiaPCIDevice for the specified device path. -func newNvidiaPCIDeviceFromPath(devicePath string) (*nvpci.NvidiaPCIDevice, error) { - root := filepath.Dir(devicePath) - address := filepath.Base(devicePath) - return nvpci.New( - nvpci.WithPCIDevicesRoot(root), - ).GetGPUByPciBusID(address) -}