Merge branch 'add-mig-pkg' into 'main'

Add MIG package with abstraction for MIG profiles in it

See merge request nvidia/cloud-native/go-nvlib!15
This commit is contained in:
Kevin Klues 2022-09-16 08:17:41 +00:00
commit 8719e258a8
9 changed files with 938 additions and 0 deletions

135
pkg/nvlib/mig/device.go Normal file
View File

@ -0,0 +1,135 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package mig
import (
"fmt"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
)
// Device defines the set of extended functions associated with a mig.Device
type Device interface {
GetProfile() (Profile, error)
}
type device struct {
nvml.Device
miglib *miglib
profile Profile
}
var _ Device = &device{}
// NewDevice builds a new Device from an nvml.Device
func (m *miglib) NewDevice(d nvml.Device) (Device, error) {
isMig, ret := d.IsMigDeviceHandle()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error checking if device is a MIG device: %v", ret)
}
if !isMig {
return nil, fmt.Errorf("not a MIG device")
}
return &device{d, m, nil}, nil
}
// GetProfile returns the MIG profile associated with a MIG device
func (d *device) GetProfile() (Profile, error) {
if d.profile != nil {
return d.profile, nil
}
parent, ret := d.Device.GetDeviceHandleFromMigDeviceHandle()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting parent device handle: %v", ret)
}
parentMemoryInfo, ret := parent.GetMemoryInfo()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting parent memory info: %v", ret)
}
attributes, ret := d.Device.GetAttributes()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device attributes: %v", ret)
}
giID, ret := d.Device.GetGpuInstanceId()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device GPU Instance ID: %v", ret)
}
ciID, ret := d.Device.GetComputeInstanceId()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device Compute Instance ID: %v", ret)
}
gi, ret := parent.GetGpuInstanceById(giID)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting GPU Instance: %v", ret)
}
ci, ret := gi.GetComputeInstanceById(ciID)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting Compute Instance: %v", ret)
}
giInfo, ret := gi.GetInfo()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting GPU Instance info: %v", ret)
}
ciInfo, ret := ci.GetInfo()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting Compute Instance info: %v", ret)
}
for i := 0; i < nvml.GPU_INSTANCE_PROFILE_COUNT; i++ {
giProfileInfo, ret := parent.GetGpuInstanceProfileInfo(i)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting GPU Instance profile info: %v", ret)
}
if giProfileInfo.Id != giInfo.ProfileId {
continue
}
for j := 0; j < nvml.COMPUTE_INSTANCE_PROFILE_COUNT; j++ {
for k := 0; k < nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT; k++ {
ciProfileInfo, ret := gi.GetComputeInstanceProfileInfo(j, k)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting Compute Instance profile info: %v", ret)
}
if ciProfileInfo.Id != ciInfo.ProfileId {
continue
}
p, err := d.miglib.NewProfile(i, j, k, attributes.MemorySizeMB, parentMemoryInfo.Total)
if err != nil {
return nil, fmt.Errorf("error creating MIG profile: %v", err)
}
d.profile = p
return p, nil
}
}
}
return nil, fmt.Errorf("no matching profile IDs found")
}

56
pkg/nvlib/mig/mig.go Normal file
View File

@ -0,0 +1,56 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package mig
import (
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
)
// Interface provides the API to the mig package
type Interface interface {
NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (Profile, error)
ParseProfile(profile string) (Profile, error)
NewDevice(d nvml.Device) (Device, error)
}
type miglib struct {
nvml nvml.Interface
}
var _ Interface = &miglib{}
// New creates a new instance of the 'mig' interface
func New(opts ...Option) Interface {
m := &miglib{}
for _, opt := range opts {
opt(m)
}
if m.nvml == nil {
m.nvml = nvml.New()
}
return m
}
// WithNvml provides an Option to set the NVML library used by the 'mig' interface
func WithNvml(nvml nvml.Interface) Option {
return func(m *miglib) {
m.nvml = nvml
}
}
// Option defines a function for passing options to the New() call
type Option func(*miglib)

332
pkg/nvlib/mig/profile.go Normal file
View File

