diff --git a/pkg/nvpci/mock.go b/pkg/nvpci/mock.go index 9468dc0..fdeb461 100644 --- a/pkg/nvpci/mock.go +++ b/pkg/nvpci/mock.go @@ -53,7 +53,7 @@ func (m *MockNvpci) Cleanup() { os.RemoveAll(m.pciDevicesRoot) } -func (m *MockNvpci) AddMockA100(address string) error { +func (m *MockNvpci) AddMockA100(address string, numaNode int) error { deviceDir := filepath.Join(m.pciDevicesRoot, address) err := os.MkdirAll(deviceDir, 0755) if err != nil { @@ -87,6 +87,15 @@ func (m *MockNvpci) AddMockA100(address string) error { return err } + numa, err := os.Create(filepath.Join(deviceDir, "numa_node")) + if err != nil { + return err + } + _, err = numa.WriteString(fmt.Sprintf("%v", numaNode)) + if err != nil { + return err + } + config, err := os.Create(filepath.Join(deviceDir, "config")) if err != nil { return err diff --git a/pkg/nvpci/nvpci_test.go b/pkg/nvpci/nvpci_test.go index 8a1e368..807ec0a 100644 --- a/pkg/nvpci/nvpci_test.go +++ b/pkg/nvpci/nvpci_test.go @@ -31,7 +31,7 @@ func TestNvpci(t *testing.T) { require.Nil(t, err, "Error creating NewMockNvpci") defer nvpci.Cleanup() - err = nvpci.AddMockA100("0000:80:05.1") + err = nvpci.AddMockA100("0000:80:05.1", 0) require.Nil(t, err, "Error adding Mock A100 device to MockNvpci") devices, err := nvpci.GetGPUs() @@ -39,6 +39,7 @@ func TestNvpci(t *testing.T) { require.Equal(t, 1, len(devices), "Wrong number of GPU devices") require.Equal(t, 1, len(devices[0].Resources), "Wrong number GPU resources found") require.Equal(t, "0000:80:05.1", devices[0].Address, "Wrong Address found for device") + require.Equal(t, 0, devices[0].NumaNode, "Wrong NUMA node found for device") config, err := devices[0].Config.Read() require.Nil(t, err, "Error reading config") @@ -62,3 +63,39 @@ func TestNvpci(t *testing.T) { require.Equal(t, int(resource0.End-resource0.Start+1), bar0.Len()) require.Equal(t, ga100PmcID, bar0.Read32(0)) } + +func TestNvpciNUMANode(t *testing.T) { + testCases := []struct { + Description string + NumaNode int + }{ + { + Description: "Numa Node -1", + NumaNode: -1, + }, + { + Description: "Numa Node 0", + NumaNode: 0, + }, + { + Description: "Numa Node 1", + NumaNode: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.Description, func(t *testing.T) { + nvpci, err := NewMockNvpci() + require.Nil(t, err, "Error creating NewMockNvpci") + defer nvpci.Cleanup() + + err = nvpci.AddMockA100("0000:80:05.1", tc.NumaNode) + require.Nil(t, err, "Error adding Mock A100 device to MockNvpci") + + devices, err := nvpci.GetGPUs() + require.Nil(t, err, "Error getting GPUs") + require.Equal(t, 1, len(devices), "Wrong number of GPU devices") + require.Equal(t, tc.NumaNode, devices[0].NumaNode, "Wrong NUMA node found for device") + }) + } +}