Move MIG apis to device package

We decided it makes sense to have top level device and MIG device abstractions
all under one package rather than trying to separate them. It will make it
easier to hav them clal between each other without package dependency loops.

Signed-off-by: Kevin Klues <kklues@nvidia.com>
This commit is contained in:
Kevin Klues 2022-09-16 10:04:59 +00:00
parent 8719e258a8
commit 1d680a93b6
4 changed files with 71 additions and 70 deletions

View File

@ -14,43 +14,43 @@
* limitations under the License.
*/
package mig
package device
import (
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
)
// Interface provides the API to the mig package
// Interface provides the API to the 'device' package
type Interface interface {
NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (Profile, error)
ParseProfile(profile string) (Profile, error)
NewDevice(d nvml.Device) (Device, error)
NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (MigProfile, error)
ParseMigProfile(profile string) (MigProfile, error)
NewMigDevice(d nvml.Device) (MigDevice, error)
}
type miglib struct {
type devicelib struct {
nvml nvml.Interface
}
var _ Interface = &miglib{}
var _ Interface = &devicelib{}
// New creates a new instance of the 'mig' interface
// New creates a new instance of the 'device' interface
func New(opts ...Option) Interface {
m := &miglib{}
d := &devicelib{}
for _, opt := range opts {
opt(m)
opt(d)
}
if m.nvml == nil {
m.nvml = nvml.New()
if d.nvml == nil {
d.nvml = nvml.New()
}
return m
return d
}
// WithNvml provides an Option to set the NVML library used by the 'mig' interface
// WithNvml provides an Option to set the NVML library used by the 'device' interface
func WithNvml(nvml nvml.Interface) Option {
return func(m *miglib) {
m.nvml = nvml
return func(d *devicelib) {
d.nvml = nvml
}
}
// Option defines a function for passing options to the New() call
type Option func(*miglib)
type Option func(*devicelib)

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package mig
package device
import (
"fmt"
@ -22,38 +22,39 @@ import (
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
)
// Device defines the set of extended functions associated with a mig.Device
type Device interface {
GetProfile() (Profile, error)
}
type device struct {
// MigDevice defines the set of extended functions associated with a MIG device
type MigDevice interface {
nvml.Device
miglib *miglib
profile Profile
GetProfile() (MigProfile, error)
}
var _ Device = &device{}
type migdevice struct {
nvml.Device
lib *devicelib
profile MigProfile
}
// NewDevice builds a new Device from an nvml.Device
func (m *miglib) NewDevice(d nvml.Device) (Device, error) {
isMig, ret := d.IsMigDeviceHandle()
var _ MigDevice = &migdevice{}
// NewMigDevice builds a new MigDevice from an nvml.Device
func (d *devicelib) NewMigDevice(handle nvml.Device) (MigDevice, error) {
isMig, ret := handle.IsMigDeviceHandle()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error checking if device is a MIG device: %v", ret)
}
if !isMig {
return nil, fmt.Errorf("not a MIG device")
}
return &device{d, m, nil}, nil
return &migdevice{handle, d, nil}, nil
}
// GetProfile returns the MIG profile associated with a MIG device
func (d *device) GetProfile() (Profile, error) {
if d.profile != nil {
return d.profile, nil
func (m *migdevice) GetProfile() (MigProfile, error) {
if m.profile != nil {
return m.profile, nil
}
parent, ret := d.Device.GetDeviceHandleFromMigDeviceHandle()
parent, ret := m.Device.GetDeviceHandleFromMigDeviceHandle()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting parent device handle: %v", ret)
}
@ -63,17 +64,17 @@ func (d *device) GetProfile() (Profile, error) {
return nil, fmt.Errorf("error getting parent memory info: %v", ret)
}
attributes, ret := d.Device.GetAttributes()
attributes, ret := m.Device.GetAttributes()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device attributes: %v", ret)
}
giID, ret := d.Device.GetGpuInstanceId()
giID, ret := m.Device.GetGpuInstanceId()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device GPU Instance ID: %v", ret)
}
ciID, ret := d.Device.GetComputeInstanceId()
ciID, ret := m.Device.GetComputeInstanceId()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device Compute Instance ID: %v", ret)
}
@ -120,12 +121,12 @@ func (d *device) GetProfile() (Profile, error) {
continue
}
p, err := d.miglib.NewProfile(i, j, k, attributes.MemorySizeMB, parentMemoryInfo.Total)
p, err := m.lib.NewMigProfile(i, j, k, attributes.MemorySizeMB, parentMemoryInfo.Total)
if err != nil {
return nil, fmt.Errorf("error creating MIG profile: %v", err)
}
d.profile = p
m.profile = p
return p, nil
}
}

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package mig
package device
import (
"fmt"
@ -30,16 +30,16 @@ const (
AttributeMediaExtensions = "me"
)
// Profile represents a specific MIG profile.
// MigProfile represents a specific MIG profile.
// Examples include "1g.5gb", "2g.10gb", "1c.2g.10gb", or "1c.1g.5gb+me", etc.
type Profile interface {
type MigProfile interface {
String() string
GetInfo() ProfileInfo
Equals(other Profile) bool
GetInfo() MigProfileInfo
Equals(other MigProfile) bool
}
// ProfileInfo holds all info associated with a specific MIG profile
type ProfileInfo struct {
// MigProfileInfo holds all info associated with a specific MIG profile
type MigProfileInfo struct {
C int
G int
GB int
@ -49,10 +49,10 @@ type ProfileInfo struct {
CIEngProfileID int
}
var _ Profile = &ProfileInfo{}
var _ MigProfile = &MigProfileInfo{}
// NewProfile constructs a new Profile struct using info from the giProfiles and ciProfiles used to create it.
func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (Profile, error) {
func (d *devicelib) NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (MigProfile, error) {
giSlices := 0
switch giProfileID {
case nvml.GPU_INSTANCE_PROFILE_1_SLICE:
@ -101,7 +101,7 @@ func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMem
attrs = append(attrs, AttributeMediaExtensions)
}
p := &ProfileInfo{
p := &MigProfileInfo{
C: ciSlices,
G: giSlices,
GB: int(getMigMemorySizeGB(deviceMemorySizeBytes, migMemorySizeMB)),
@ -114,8 +114,8 @@ func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMem
return p, nil
}
// ParseProfile converts a string representation of a Profile into an object.
func (m *miglib) ParseProfile(profile string) (Profile, error) {
// ParseMigProfile converts a string representation of a MigProfile into an object
func (d *devicelib) ParseMigProfile(profile string) (MigProfile, error) {
var err error
var c, g, gb int
var attrs []string
@ -126,18 +126,18 @@ func (m *miglib) ParseProfile(profile string) (Profile, error) {
split := strings.SplitN(profile, "+", 2)
if len(split) == 2 {
attrs, err = parseProfileAttributes(split[1])
attrs, err = parseMigProfileAttributes(split[1])
if err != nil {
return nil, fmt.Errorf("error parsing attributes following '+' in Profile string: %v", err)
}
}
c, g, gb, err = parseProfileFields(split[0])
c, g, gb, err = parseMigProfileFields(split[0])
if err != nil {
return nil, fmt.Errorf("error parsing '.' separated fields in Profile string: %v", err)
}
p := &ProfileInfo{
p := &MigProfileInfo{
C: c,
G: g,
GB: gb,
@ -197,7 +197,7 @@ func (m *miglib) ParseProfile(profile string) (Profile, error) {
}
// String returns the string representation of a Profile
func (p *ProfileInfo) String() string {
func (p *MigProfileInfo) String() string {
var suffix string
if len(p.Attributes) > 0 {
suffix = "+" + strings.Join(p.Attributes, ",")
@ -209,14 +209,14 @@ func (p *ProfileInfo) String() string {
}
// GetInfo returns detailed info about a Profile
func (p *ProfileInfo) GetInfo() ProfileInfo {
func (p *MigProfileInfo) GetInfo() MigProfileInfo {
return *p
}
// Equals checks if two Profiles are identical or not
func (p *ProfileInfo) Equals(other Profile) bool {
func (p *MigProfileInfo) Equals(other MigProfile) bool {
switch o := other.(type) {
case *ProfileInfo:
case *MigProfileInfo:
if p.C != o.C {
return false
}
@ -240,7 +240,7 @@ func (p *ProfileInfo) Equals(other Profile) bool {
return false
}
func parseProfileField(s string, field string) (int, error) {
func parseMigProfileField(s string, field string) (int, error) {
if strings.TrimSpace(s) != s {
return -1, fmt.Errorf("leading or trailing spaces on '%%d%s'", field)
}
@ -257,32 +257,32 @@ func parseProfileField(s string, field string) (int, error) {
return v, nil
}
func parseProfileFields(s string) (int, int, int, error) {
func parseMigProfileFields(s string) (int, int, int, error) {
var err error
var c, g, gb int
split := strings.SplitN(s, ".", 3)
if len(split) == 3 {
c, err = parseProfileField(split[0], "c")
c, err = parseMigProfileField(split[0], "c")
if err != nil {
return -1, -1, -1, err
}
g, err = parseProfileField(split[1], "g")
g, err = parseMigProfileField(split[1], "g")
if err != nil {
return -1, -1, -1, err
}
gb, err = parseProfileField(split[2], "gb")
gb, err = parseMigProfileField(split[2], "gb")
if err != nil {
return -1, -1, -1, err
}
return c, g, gb, err
}
if len(split) == 2 {
g, err = parseProfileField(split[0], "g")
g, err = parseMigProfileField(split[0], "g")
if err != nil {
return -1, -1, -1, err
}
gb, err = parseProfileField(split[1], "gb")
gb, err = parseMigProfileField(split[1], "gb")
if err != nil {
return -1, -1, -1, err
}
@ -292,7 +292,7 @@ func parseProfileFields(s string) (int, int, int, error) {
return -1, -1, -1, fmt.Errorf("parsed wrong number of fields, expected 2 or 3")
}
func parseProfileAttributes(s string) ([]string, error) {
func parseMigProfileAttributes(s string) ([]string, error) {
attr := strings.Split(s, ",")
if len(attr) == 0 {
return nil, fmt.Errorf("empty attribute list")

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package mig
package device
import (
"fmt"
@ -256,10 +256,10 @@ func TestParseMigProfile(t *testing.T) {
},
}
m := New()
d := New()
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
_, err := m.ParseProfile(tc.device)
_, err := d.ParseMigProfile(tc.device)
if tc.valid {
require.Nil(t, err)
} else {