@ -0,0 +1,332 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package mig
import (
"fmt"
"math"
"strconv"
"strings"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
)
const (
// AttributeMediaExtensions holds the string representation for the media extension MIG profile attribute.
AttributeMediaExtensions = "me"
)
// Profile represents a specific MIG profile.
// Examples include "1g.5gb", "2g.10gb", "1c.2g.10gb", or "1c.1g.5gb+me", etc.
type Profile interface {
String() string
GetInfo() ProfileInfo
Equals(other Profile) bool
}
// ProfileInfo holds all info associated with a specific MIG profile
type ProfileInfo struct {
C int
G int
GB int
Attributes []string
GIProfileID int
CIProfileID int
CIEngProfileID int
}
var _ Profile = &ProfileInfo{}
// NewProfile constructs a new Profile struct using info from the giProfiles and ciProfiles used to create it.
func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (Profile, error) {
giSlices := 0
switch giProfileID {
case nvml.GPU_INSTANCE_PROFILE_1_SLICE:
giSlices = 1
case nvml.GPU_INSTANCE_PROFILE_1_SLICE_REV1:
giSlices = 1
case nvml.GPU_INSTANCE_PROFILE_2_SLICE:
giSlices = 2
case nvml.GPU_INSTANCE_PROFILE_3_SLICE:
giSlices = 3
case nvml.GPU_INSTANCE_PROFILE_4_SLICE:
giSlices = 4
case nvml.GPU_INSTANCE_PROFILE_6_SLICE:
giSlices = 6
case nvml.GPU_INSTANCE_PROFILE_7_SLICE:
giSlices = 7
case nvml.GPU_INSTANCE_PROFILE_8_SLICE:
giSlices = 8
default:
return nil, fmt.Errorf("invalid GPU Instance Profile ID: %v", giProfileID)
}
ciSlices := 0
switch ciProfileID {
case nvml.COMPUTE_INSTANCE_PROFILE_1_SLICE:
ciSlices = 1
case nvml.COMPUTE_INSTANCE_PROFILE_2_SLICE:
ciSlices = 2
case nvml.COMPUTE_INSTANCE_PROFILE_3_SLICE:
ciSlices = 3
case nvml.COMPUTE_INSTANCE_PROFILE_4_SLICE:
ciSlices = 4
case nvml.COMPUTE_INSTANCE_PROFILE_6_SLICE:
ciSlices = 6
case nvml.COMPUTE_INSTANCE_PROFILE_7_SLICE:
ciSlices = 7
case nvml.COMPUTE_INSTANCE_PROFILE_8_SLICE:
ciSlices = 8
default:
return nil, fmt.Errorf("invalid Compute Instance Profile ID: %v", ciProfileID)
}
var attrs []string
switch giProfileID {
case nvml.GPU_INSTANCE_PROFILE_1_SLICE_REV1:
attrs = append(attrs, AttributeMediaExtensions)
}
p := &ProfileInfo{
C: ciSlices,
G: giSlices,
GB: int(getMigMemorySizeGB(deviceMemorySizeBytes, migMemorySizeMB)),
Attributes: attrs,
GIProfileID: giProfileID,
CIProfileID: ciProfileID,
CIEngProfileID: ciEngProfileID,
}
return p, nil
}
// ParseProfile converts a string representation of a Profile into an object.
func (m *miglib) ParseProfile(profile string) (Profile, error) {
var err error
var c, g, gb int
var attrs []string
if len(profile) == 0 {
return nil, fmt.Errorf("empty Profile string")
}
split := strings.SplitN(profile, "+", 2)
if len(split) == 2 {
attrs, err = parseProfileAttributes(split[1])
if err != nil {
return nil, fmt.Errorf("error parsing attributes following '+' in Profile string: %v", err)
}
}
c, g, gb, err = parseProfileFields(split[0])
if err != nil {
return nil, fmt.Errorf("error parsing '.' separated fields in Profile string: %v", err)
}
p := &ProfileInfo{
C: c,
G: g,
GB: gb,
Attributes: attrs,
}
switch c {
case 1:
p.CIProfileID = nvml.COMPUTE_INSTANCE_PROFILE_1_SLICE
case 2:
p.CIProfileID = nvml.COMPUTE_INSTANCE_PROFILE_2_SLICE
case 3:
p.CIProfileID = nvml.COMPUTE_INSTANCE_PROFILE_3_SLICE
case 4:
p.CIProfileID = nvml.COMPUTE_INSTANCE_PROFILE_4_SLICE
case 6:
p.CIProfileID = nvml.COMPUTE_INSTANCE_PROFILE_6_SLICE
case 7:
p.CIProfileID = nvml.COMPUTE_INSTANCE_PROFILE_7_SLICE
case 8:
p.CIProfileID = nvml.COMPUTE_INSTANCE_PROFILE_8_SLICE
default:
return nil, fmt.Errorf("unknown Compute Instance slice size: %v", c)
}
switch g {
case 1:
p.GIProfileID = nvml.GPU_INSTANCE_PROFILE_1_SLICE
case 2:
p.GIProfileID = nvml.GPU_INSTANCE_PROFILE_2_SLICE
case 3:
p.GIProfileID = nvml.GPU_INSTANCE_PROFILE_3_SLICE
case 4:
p.GIProfileID = nvml.GPU_INSTANCE_PROFILE_4_SLICE
case 6:
p.GIProfileID = nvml.GPU_INSTANCE_PROFILE_6_SLICE
case 7:
p.GIProfileID = nvml.GPU_INSTANCE_PROFILE_7_SLICE
case 8:
p.GIProfileID = nvml.GPU_INSTANCE_PROFILE_8_SLICE
default:
return nil, fmt.Errorf("unknown GPU Instance slice size: %v", g)
}
p.CIEngProfileID = nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED
for _, a := range attrs {
switch a {
case AttributeMediaExtensions:
p.GIProfileID = nvml.GPU_INSTANCE_PROFILE_1_SLICE_REV1
default:
return nil, fmt.Errorf("unknown Profile attribute: %v", a)
}
}
return p, nil
}
// String returns the string representation of a Profile
func (p *ProfileInfo) String() string {
var suffix string
if len(p.Attributes) > 0 {
suffix = "+" + strings.Join(p.Attributes, ",")
}
if p.C == p.G {
return fmt.Sprintf("%dg.%dgb%s", p.G, p.GB, suffix)
}
return fmt.Sprintf("%dc.%dg.%dgb%s", p.C, p.G, p.GB, suffix)
}
// GetInfo returns detailed info about a Profile
func (p *ProfileInfo) GetInfo() ProfileInfo {
return *p
}
// Equals checks if two Profiles are identical or not
func (p *ProfileInfo) Equals(other Profile) bool {
switch o := other.(type) {
case *ProfileInfo:
if p.C != o.C {
return false
}
if p.G != o.G {
return false
}
if p.GB != o.GB {
return false
}
if p.GIProfileID != o.GIProfileID {
return false
}
if p.CIProfileID != o.CIProfileID {
return false
}
if p.CIEngProfileID != o.CIEngProfileID {
return false
}
return true
}
return false
}
func parseProfileField(s string, field string) (int, error) {
if strings.TrimSpace(s) != s {
return -1, fmt.Errorf("leading or trailing spaces on '%%d%s'", field)
}
if !strings.HasSuffix(s, field) {
return -1, fmt.Errorf("missing '%s' from '%%d%s'", field, field)
}
v, err := strconv.Atoi(strings.TrimSuffix(s, field))
if err != nil {
return -1, fmt.Errorf("malformed number in '%%d%s'", field)
}
return v, nil
}
func parseProfileFields(s string) (int, int, int, error) {
var err error
var c, g, gb int
split := strings.SplitN(s, ".", 3)
if len(split) == 3 {
c, err = parseProfileField(split[0], "c")
if err != nil {
return -1, -1, -1, err
}
g, err = parseProfileField(split[1], "g")
if err != nil {
return -1, -1, -1, err
}
gb, err = parseProfileField(split[2], "gb")
if err != nil {
return -1, -1, -1, err
}
return c, g, gb, err
}
if len(split) == 2 {
g, err = parseProfileField(split[0], "g")
if err != nil {
return -1, -1, -1, err
}
gb, err = parseProfileField(split[1], "gb")
if err != nil {
return -1, -1, -1, err
}
return g, g, gb, nil
}
return -1, -1, -1, fmt.Errorf("parsed wrong number of fields, expected 2 or 3")
}
func parseProfileAttributes(s string) ([]string, error) {
attr := strings.Split(s, ",")
if len(attr) == 0 {
return nil, fmt.Errorf("empty attribute list")
}
unique := make(map[string]int)
for _, a := range attr {
if unique[a] > 0 {
return nil, fmt.Errorf("non unique attribute in list")
}
if a == "" {
return nil, fmt.Errorf("empty attribute in list")
}
if strings.TrimSpace(a) != a {
return nil, fmt.Errorf("leading or trailing spaces in attribute")
}
if a[0] >= '0' && a[0] <= '9' {
return nil, fmt.Errorf("attribute begins with a number")
}
for _, c := range a {
if (c < 'a' || c > 'z') && (c < 'A' || c > 'Z') && (c < '0' || c > '9') {
return nil, fmt.Errorf("non alpha-numeric character or digit in attribute")
}
}
unique[a]++
}
return attr, nil
}
func getMigMemorySizeGB(totalDeviceMemory, migMemorySizeMB uint64) uint64 {
const fracDenominator = 8
const oneMB = 1024 * 1024
const oneGB = 1024 * 1024 * 1024
fractionalGpuMem := (float64(migMemorySizeMB) * oneMB) / float64(totalDeviceMemory)
fractionalGpuMem = math.Ceil(fractionalGpuMem*fracDenominator) / fracDenominator
totalMemGB := float64((totalDeviceMemory + oneGB - 1) / oneGB)
return uint64(math.Round(fractionalGpuMem * totalMemGB))
}

