nvidia-container-toolkit/vendor/github.com/NVIDIA/go-nvlib/pkg/nvpci/mock.go
dependabot[bot] ca528c4f53
Bump github.com/NVIDIA/go-nvlib from 0.5.0 to 0.6.0
Bumps [github.com/NVIDIA/go-nvlib](https://github.com/NVIDIA/go-nvlib) from 0.5.0 to 0.6.0.
- [Release notes](https://github.com/NVIDIA/go-nvlib/releases)
- [Commits](https://github.com/NVIDIA/go-nvlib/compare/v0.5.0...v0.6.0)

---
updated-dependencies:
- dependency-name: github.com/NVIDIA/go-nvlib
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-06-23 08:42:01 +00:00

261 lines
6.1 KiB
Go

/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvpci
import (
"fmt"
"os"
"path/filepath"
"regexp"
"strconv"
"github.com/NVIDIA/go-nvlib/pkg/nvpci/bytes"
)
// MockNvpci mock pci device.
type MockNvpci struct {
*nvpci
}
var _ Interface = (*MockNvpci)(nil)
// NewMockNvpci create new mock PCI and remove old devices.
func NewMockNvpci() (mock *MockNvpci, rerr error) {
rootDir, err := os.MkdirTemp(os.TempDir(), "")
if err != nil {
return nil, err
}
defer func() {
if rerr != nil {
os.RemoveAll(rootDir)
}
}()
mock = &MockNvpci{
New(WithPCIDevicesRoot(rootDir)).(*nvpci),
}
return mock, nil
}
// Cleanup remove the mocked PCI devices root folder.
func (m *MockNvpci) Cleanup() {
os.RemoveAll(m.pciDevicesRoot)
}
func validatePCIAddress(addr string) error {
r := regexp.MustCompile(`0{4}:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9]`)
if !r.Match([]byte(addr)) {
return fmt.Errorf(`invalid PCI address should match 0{4}:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9]: %s`, addr)
}
return nil
}
// AddMockA100 Create an A100 like GPU mock device.
func (m *MockNvpci) AddMockA100(address string, numaNode int, sriov *SriovInfo) error {
err := validatePCIAddress(address)
if err != nil {
return err
}
deviceDir := filepath.Join(m.pciDevicesRoot, address)
err = os.MkdirAll(deviceDir, 0755)
if err != nil {
return err
}
err = createNVIDIAgpuFiles(deviceDir)
if err != nil {
return err
}
iommuGroup := 20
_, err = os.Create(filepath.Join(deviceDir, strconv.Itoa(iommuGroup)))
if err != nil {
return err
}
err = os.Symlink(filepath.Join(deviceDir, strconv.Itoa(iommuGroup)), filepath.Join(deviceDir, "iommu_group"))
if err != nil {
return err
}
numa, err := os.Create(filepath.Join(deviceDir, "numa_node"))
if err != nil {
return err
}
_, err = numa.WriteString(fmt.Sprintf("%v", numaNode))
if err != nil {
return err
}
if sriov != nil && sriov.PhysicalFunction != nil {
totalVFs, err := os.Create(filepath.Join(deviceDir, "sriov_totalvfs"))
if err != nil {
return err
}
_, err = fmt.Fprintf(totalVFs, "%d", sriov.PhysicalFunction.TotalVFs)
if err != nil {
return err
}
numVFs, err := os.Create(filepath.Join(deviceDir, "sriov_numvfs"))
if err != nil {
return err
}
_, err = fmt.Fprintf(numVFs, "%d", sriov.PhysicalFunction.NumVFs)
if err != nil {
return err
}
for i := 1; i <= int(sriov.PhysicalFunction.NumVFs); i++ {
err = m.createVf(address, i, iommuGroup, numaNode)
if err != nil {
return err
}
}
}
return nil
}
func createNVIDIAgpuFiles(deviceDir string) error {
vendor, err := os.Create(filepath.Join(deviceDir, "vendor"))
if err != nil {
return err
}
_, err = vendor.WriteString(fmt.Sprintf("0x%x", PCINvidiaVendorID))
if err != nil {
return err
}
class, err := os.Create(filepath.Join(deviceDir, "class"))
if err != nil {
return err
}
_, err = class.WriteString(fmt.Sprintf("0x%x", PCI3dControllerClass))
if err != nil {
return err
}
device, err := os.Create(filepath.Join(deviceDir, "device"))
if err != nil {
return err
}
_, err = device.WriteString("0x20bf")
if err != nil {
return err
}
_, err = os.Create(filepath.Join(deviceDir, "nvidia"))
if err != nil {
return err
}
err = os.Symlink(filepath.Join(deviceDir, "nvidia"), filepath.Join(deviceDir, "driver"))
if err != nil {
return err
}
config, err := os.Create(filepath.Join(deviceDir, "config"))
if err != nil {
return err
}
_data := make([]byte, PCICfgSpaceStandardSize)
data := bytes.New(&_data)
data.Write16(0, PCINvidiaVendorID)
data.Write16(2, uint16(0x20bf))
data.Write8(PCIStatusBytePosition, PCIStatusCapabilityList)
_, err = config.Write(*data.Raw())
if err != nil {
return err
}
bar0 := []uint64{0x00000000c2000000, 0x00000000c2ffffff, 0x0000000000040200}
resource, err := os.Create(filepath.Join(deviceDir, "resource"))
if err != nil {
return err
}
_, err = resource.WriteString(fmt.Sprintf("0x%x 0x%x 0x%x", bar0[0], bar0[1], bar0[2]))
if err != nil {
return err
}
pmcID := uint32(0x170000a1)
resource0, err := os.Create(filepath.Join(deviceDir, "resource0"))
if err != nil {
return err
}
_data = make([]byte, bar0[1]-bar0[0]+1)
data = bytes.New(&_data).LittleEndian()
data.Write32(0, pmcID)
_, err = resource0.Write(*data.Raw())
if err != nil {
return err
}
return nil
}
func (m *MockNvpci) createVf(pfAddress string, id, iommu_group, numaNode int) error {
functionID := pfAddress[len(pfAddress)-1]
// we are verifying the last character of pfAddress is integer.
functionNumber, err := strconv.Atoi(string(functionID))
if err != nil {
return fmt.Errorf("can't conver physical function pci address function number %s to integer: %v", string(functionID), err)
}
vfFunctionNumber := functionNumber + id
vfAddress := pfAddress[:len(pfAddress)-1] + strconv.Itoa(vfFunctionNumber)
deviceDir := filepath.Join(m.pciDevicesRoot, vfAddress)
err = os.MkdirAll(deviceDir, 0755)
if err != nil {
return err
}
err = createNVIDIAgpuFiles(deviceDir)
if err != nil {
return err
}
vfIommuGroup := strconv.Itoa(iommu_group + id)
_, err = os.Create(filepath.Join(deviceDir, vfIommuGroup))
if err != nil {
return err
}
err = os.Symlink(filepath.Join(deviceDir, vfIommuGroup), filepath.Join(deviceDir, "iommu_group"))
if err != nil {
return err
}
numa, err := os.Create(filepath.Join(deviceDir, "numa_node"))
if err != nil {
return err
}
_, err = numa.WriteString(fmt.Sprintf("%v", numaNode))
if err != nil {
return err
}
err = os.Symlink(filepath.Join(m.pciDevicesRoot, pfAddress), filepath.Join(deviceDir, "physfn"))
if err != nil {
return err
}
return nil
}