mirror of
https://github.com/clearml/go-nvlib
synced 2025-02-25 05:23:52 +00:00
Add nvpci.Interface to the nvmdev struct to aid in unit tests
Signed-off-by: Christopher Desiniotis <cdesiniotis@nvidia.com>
This commit is contained in:
parent
44a5440a0a
commit
7c3222d683
@ -28,6 +28,7 @@ import (
|
||||
// MockNvmdev mock mdev device.
|
||||
type MockNvmdev struct {
|
||||
*nvmdev
|
||||
pciDevicesRoot string
|
||||
}
|
||||
|
||||
var _ Interface = (*MockNvmdev)(nil)
|
||||
@ -53,8 +54,24 @@ func NewMock() (mock *MockNvmdev, rerr error) {
|
||||
}
|
||||
}()
|
||||
|
||||
pciRootDir, err := os.MkdirTemp(os.TempDir(), "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if rerr != nil {
|
||||
os.RemoveAll(pciRootDir)
|
||||
}
|
||||
}()
|
||||
|
||||
nvpciLib := nvpci.New(nvpci.WithPCIDevicesRoot(pciRootDir))
|
||||
mock = &MockNvmdev{
|
||||
&nvmdev{mdevParentsRootDir, mdevDevicesRootDir},
|
||||
nvmdev: &nvmdev{
|
||||
mdevParentsRoot: mdevParentsRootDir,
|
||||
mdevDevicesRoot: mdevDevicesRootDir,
|
||||
nvpci: nvpciLib,
|
||||
},
|
||||
pciDevicesRoot: pciRootDir,
|
||||
}
|
||||
|
||||
return mock, nil
|
||||
@ -64,12 +81,20 @@ func NewMock() (mock *MockNvmdev, rerr error) {
|
||||
func (m *MockNvmdev) Cleanup() {
|
||||
os.RemoveAll(m.mdevParentsRoot)
|
||||
os.RemoveAll(m.mdevDevicesRoot)
|
||||
os.RemoveAll(m.pciDevicesRoot)
|
||||
}
|
||||
|
||||
// AddMockA100Parent creates an A100 like parent GPU mock device.
|
||||
func (m *MockNvmdev) AddMockA100Parent(address string, numaNode int) error {
|
||||
pciDeviceDir := filepath.Join(m.pciDevicesRoot, address)
|
||||
err := os.MkdirAll(pciDeviceDir, 0755)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// /sys/class/mdev_bus/<address> is a symlink to /sys/bus/pci/devices/<address>
|
||||
deviceDir := filepath.Join(m.mdevParentsRoot, address)
|
||||
err := os.MkdirAll(deviceDir, 0755)
|
||||
err = os.Symlink(pciDeviceDir, deviceDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -42,6 +42,7 @@ type Interface interface {
|
||||
type nvmdev struct {
|
||||
mdevParentsRoot string
|
||||
mdevDevicesRoot string
|
||||
nvpci nvpci.Interface
|
||||
}
|
||||
|
||||
var _ Interface = (*nvmdev)(nil)
|
||||
@ -63,8 +64,25 @@ type Device struct {
|
||||
}
|
||||
|
||||
// New interface that allows us to get a list of all NVIDIA parent and MDEV (vGPU) devices.
|
||||
func New() Interface {
|
||||
return &nvmdev{mdevParentsRoot, mdevDevicesRoot}
|
||||
func New(opts ...Option) Interface {
|
||||
n := &nvmdev{mdevParentsRoot: mdevParentsRoot, mdevDevicesRoot: mdevDevicesRoot}
|
||||
for _, opt := range opts {
|
||||
opt(n)
|
||||
}
|
||||
if n.nvpci == nil {
|
||||
n.nvpci = nvpci.New()
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// Option defines a function for passing options to the New() call.
|
||||
type Option func(*nvmdev)
|
||||
|
||||
// WithNvpciLib provides an Option to set the nvpci library.
|
||||
func WithNvpciLib(nvpciLib nvpci.Interface) Option {
|
||||
return func(n *nvmdev) {
|
||||
n.nvpci = nvpciLib
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllParentDevices returns all NVIDIA Parent PCI devices on the system.
|
||||
@ -77,7 +95,7 @@ func (m *nvmdev) GetAllParentDevices() ([]*ParentDevice, error) {
|
||||
var nvdevices []*ParentDevice
|
||||
for _, deviceDir := range deviceDirs {
|
||||
devicePath := path.Join(m.mdevParentsRoot, deviceDir.Name())
|
||||
nvdevice, err := NewParentDevice(devicePath)
|
||||
nvdevice, err := m.NewParentDevice(devicePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing NVIDIA parent device: %v", err)
|
||||
}
|
||||
@ -110,7 +128,7 @@ func (m *nvmdev) GetAllDevices() ([]*Device, error) {
|
||||
|
||||
var nvdevices []*Device
|
||||
for _, deviceDir := range deviceDirs {
|
||||
nvdevice, err := NewDevice(m.mdevDevicesRoot, deviceDir.Name())
|
||||
nvdevice, err := m.NewDevice(m.mdevDevicesRoot, deviceDir.Name())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing MDEV device: %v", err)
|
||||
}
|
||||
@ -124,7 +142,7 @@ func (m *nvmdev) GetAllDevices() ([]*Device, error) {
|
||||
}
|
||||
|
||||
// NewDevice constructs a Device, which represents an NVIDIA mdev (vGPU) device.
|
||||
func NewDevice(root string, uuid string) (*Device, error) {
|
||||
func (n *nvmdev) NewDevice(root string, uuid string) (*Device, error) {
|
||||
path := path.Join(root, uuid)
|
||||
|
||||
m, err := newMdev(path)
|
||||
@ -132,7 +150,7 @@ func NewDevice(root string, uuid string) (*Device, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parent, err := NewParentDevice(m.parentDevicePath())
|
||||
parent, err := n.NewParentDevice(m.parentDevicePath())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing NVIDIA PCI device: %v", err)
|
||||
}
|
||||
@ -241,9 +259,9 @@ func (m mdev) iommuGroup() (int, error) {
|
||||
}
|
||||
|
||||
// NewParentDevice constructs a ParentDevice.
|
||||
func NewParentDevice(devicePath string) (*ParentDevice, error) {
|
||||
func (m *nvmdev) NewParentDevice(devicePath string) (*ParentDevice, error) {
|
||||
address := filepath.Base(devicePath)
|
||||
nvdevice, err := nvpci.New().GetGPUByPciBusID(address)
|
||||
nvdevice, err := m.nvpci.GetGPUByPciBusID(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to construct NVIDIA PCI device: %v", err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user