View File

@ -0,0 +1,315 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package mig
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestParseMigProfile(t *testing.T) {
testCases := []struct {
description string
device string
valid bool
}{
{
"Empty device type",
"",
false,
},
{
"Valid 1g.5gb",
"1g.5gb",
true,
},
{
"Valid 1c.1g.5gb",
"1c.1g.5gb",
true,
},
{
"Valid 1g.5gb+me",
"1g.5gb+me",
true,
},
{
"Valid 1c.1g.5gb+me",
"1c.1g.5gb+me",
true,
},
{
"Invalid 0g.0gb",
"0g.0gb",
false,
},
{
"Invalid 0c.0g.0gb",
"0c.0g.0gb",
false,
},
{
"Invalid 10000g.500000gb",
"10000g.500000gb",
false,
},
{
"Invalid 10000c.10000g.500000gb",
"10000c.10000g.500000gb",
false,
},
{
"Invalid ' 1c.1g.5gb'",
" 1c.1g.5gb",
false,
},
{
"Invalid '1 c.1g.5gb'",
"1 c.1g.5gb",
false,
},
{
"Invalid '1c .1g.5gb'",
"1c .1g.5gb",
false,
},
{
"Invalid '1c. 1g.5gb'",
"1c. 1g.5gb",
false,
},
{
"Invalid '1c.1 g.5gb'",
"1c.1 g.5gb",
false,
},
{
"Invalid '1c.1g .5gb'",
"1c.1g .5gb",
false,
},
{
"Invalid '1c.1g. 5gb'",
"1c.1g. 5gb",
false,
},
{
"Invalid '1c.1g.5 gb'",
"1c.1g.5 gb",
false,
},
{
"Invalid '1c.1g.5g b'",
"1c.1g.5g b",
false,
},
{
"Invalid '1c.1g.5gb '",
"1c.1g.5gb ",
false,
},
{
"Invalid '1c . 1g . 5gb'",
"1c . 1g . 5gb",
false,
},
{
"Invalid 1c.f1g.5gb",
"1c.f1g.5gb",
false,
},
{
"Invalid 1r.1g.5gb",
"1r.1g.5gb",
false,
},
{
"Invalid 1g.5gbk",
"1g.5gbk",
false,
},
{
"Invalid 1g.5",
"1g.5",
false,
},
{
"Invalid g.5gb",
"1g.5",
false,
},
{
"Invalid g.5gb",
"g.5gb",
false,
},
{
"Invalid 1g.gb",
"1g.gb",
false,
},
{
"Invalid 1g.5gb+me,me",
"1g.5gb+me,me",
false,
},
{
"Invalid 1g.5gb+me,you,them",
"1g.5gb+me,you,them",
false,
},
{
"Invalid 1c.1g.5gb+me,you,them",
"1c.1g.5gb+me,you,them",
false,
},
{
"Invalid 1g.5gb+",
"1g.5gb+",
false,
},
{
"Invalid 1g.5gb +",
"1g.5gb+",
false,
},
{
"Invalid 1g.5gb+ ",
"1g.5gb+",
false,
},
{
"Invalid 1g.5gb+ ,",
"1g.5gb+",
false,
},
{
"Invalid 1g.5gb+,",
"1g.5gb+",
false,
},
{
"Invalid 1g.5gb+,,",
"1g.5gb+",
false,
},
{
"Invalid 1g.5gb+me,",
"1g.5gb+",
false,
},
{
"Invalid 1g.5gb+me,,",
"1g.5gb+",
false,
},
{
"Invalid 1g.5gb+me, ",
"1g.5gb+",
false,
},
{
"Invalid 1g.5gb+2me",
"1g.5gb+2me",
false,
},
{
"Inavlid 1g.5gb*me",
"1g.5gb*me",
false,
},
{
"Invalid 1c.1g.5gb*me",
"1c.1g.5gb*me",
false,
},
{
"Invalid 1g.5gb*me,you,them",
"1g.5gb*me,you,them",
false,
},
{
"Invalid 1c.1g.5gb*me,you,them",
"1c.1g.5gb*me,you,them",
false,
},
{
"Invalid bogus",
"bogus",
false,
},
}
m := New()
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
_, err := m.ParseProfile(tc.device)
if tc.valid {
require.Nil(t, err)
} else {
require.Error(t, err)
}
})
}
}
func TestGetMigMemorySizeGB(t *testing.T) {
type testCase struct {
totalDeviceMemory uint64
migMemorySizeMB uint64
expectedMemorySizeGB uint64
}
const maxMemorySlices = 8
const oneMB = uint64(1024 * 1024)
const oneGB = uint64(1024 * 1024 * 1024)
totalDeviceMemory := []uint64{
24 * oneGB,
40 * oneGB,
80 * oneGB,
}
testCases := []testCase{}
for _, tdm := range totalDeviceMemory {
sliceSize := tdm / maxMemorySlices
const stepSize = oneGB / 32
for i := stepSize; i <= tdm; i += stepSize {
tc := testCase{
totalDeviceMemory: tdm,
migMemorySizeMB: i / oneMB,
}
for j := uint64(sliceSize); j <= tdm; j += sliceSize {
if i <= j {
tc.expectedMemorySizeGB = j / oneGB
break
}
}
testCases = append(testCases, tc)
}
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("%v", tc.migMemorySizeMB), func(t *testing.T) {
memorySizeGB := getMigMemorySizeGB(tc.totalDeviceMemory, tc.migMemorySizeMB)
require.Equal(t, int(tc.expectedMemorySizeGB), int(memorySizeGB))
})
}
}

