diff --git a/pkg/nvmdev/mock.go b/pkg/nvmdev/mock.go index 8fd0e8c..ed7a055 100644 --- a/pkg/nvmdev/mock.go +++ b/pkg/nvmdev/mock.go @@ -183,14 +183,20 @@ func (m *MockNvmdev) AddMockA100Parent(address string, numaNode int) error { // AddMockA100Mdev creates an A100 like MDEV (vGPU) mock device. // The corresponding mocked parent A100 device must be created beforehand. -func (m *MockNvmdev) AddMockA100Mdev(uuid string, mdevType string, parentMdevTypeDir string) error { - deviceDir := filepath.Join(m.mdevDevicesRoot, uuid) - err := os.MkdirAll(deviceDir, 0755) +func (m *MockNvmdev) AddMockA100Mdev(uuid string, mdevType string, mdevTypeDir string, parentDeviceDir string) error { + mdevDeviceDir := filepath.Join(parentDeviceDir, uuid) + err := os.Mkdir(mdevDeviceDir, 0755) if err != nil { return err } - err = os.Symlink(parentMdevTypeDir, filepath.Join(deviceDir, "mdev_type")) + parentMdevTypeDir := filepath.Join(parentDeviceDir, "mdev_supported_types", mdevTypeDir) + err = os.Symlink(parentMdevTypeDir, filepath.Join(mdevDeviceDir, "mdev_type")) + if err != nil { + return err + } + + err = os.Symlink(mdevDeviceDir, filepath.Join(m.mdevDevicesRoot, uuid)) if err != nil { return err } diff --git a/pkg/nvmdev/nvmdev.go b/pkg/nvmdev/nvmdev.go index 6a6f060..3e4befe 100644 --- a/pkg/nvmdev/nvmdev.go +++ b/pkg/nvmdev/nvmdev.go @@ -153,27 +153,33 @@ func NewDevice(root string, uuid string) (*Device, error) { return &device, nil } +// mdev represents the path to an NVIDIA mdev (vGPU) device. type mdev string func newMdev(devicePath string) (mdev, error) { - mdevTypeDir, err := filepath.EvalSymlinks(path.Join(devicePath, "mdev_type")) + mdevDir, err := filepath.EvalSymlinks(devicePath) if err != nil { - return "", fmt.Errorf("error resolving mdev_type link: %v", err) + return "", fmt.Errorf("error resolving symlink for %s: %v", devicePath, err) } - return mdev(mdevTypeDir), nil + return mdev(mdevDir), nil } func (m mdev) String() string { return string(m) } func (m mdev) parentDevicePath() string { - // /sys/bus/pci/devices//mdev_supported_types/ - return path.Dir(path.Dir(string(m))) + // /sys/bus/pci/devices// + return path.Dir(string(m)) } func (m mdev) Type() (string, error) { - mdevType, err := os.ReadFile(path.Join(string(m), "name")) + mdevTypeDir, err := filepath.EvalSymlinks(path.Join(string(m), "mdev_type")) + if err != nil { + return "", fmt.Errorf("error resolving mdev_type link for mdev %s: %v", m, err) + } + + mdevType, err := os.ReadFile(path.Join(mdevTypeDir, "name")) if err != nil { return "", fmt.Errorf("unable to read mdev_type name for mdev %s: %v", m, err) } diff --git a/pkg/nvmdev/nvmdev_test.go b/pkg/nvmdev/nvmdev_test.go index 43815f4..d044a9c 100644 --- a/pkg/nvmdev/nvmdev_test.go +++ b/pkg/nvmdev/nvmdev_test.go @@ -18,7 +18,6 @@ package nvmdev import ( "github.com/stretchr/testify/require" - "path/filepath" "testing" ) @@ -41,8 +40,7 @@ func TestNvmdev(t *testing.T) { require.Nil(t, err, "Error checking if A100-4Q vGPU type is available for creation") require.True(t, available, "A100-4C should be available to create") - err = nvmdev.AddMockA100Mdev("b1914f0a-15cf-416e-8967-55fc7cb68e20", "A100-4C", - filepath.Join(parentDevs[0].Path, "mdev_supported_types/nvidia-500")) + err = nvmdev.AddMockA100Mdev("b1914f0a-15cf-416e-8967-55fc7cb68e20", "A100-4C", "nvidia-500", parentDevs[0].Path) require.Nil(t, err, "Error adding Mock A100 mediated device") mdevs, err := nvmdev.GetAllDevices()