Merge branch 'fix-mir-profile-equals' into 'main'

Fix bug where MigProfile.Equals() would not work with wrapper type

See merge request nvidia/cloud-native/go-nvlib!36
This commit is contained in:
Kevin Klues 2023-03-27 17:12:01 +00:00
commit 97a3f2d5c5
2 changed files with 145 additions and 76 deletions

View File

@ -142,7 +142,7 @@ func (d *devicelib) ParseMigProfile(profile string) (MigProfile, error) {
}
// String returns the string representation of a Profile
func (p *MigProfileInfo) String() string {
func (p MigProfileInfo) String() string {
var suffix string
if len(p.Attributes) > 0 {
suffix = "+" + strings.Join(p.Attributes, ",")
@ -154,39 +154,36 @@ func (p *MigProfileInfo) String() string {
}
// GetInfo returns detailed info about a Profile
func (p *MigProfileInfo) GetInfo() MigProfileInfo {
return *p
func (p MigProfileInfo) GetInfo() MigProfileInfo {
return p
}
// Equals checks if two Profiles are identical or not
func (p *MigProfileInfo) Equals(other MigProfile) bool {
switch o := other.(type) {
case *MigProfileInfo:
if p.C != o.C {
return false
}
if p.G != o.G {
return false
}
if p.GB != o.GB {
return false
}
if p.GIProfileID != o.GIProfileID {
return false
}
if p.CIProfileID != o.CIProfileID {
return false
}
if p.CIEngProfileID != o.CIEngProfileID {
return false
}
return true
func (p MigProfileInfo) Equals(other MigProfile) bool {
o := other.GetInfo()
if p.C != o.C {
return false
}
return false
if p.G != o.G {
return false
}
if p.GB != o.GB {
return false
}
if p.GIProfileID != o.GIProfileID {
return false
}
if p.CIProfileID != o.CIProfileID {
return false
}
if p.CIEngProfileID != o.CIEngProfileID {
return false
}
return true
}
// Matches checks if a MigProfile matches the string passed in
func (p *MigProfileInfo) Matches(profile string) bool {
func (p MigProfileInfo) Matches(profile string) bool {
c, g, gb, attrs, err := parseMigProfile(profile)
if err != nil {
return false

View File

@ -24,6 +24,62 @@ import (
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
)
type MigProfileInfoWrapper struct {
MigProfileInfo
}
func newMockDeviceLib() Interface {
mockDevice := &nvml.DeviceMock{
GetNameFunc: func() (string, nvml.Return) {
return "MockDevice", nvml.SUCCESS
},
GetMigModeFunc: func() (int, int, nvml.Return) {
return nvml.DEVICE_MIG_ENABLE, nvml.DEVICE_MIG_ENABLE, nvml.SUCCESS
},
GetMemoryInfoFunc: func() (nvml.Memory, nvml.Return) {
memory := nvml.Memory{
Total: 40 * 1024 * 1024 * 1024,
}
return memory, nvml.SUCCESS
},
GetGpuInstanceProfileInfoFunc: func(Profile int) (nvml.GpuInstanceProfileInfo, nvml.Return) {
info := nvml.GpuInstanceProfileInfo{}
switch Profile {
case nvml.GPU_INSTANCE_PROFILE_1_SLICE,
nvml.GPU_INSTANCE_PROFILE_1_SLICE_REV1:
info.MemorySizeMB = 5 * 1024
case nvml.GPU_INSTANCE_PROFILE_1_SLICE_REV2:
info.MemorySizeMB = 10 * 1024
case nvml.GPU_INSTANCE_PROFILE_2_SLICE,
nvml.GPU_INSTANCE_PROFILE_2_SLICE_REV1:
info.MemorySizeMB = 10 * 1024
case nvml.GPU_INSTANCE_PROFILE_3_SLICE:
info.MemorySizeMB = 20 * 1024
case nvml.GPU_INSTANCE_PROFILE_4_SLICE:
info.MemorySizeMB = 20 * 1024
case nvml.GPU_INSTANCE_PROFILE_7_SLICE:
info.MemorySizeMB = 40 * 1024
case nvml.GPU_INSTANCE_PROFILE_6_SLICE,
nvml.GPU_INSTANCE_PROFILE_8_SLICE:
fallthrough
default:
return info, nvml.ERROR_NOT_SUPPORTED
}
return info, nvml.SUCCESS
},
}
mockNvml := &nvml.InterfaceMock{
DeviceGetCountFunc: func() (int, nvml.Return) {
return 1, nvml.SUCCESS
},
DeviceGetHandleByIndexFunc: func(Index int) (nvml.Device, nvml.Return) {
return mockDevice, nvml.SUCCESS
},
}
return New(WithNvml(mockNvml), WithVerifySymbols(false))
}
func TestParseMigProfile(t *testing.T) {
testCases := []struct {
description string
@ -291,55 +347,7 @@ func TestParseMigProfile(t *testing.T) {
},
}
mockDevice := &nvml.DeviceMock{
GetNameFunc: func() (string, nvml.Return) {
return "MockDevice", nvml.SUCCESS
},
GetMigModeFunc: func() (int, int, nvml.Return) {
return nvml.DEVICE_MIG_ENABLE, nvml.DEVICE_MIG_ENABLE, nvml.SUCCESS
},
GetMemoryInfoFunc: func() (nvml.Memory, nvml.Return) {
memory := nvml.Memory{
Total: 40 * 1024 * 1024 * 1024,
}
return memory, nvml.SUCCESS
},
GetGpuInstanceProfileInfoFunc: func(Profile int) (nvml.GpuInstanceProfileInfo, nvml.Return) {
info := nvml.GpuInstanceProfileInfo{}
switch Profile {
case nvml.GPU_INSTANCE_PROFILE_1_SLICE,
nvml.GPU_INSTANCE_PROFILE_1_SLICE_REV1:
info.MemorySizeMB = 5 * 1024
case nvml.GPU_INSTANCE_PROFILE_1_SLICE_REV2:
info.MemorySizeMB = 10 * 1024
case nvml.GPU_INSTANCE_PROFILE_2_SLICE,
nvml.GPU_INSTANCE_PROFILE_2_SLICE_REV1:
info.MemorySizeMB = 10 * 1024
case nvml.GPU_INSTANCE_PROFILE_3_SLICE:
info.MemorySizeMB = 20 * 1024
case nvml.GPU_INSTANCE_PROFILE_4_SLICE:
info.MemorySizeMB = 20 * 1024
case nvml.GPU_INSTANCE_PROFILE_7_SLICE:
info.MemorySizeMB = 40 * 1024
case nvml.GPU_INSTANCE_PROFILE_6_SLICE,
nvml.GPU_INSTANCE_PROFILE_8_SLICE:
fallthrough
default:
return info, nvml.ERROR_NOT_SUPPORTED
}
return info, nvml.SUCCESS
},
}
mockNvml := &nvml.InterfaceMock{
DeviceGetCountFunc: func() (int, nvml.Return) {
return 1, nvml.SUCCESS
},
DeviceGetHandleByIndexFunc: func(Index int) (nvml.Device, nvml.Return) {
return mockDevice, nvml.SUCCESS
},
}
d := New(WithNvml(mockNvml), WithVerifySymbols(false))
d := newMockDeviceLib()
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
err := d.AssertValidMigProfileFormat(tc.device)
@ -358,6 +366,70 @@ func TestParseMigProfile(t *testing.T) {
}
}
func TestParseMigProfileEquals(t *testing.T) {
testCases := []struct {
description string
profile1 string
profile2 string
valid bool
}{
{
"Exactly equal",
"1g.5gb",
"1g.5gb",
true,
},
{
"Equal when expanded",
"1c.1g.5gb",
"1g.5gb",
true,
},
{
"Equal with attributes",
"1g.5gb+me",
"1g.5gb+me",
true,
},
{
"Not equal C slices",
"1c.2g.10gb",
"2c.2g.10gb",
false,
},
{
"Not equal G slices",
"1c.1g.10gb",
"1c.2g.10gb",
false,
},
{
"Not equal attributes",
"1g.5gb",
"1g.5gb+me",
false,
},
}
d := newMockDeviceLib()
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
p1, err := d.ParseMigProfile(tc.profile1)
require.Nil(t, err)
p2, err := d.ParseMigProfile(tc.profile2)
require.Nil(t, err)
wrapper := MigProfileInfoWrapper{p2.GetInfo()}
if tc.valid {
require.True(t, p1.Equals(p2))
require.True(t, p1.Equals(wrapper))
} else {
require.False(t, p1.Equals(p2))
require.False(t, p1.Equals(wrapper))
}
})
}
}
func TestGetMigMemorySizeGB(t *testing.T) {
type testCase struct {
totalDeviceMemory uint64