View File

@ -76,6 +76,12 @@ func (d nvmlDevice) GetMigMode() (int, int, Return) {
return s1, s2, Return(r)
}
// GetGpuInstanceById returns the GPU Instance associated with a particular ID
func (d nvmlDevice) GetGpuInstanceById(id int) (GpuInstance, Return) {
gi, r := nvml.Device(d).GetGpuInstanceById(id)
return nvmlGpuInstance(gi), Return(r)
}
// GetGpuInstanceProfileInfo returns the profile info of a GPU Instance
func (d nvmlDevice) GetGpuInstanceProfileInfo(profile int) (GpuInstanceProfileInfo, Return) {
p, r := nvml.Device(d).GetGpuInstanceProfileInfo(profile)

View File

@ -29,6 +29,9 @@ var _ Device = &DeviceMock{}
// GetDeviceHandleFromMigDeviceHandleFunc: func() (Device, Return) {
// panic("mock out the GetDeviceHandleFromMigDeviceHandle method")
// },
// GetGpuInstanceByIdFunc: func(ID int) (GpuInstance, Return) {
// panic("mock out the GetGpuInstanceById method")
// },
// GetGpuInstanceIdFunc: func() (int, Return) {
// panic("mock out the GetGpuInstanceId method")
// },
@ -90,6 +93,9 @@ type DeviceMock struct {
// GetDeviceHandleFromMigDeviceHandleFunc mocks the GetDeviceHandleFromMigDeviceHandle method.
GetDeviceHandleFromMigDeviceHandleFunc func() (Device, Return)
// GetGpuInstanceByIdFunc mocks the GetGpuInstanceById method.
GetGpuInstanceByIdFunc func(ID int) (GpuInstance, Return)
// GetGpuInstanceIdFunc mocks the GetGpuInstanceId method.
GetGpuInstanceIdFunc func() (int, Return)
@ -146,6 +152,11 @@ type DeviceMock struct {
// GetDeviceHandleFromMigDeviceHandle holds details about calls to the GetDeviceHandleFromMigDeviceHandle method.
GetDeviceHandleFromMigDeviceHandle []struct {
}
// GetGpuInstanceById holds details about calls to the GetGpuInstanceById method.
GetGpuInstanceById []struct {
// ID is the ID argument value.
ID int
}
// GetGpuInstanceId holds details about calls to the GetGpuInstanceId method.
GetGpuInstanceId []struct {
}
@ -201,6 +212,7 @@ type DeviceMock struct {
lockGetComputeInstanceId sync.RWMutex
lockGetCudaComputeCapability sync.RWMutex
lockGetDeviceHandleFromMigDeviceHandle sync.RWMutex
lockGetGpuInstanceById sync.RWMutex
lockGetGpuInstanceId sync.RWMutex
lockGetGpuInstanceProfileInfo sync.RWMutex
lockGetGpuInstances sync.RWMutex
@ -321,6 +333,37 @@ func (mock *DeviceMock) GetDeviceHandleFromMigDeviceHandleCalls() []struct {
return calls
}
// GetGpuInstanceById calls GetGpuInstanceByIdFunc.
func (mock *DeviceMock) GetGpuInstanceById(ID int) (GpuInstance, Return) {
if mock.GetGpuInstanceByIdFunc == nil {
panic("DeviceMock.GetGpuInstanceByIdFunc: method is nil but Device.GetGpuInstanceById was just called")
}
callInfo := struct {
ID int
}{
ID: ID,
}
mock.lockGetGpuInstanceById.Lock()
mock.calls.GetGpuInstanceById = append(mock.calls.GetGpuInstanceById, callInfo)
mock.lockGetGpuInstanceById.Unlock()
return mock.GetGpuInstanceByIdFunc(ID)
}
// GetGpuInstanceByIdCalls gets all the calls that were made to GetGpuInstanceById.
// Check the length with:
// len(mockedDevice.GetGpuInstanceByIdCalls())
func (mock *DeviceMock) GetGpuInstanceByIdCalls() []struct {
ID int
} {
var calls []struct {
ID int
}
mock.lockGetGpuInstanceById.RLock()
calls = mock.calls.GetGpuInstanceById
mock.lockGetGpuInstanceById.RUnlock()
return calls
}
// GetGpuInstanceId calls GetGpuInstanceIdFunc.
func (mock *DeviceMock) GetGpuInstanceId() (int, Return) {
if mock.GetGpuInstanceIdFunc == nil {

View File

@ -36,6 +36,12 @@ func (gi nvmlGpuInstance) GetInfo() (GpuInstanceInfo, Return) {
return info, Return(r)
}
// GetComputeInstanceById returns the Compute Instance associated with a particular ID.
func (gi nvmlGpuInstance) GetComputeInstanceById(id int) (ComputeInstance, Return) {
ci, r := nvml.GpuInstance(gi).GetComputeInstanceById(id)
return nvmlComputeInstance(ci), Return(r)
}
// GetComputeInstanceProfileInfo returns info about a given Compute Instance profile
func (gi nvmlGpuInstance) GetComputeInstanceProfileInfo(profile int, engProfile int) (ComputeInstanceProfileInfo, Return) {
p, r := nvml.GpuInstance(gi).GetComputeInstanceProfileInfo(profile, engProfile)

View File

@ -23,6 +23,9 @@ var _ GpuInstance = &GpuInstanceMock{}
// DestroyFunc: func() Return {
// panic("mock out the Destroy method")
// },
// GetComputeInstanceByIdFunc: func(ID int) (ComputeInstance, Return) {
// panic("mock out the GetComputeInstanceById method")
// },
// GetComputeInstanceProfileInfoFunc: func(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return) {
// panic("mock out the GetComputeInstanceProfileInfo method")
// },
@ -45,6 +48,9 @@ type GpuInstanceMock struct {
// DestroyFunc mocks the Destroy method.
DestroyFunc func() Return
// GetComputeInstanceByIdFunc mocks the GetComputeInstanceById method.
GetComputeInstanceByIdFunc func(ID int) (ComputeInstance, Return)
// GetComputeInstanceProfileInfoFunc mocks the GetComputeInstanceProfileInfo method.
GetComputeInstanceProfileInfoFunc func(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return)
@ -64,6 +70,11 @@ type GpuInstanceMock struct {
// Destroy holds details about calls to the Destroy method.
Destroy []struct {
}
// GetComputeInstanceById holds details about calls to the GetComputeInstanceById method.
GetComputeInstanceById []struct {
// ID is the ID argument value.
ID int
}
// GetComputeInstanceProfileInfo holds details about calls to the GetComputeInstanceProfileInfo method.
GetComputeInstanceProfileInfo []struct {
// Profile is the Profile argument value.
@ -82,6 +93,7 @@ type GpuInstanceMock struct {
}
lockCreateComputeInstance sync.RWMutex
lockDestroy sync.RWMutex
lockGetComputeInstanceById sync.RWMutex
lockGetComputeInstanceProfileInfo sync.RWMutex
lockGetComputeInstances sync.RWMutex
lockGetInfo sync.RWMutex
@ -144,6 +156,37 @@ func (mock *GpuInstanceMock) DestroyCalls() []struct {
return calls
}
// GetComputeInstanceById calls GetComputeInstanceByIdFunc.
func (mock *GpuInstanceMock) GetComputeInstanceById(ID int) (ComputeInstance, Return) {
if mock.GetComputeInstanceByIdFunc == nil {
panic("GpuInstanceMock.GetComputeInstanceByIdFunc: method is nil but GpuInstance.GetComputeInstanceById was just called")
}
callInfo := struct {
ID int
}{
ID: ID,
}
mock.lockGetComputeInstanceById.Lock()
mock.calls.GetComputeInstanceById = append(mock.calls.GetComputeInstanceById, callInfo)
mock.lockGetComputeInstanceById.Unlock()
return mock.GetComputeInstanceByIdFunc(ID)
}
// GetComputeInstanceByIdCalls gets all the calls that were made to GetComputeInstanceById.
// Check the length with:
// len(mockedGpuInstance.GetComputeInstanceByIdCalls())
func (mock *GpuInstanceMock) GetComputeInstanceByIdCalls() []struct {
ID int
} {
var calls []struct {
ID int
}
mock.lockGetComputeInstanceById.RLock()
calls = mock.calls.GetComputeInstanceById
mock.lockGetComputeInstanceById.RUnlock()
return calls
}
// GetComputeInstanceProfileInfo calls GetComputeInstanceProfileInfoFunc.
func (mock *GpuInstanceMock) GetComputeInstanceProfileInfo(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return) {
if mock.GetComputeInstanceProfileInfoFunc == nil {

View File

@ -42,6 +42,7 @@ type Device interface {
GetComputeInstanceId() (int, Return)
GetCudaComputeCapability() (int, int, Return)
GetDeviceHandleFromMigDeviceHandle() (Device, Return)
GetGpuInstanceById(ID int) (GpuInstance, Return)
GetGpuInstanceId() (int, Return)
GetGpuInstanceProfileInfo(Profile int) (GpuInstanceProfileInfo, Return)
GetGpuInstances(Info *GpuInstanceProfileInfo) ([]GpuInstance, Return)
@ -64,6 +65,7 @@ type Device interface {
type GpuInstance interface {
CreateComputeInstance(Info *ComputeInstanceProfileInfo) (ComputeInstance, Return)
Destroy() Return
GetComputeInstanceById(ID int) (ComputeInstance, Return)
GetComputeInstanceProfileInfo(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return)
GetComputeInstances(Info *ComputeInstanceProfileInfo) ([]ComputeInstance, Return)
GetInfo() (GpuInstanceInfo, Return)