nvidia-container-toolkit/internal/platform-support/dgpu/nvml_test.go
Evan Lezar fa16d83494 [no-relnotes] Add initial unit test for MIG CDI spec generation
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2024-07-31 11:56:36 +02:00

172 lines
4.4 KiB
Go

/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package dgpu
import (
"testing"
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
"github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/NVIDIA/go-nvml/pkg/nvml/mock"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps"
)
// TODO: In order to properly test this, we need a mechanism to inject /
// override the char device discoverer.
func TestNewNvmlDGPUDiscoverer(t *testing.T) {
logger, _ := testlog.NewNullLogger()
nvmllib := &mock.Interface{}
devicelib := device.New(
nvmllib,
)
testCases := []struct {
description string
device nvml.Device
expectedError error
expectedDevices []discover.Device
expectedHooks []discover.Hook
expectedMounts []discover.Mount
}{
{
description: "",
device: &mock.Device{
GetMinorNumberFunc: func() (int, nvml.Return) {
return 3, nvml.SUCCESS
},
GetPciInfoFunc: func() (nvml.PciInfo, nvml.Return) {
var busID [32]int8
for i, b := range []byte("00000000:45:00:00") {
busID[i] = int8(b)
}
info := nvml.PciInfo{
BusId: busID,
}
return info, nvml.SUCCESS
},
},
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
o := &options{logger: logger}
device, err := devicelib.NewDevice(tc.device)
require.NoError(t, err)
d, err := o.newNvmlDGPUDiscoverer(&toRequiredInfo{device})
require.ErrorIs(t, err, tc.expectedError)
devices, _ := d.Devices()
require.EqualValues(t, tc.expectedDevices, devices)
hooks, _ := d.Hooks()
require.EqualValues(t, tc.expectedHooks, hooks)
mounts, _ := d.Mounts()
require.EqualValues(t, tc.expectedMounts, mounts)
})
}
}
func TestNewNvmlMIGDiscoverer(t *testing.T) {
logger, _ := testlog.NewNullLogger()
nvmllib := &mock.Interface{}
devicelib := device.New(
nvmllib,
)
testCases := []struct {
description string
mig *mock.Device
parent nvml.Device
migCaps nvcaps.MigCaps
expectedError error
expectedDevices []discover.Device
expectedHooks []discover.Hook
expectedMounts []discover.Mount
}{
{
description: "",
mig: &mock.Device{
IsMigDeviceHandleFunc: func() (bool, nvml.Return) {
return true, nvml.SUCCESS
},
GetGpuInstanceIdFunc: func() (int, nvml.Return) {
return 1, nvml.SUCCESS
},
GetComputeInstanceIdFunc: func() (int, nvml.Return) {
return 2, nvml.SUCCESS
},
},
parent: &mock.Device{
GetMinorNumberFunc: func() (int, nvml.Return) {
return 3, nvml.SUCCESS
},
GetPciInfoFunc: func() (nvml.PciInfo, nvml.Return) {
var busID [32]int8
for i, b := range []byte("00000000:45:00:00") {
busID[i] = int8(b)
}
info := nvml.PciInfo{
BusId: busID,
}
return info, nvml.SUCCESS
},
},
migCaps: nvcaps.MigCaps{
"gpu3/gi1/access": 31,
"gpu3/gi1/ci2/access": 312,
},
expectedDevices: nil,
expectedMounts: nil,
expectedHooks: []discover.Hook{},
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
tc.mig.GetDeviceHandleFromMigDeviceHandleFunc = func() (nvml.Device, nvml.Return) {
return tc.parent, nvml.SUCCESS
}
parent, err := devicelib.NewDevice(tc.parent)
require.NoError(t, err)
mig, err := devicelib.NewMigDevice(tc.mig)
require.NoError(t, err)
d, err := NewForMigDevice(parent, mig,
WithLogger(logger),
WithMIGCaps(tc.migCaps),
)
require.ErrorIs(t, err, tc.expectedError)
devices, _ := d.Devices()
require.EqualValues(t, tc.expectedDevices, devices)
hooks, _ := d.Hooks()
require.EqualValues(t, tc.expectedHooks, hooks)
mounts, _ := d.Mounts()
require.EqualValues(t, tc.expectedMounts, mounts)
})
}
}