diff --git a/pkg/nvpci/mock.go b/pkg/nvpci/mock.go index 3400d9e..fdeb461 100644 --- a/pkg/nvpci/mock.go +++ b/pkg/nvpci/mock.go @@ -25,17 +25,13 @@ import ( "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci/bytes" ) -type MockA100 struct { +type MockNvpci struct { *nvpci } -func (m *MockA100) Cleanup() { - os.RemoveAll(m.pciDevicesRoot) -} +var _ Interface = (*MockNvpci)(nil) -var _ Interface = (*MockA100)(nil) - -func NewMockA100() (mock *MockA100, rerr error) { +func NewMockNvpci() (mock *MockNvpci, rerr error) { rootDir, err := ioutil.TempDir("", "") if err != nil { return nil, err @@ -46,42 +42,63 @@ func NewMockA100() (mock *MockA100, rerr error) { } }() - deviceDir := filepath.Join(rootDir, "0000:80:05.1") - err = os.MkdirAll(deviceDir, 0755) + mock = &MockNvpci{ + &nvpci{rootDir}, + } + + return mock, nil +} + +func (m *MockNvpci) Cleanup() { + os.RemoveAll(m.pciDevicesRoot) +} + +func (m *MockNvpci) AddMockA100(address string, numaNode int) error { + deviceDir := filepath.Join(m.pciDevicesRoot, address) + err := os.MkdirAll(deviceDir, 0755) if err != nil { - return nil, err + return err } vendor, err := os.Create(filepath.Join(deviceDir, "vendor")) if err != nil { - return nil, err + return err } _, err = vendor.WriteString(fmt.Sprintf("0x%x", pciNvidiaVendorID)) if err != nil { - return nil, err + return err } class, err := os.Create(filepath.Join(deviceDir, "class")) if err != nil { - return nil, err + return err } _, err = class.WriteString(fmt.Sprintf("0x%x", pci3dControllerClass)) if err != nil { - return nil, err + return err } device, err := os.Create(filepath.Join(deviceDir, "device")) if err != nil { - return nil, err + return err } _, err = device.WriteString("0x20bf") if err != nil { - return nil, err + return err + } + + numa, err := os.Create(filepath.Join(deviceDir, "numa_node")) + if err != nil { + return err + } + _, err = numa.WriteString(fmt.Sprintf("%v", numaNode)) + if err != nil { + return err } config, err := os.Create(filepath.Join(deviceDir, "config")) if err != nil { - return nil, err + return err } _data := make([]byte, pciCfgSpaceStandardSize) data := bytes.New(&_data) @@ -89,32 +106,28 @@ func NewMockA100() (mock *MockA100, rerr error) { data.Write16(2, uint16(0x20bf)) _, err = config.Write(*data.Raw()) if err != nil { - return nil, err + return err } bar0 := []uint64{0x00000000c2000000, 0x00000000c2ffffff, 0x0000000000040200} resource, err := os.Create(filepath.Join(deviceDir, "resource")) _, err = resource.WriteString(fmt.Sprintf("0x%x 0x%x 0x%x", bar0[0], bar0[1], bar0[2])) if err != nil { - return nil, err + return err } pmcID := uint32(0x170000a1) resource0, err := os.Create(filepath.Join(deviceDir, "resource0")) if err != nil { - return nil, err + return err } _data = make([]byte, bar0[1]-bar0[0]+1) data = bytes.New(&_data).LittleEndian() data.Write32(0, pmcID) _, err = resource0.Write(*data.Raw()) if err != nil { - return nil, err + return err } - mock = &MockA100{ - &nvpci{rootDir}, - } - - return mock, nil + return nil } diff --git a/pkg/nvpci/nvpci.go b/pkg/nvpci/nvpci.go index 0bd5b21..f4069f9 100644 --- a/pkg/nvpci/nvpci.go +++ b/pkg/nvpci/nvpci.go @@ -61,6 +61,7 @@ type NvidiaPCIDevice struct { Vendor uint16 Class uint32 Device uint16 + NumaNode int Config *ConfigSpace Resources map[int]*MemoryResource } @@ -147,6 +148,16 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) { return nil, fmt.Errorf("unable to convert device string to uint16: %v", deviceStr) } + numa, err := ioutil.ReadFile(path.Join(devicePath, "numa_node")) + if err != nil { + return nil, fmt.Errorf("unable to read PCI NUMA node for %s: %v", address, err) + } + numaStr := strings.TrimSpace(string(numa)) + numaNode, err := strconv.ParseInt(numaStr, 0, 64) + if err != nil { + return nil, fmt.Errorf("unable to convert NUMA node string to int64: %v", numaNode) + } + config := &ConfigSpace{ Path: path.Join(devicePath, "config"), } @@ -183,6 +194,7 @@ func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) { Vendor: uint16(vendorID), Class: uint32(classID), Device: uint16(deviceID), + NumaNode: int(numaNode), Config: config, Resources: resources, } diff --git a/pkg/nvpci/nvpci_test.go b/pkg/nvpci/nvpci_test.go index ab3c544..807ec0a 100644 --- a/pkg/nvpci/nvpci_test.go +++ b/pkg/nvpci/nvpci_test.go @@ -27,14 +27,19 @@ const ( ) func TestNvpci(t *testing.T) { - nvpci, err := NewMockA100() - require.Nil(t, err, "Error creating NewMockA100") + nvpci, err := NewMockNvpci() + require.Nil(t, err, "Error creating NewMockNvpci") defer nvpci.Cleanup() + err = nvpci.AddMockA100("0000:80:05.1", 0) + require.Nil(t, err, "Error adding Mock A100 device to MockNvpci") + devices, err := nvpci.GetGPUs() require.Nil(t, err, "Error getting GPUs") require.Equal(t, 1, len(devices), "Wrong number of GPU devices") require.Equal(t, 1, len(devices[0].Resources), "Wrong number GPU resources found") + require.Equal(t, "0000:80:05.1", devices[0].Address, "Wrong Address found for device") + require.Equal(t, 0, devices[0].NumaNode, "Wrong NUMA node found for device") config, err := devices[0].Config.Read() require.Nil(t, err, "Error reading config") @@ -58,3 +63,39 @@ func TestNvpci(t *testing.T) { require.Equal(t, int(resource0.End-resource0.Start+1), bar0.Len()) require.Equal(t, ga100PmcID, bar0.Read32(0)) } + +func TestNvpciNUMANode(t *testing.T) { + testCases := []struct { + Description string + NumaNode int + }{ + { + Description: "Numa Node -1", + NumaNode: -1, + }, + { + Description: "Numa Node 0", + NumaNode: 0, + }, + { + Description: "Numa Node 1", + NumaNode: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.Description, func(t *testing.T) { + nvpci, err := NewMockNvpci() + require.Nil(t, err, "Error creating NewMockNvpci") + defer nvpci.Cleanup() + + err = nvpci.AddMockA100("0000:80:05.1", tc.NumaNode) + require.Nil(t, err, "Error adding Mock A100 device to MockNvpci") + + devices, err := nvpci.GetGPUs() + require.Nil(t, err, "Error getting GPUs") + require.Equal(t, 1, len(devices), "Wrong number of GPU devices") + require.Equal(t, tc.NumaNode, devices[0].NumaNode, "Wrong NUMA node found for device") + }) + } +}