diff --git a/pkg/nvlib/device/mig_profile.go b/pkg/nvlib/device/mig_profile.go index a5c455e..13a5dbf 100644 --- a/pkg/nvlib/device/mig_profile.go +++ b/pkg/nvlib/device/mig_profile.go @@ -142,7 +142,7 @@ func (d *devicelib) ParseMigProfile(profile string) (MigProfile, error) { } // String returns the string representation of a Profile -func (p *MigProfileInfo) String() string { +func (p MigProfileInfo) String() string { var suffix string if len(p.Attributes) > 0 { suffix = "+" + strings.Join(p.Attributes, ",") @@ -154,39 +154,36 @@ func (p *MigProfileInfo) String() string { } // GetInfo returns detailed info about a Profile -func (p *MigProfileInfo) GetInfo() MigProfileInfo { - return *p +func (p MigProfileInfo) GetInfo() MigProfileInfo { + return p } // Equals checks if two Profiles are identical or not -func (p *MigProfileInfo) Equals(other MigProfile) bool { - switch o := other.(type) { - case *MigProfileInfo: - if p.C != o.C { - return false - } - if p.G != o.G { - return false - } - if p.GB != o.GB { - return false - } - if p.GIProfileID != o.GIProfileID { - return false - } - if p.CIProfileID != o.CIProfileID { - return false - } - if p.CIEngProfileID != o.CIEngProfileID { - return false - } - return true +func (p MigProfileInfo) Equals(other MigProfile) bool { + o := other.GetInfo() + if p.C != o.C { + return false } - return false + if p.G != o.G { + return false + } + if p.GB != o.GB { + return false + } + if p.GIProfileID != o.GIProfileID { + return false + } + if p.CIProfileID != o.CIProfileID { + return false + } + if p.CIEngProfileID != o.CIEngProfileID { + return false + } + return true } // Matches checks if a MigProfile matches the string passed in -func (p *MigProfileInfo) Matches(profile string) bool { +func (p MigProfileInfo) Matches(profile string) bool { c, g, gb, attrs, err := parseMigProfile(profile) if err != nil { return false diff --git a/pkg/nvlib/device/mig_profile_test.go b/pkg/nvlib/device/mig_profile_test.go index d1a4f5b..4bd3eda 100644 --- a/pkg/nvlib/device/mig_profile_test.go +++ b/pkg/nvlib/device/mig_profile_test.go @@ -24,6 +24,62 @@ import ( "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" ) +type MigProfileInfoWrapper struct { + MigProfileInfo +} + +func newMockDeviceLib() Interface { + mockDevice := &nvml.DeviceMock{ + GetNameFunc: func() (string, nvml.Return) { + return "MockDevice", nvml.SUCCESS + }, + GetMigModeFunc: func() (int, int, nvml.Return) { + return nvml.DEVICE_MIG_ENABLE, nvml.DEVICE_MIG_ENABLE, nvml.SUCCESS + }, + GetMemoryInfoFunc: func() (nvml.Memory, nvml.Return) { + memory := nvml.Memory{ + Total: 40 * 1024 * 1024 * 1024, + } + return memory, nvml.SUCCESS + }, + GetGpuInstanceProfileInfoFunc: func(Profile int) (nvml.GpuInstanceProfileInfo, nvml.Return) { + info := nvml.GpuInstanceProfileInfo{} + switch Profile { + case nvml.GPU_INSTANCE_PROFILE_1_SLICE, + nvml.GPU_INSTANCE_PROFILE_1_SLICE_REV1: + info.MemorySizeMB = 5 * 1024 + case nvml.GPU_INSTANCE_PROFILE_1_SLICE_REV2: + info.MemorySizeMB = 10 * 1024 + case nvml.GPU_INSTANCE_PROFILE_2_SLICE, + nvml.GPU_INSTANCE_PROFILE_2_SLICE_REV1: + info.MemorySizeMB = 10 * 1024 + case nvml.GPU_INSTANCE_PROFILE_3_SLICE: + info.MemorySizeMB = 20 * 1024 + case nvml.GPU_INSTANCE_PROFILE_4_SLICE: + info.MemorySizeMB = 20 * 1024 + case nvml.GPU_INSTANCE_PROFILE_7_SLICE: + info.MemorySizeMB = 40 * 1024 + case nvml.GPU_INSTANCE_PROFILE_6_SLICE, + nvml.GPU_INSTANCE_PROFILE_8_SLICE: + fallthrough + default: + return info, nvml.ERROR_NOT_SUPPORTED + } + return info, nvml.SUCCESS + }, + } + mockNvml := &nvml.InterfaceMock{ + DeviceGetCountFunc: func() (int, nvml.Return) { + return 1, nvml.SUCCESS + }, + DeviceGetHandleByIndexFunc: func(Index int) (nvml.Device, nvml.Return) { + return mockDevice, nvml.SUCCESS + }, + } + + return New(WithNvml(mockNvml), WithVerifySymbols(false)) +} + func TestParseMigProfile(t *testing.T) { testCases := []struct { description string @@ -291,55 +347,7 @@ func TestParseMigProfile(t *testing.T) { }, } - mockDevice := &nvml.DeviceMock{ - GetNameFunc: func() (string, nvml.Return) { - return "MockDevice", nvml.SUCCESS - }, - GetMigModeFunc: func() (int, int, nvml.Return) { - return nvml.DEVICE_MIG_ENABLE, nvml.DEVICE_MIG_ENABLE, nvml.SUCCESS - }, - GetMemoryInfoFunc: func() (nvml.Memory, nvml.Return) { - memory := nvml.Memory{ - Total: 40 * 1024 * 1024 * 1024, - } - return memory, nvml.SUCCESS - }, - GetGpuInstanceProfileInfoFunc: func(Profile int) (nvml.GpuInstanceProfileInfo, nvml.Return) { - info := nvml.GpuInstanceProfileInfo{} - switch Profile { - case nvml.GPU_INSTANCE_PROFILE_1_SLICE, - nvml.GPU_INSTANCE_PROFILE_1_SLICE_REV1: - info.MemorySizeMB = 5 * 1024 - case nvml.GPU_INSTANCE_PROFILE_1_SLICE_REV2: - info.MemorySizeMB = 10 * 1024 - case nvml.GPU_INSTANCE_PROFILE_2_SLICE, - nvml.GPU_INSTANCE_PROFILE_2_SLICE_REV1: - info.MemorySizeMB = 10 * 1024 - case nvml.GPU_INSTANCE_PROFILE_3_SLICE: - info.MemorySizeMB = 20 * 1024 - case nvml.GPU_INSTANCE_PROFILE_4_SLICE: - info.MemorySizeMB = 20 * 1024 - case nvml.GPU_INSTANCE_PROFILE_7_SLICE: - info.MemorySizeMB = 40 * 1024 - case nvml.GPU_INSTANCE_PROFILE_6_SLICE, - nvml.GPU_INSTANCE_PROFILE_8_SLICE: - fallthrough - default: - return info, nvml.ERROR_NOT_SUPPORTED - } - return info, nvml.SUCCESS - }, - } - mockNvml := &nvml.InterfaceMock{ - DeviceGetCountFunc: func() (int, nvml.Return) { - return 1, nvml.SUCCESS - }, - DeviceGetHandleByIndexFunc: func(Index int) (nvml.Device, nvml.Return) { - return mockDevice, nvml.SUCCESS - }, - } - - d := New(WithNvml(mockNvml), WithVerifySymbols(false)) + d := newMockDeviceLib() for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { err := d.AssertValidMigProfileFormat(tc.device) @@ -358,6 +366,70 @@ func TestParseMigProfile(t *testing.T) { } } +func TestParseMigProfileEquals(t *testing.T) { + testCases := []struct { + description string + profile1 string + profile2 string + valid bool + }{ + { + "Exactly equal", + "1g.5gb", + "1g.5gb", + true, + }, + { + "Equal when expanded", + "1c.1g.5gb", + "1g.5gb", + true, + }, + { + "Equal with attributes", + "1g.5gb+me", + "1g.5gb+me", + true, + }, + { + "Not equal C slices", + "1c.2g.10gb", + "2c.2g.10gb", + false, + }, + { + "Not equal G slices", + "1c.1g.10gb", + "1c.2g.10gb", + false, + }, + { + "Not equal attributes", + "1g.5gb", + "1g.5gb+me", + false, + }, + } + + d := newMockDeviceLib() + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + p1, err := d.ParseMigProfile(tc.profile1) + require.Nil(t, err) + p2, err := d.ParseMigProfile(tc.profile2) + require.Nil(t, err) + wrapper := MigProfileInfoWrapper{p2.GetInfo()} + if tc.valid { + require.True(t, p1.Equals(p2)) + require.True(t, p1.Equals(wrapper)) + } else { + require.False(t, p1.Equals(p2)) + require.False(t, p1.Equals(wrapper)) + } + }) + } +} + func TestGetMigMemorySizeGB(t *testing.T) { type testCase struct { totalDeviceMemory uint64