Move MIG apis to device package

We decided it makes sense to have top level device and MIG device abstractions
all under one package rather than trying to separate them. It will make it
easier to hav them clal between each other without package dependency loops.

Signed-off-by: Kevin Klues <kklues@nvidia.com>
This commit is contained in:
Kevin Klues 2022-09-16 10:04:59 +00:00
parent 8719e258a8
commit 1d680a93b6
4 changed files with 71 additions and 70 deletions

View File

@ -14,43 +14,43 @@
* limitations under the License. * limitations under the License.
*/ */
package mig package device
import ( import (
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
) )
// Interface provides the API to the mig package // Interface provides the API to the 'device' package
type Interface interface { type Interface interface {
NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (Profile, error) NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (MigProfile, error)
ParseProfile(profile string) (Profile, error) ParseMigProfile(profile string) (MigProfile, error)
NewDevice(d nvml.Device) (Device, error) NewMigDevice(d nvml.Device) (MigDevice, error)
} }
type miglib struct { type devicelib struct {
nvml nvml.Interface nvml nvml.Interface
} }
var _ Interface = &miglib{} var _ Interface = &devicelib{}
// New creates a new instance of the 'mig' interface // New creates a new instance of the 'device' interface
func New(opts ...Option) Interface { func New(opts ...Option) Interface {
m := &miglib{} d := &devicelib{}
for _, opt := range opts { for _, opt := range opts {
opt(m) opt(d)
} }
if m.nvml == nil { if d.nvml == nil {
m.nvml = nvml.New() d.nvml = nvml.New()
} }
return m return d
} }
// WithNvml provides an Option to set the NVML library used by the 'mig' interface // WithNvml provides an Option to set the NVML library used by the 'device' interface
func WithNvml(nvml nvml.Interface) Option { func WithNvml(nvml nvml.Interface) Option {
return func(m *miglib) { return func(d *devicelib) {
m.nvml = nvml d.nvml = nvml
} }
} }
// Option defines a function for passing options to the New() call // Option defines a function for passing options to the New() call
type Option func(*miglib) type Option func(*devicelib)

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
package mig package device
import ( import (
"fmt" "fmt"
@ -22,38 +22,39 @@ import (
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
) )
// Device defines the set of extended functions associated with a mig.Device // MigDevice defines the set of extended functions associated with a MIG device
type Device interface { type MigDevice interface {
GetProfile() (Profile, error)
}
type device struct {
nvml.Device nvml.Device
miglib *miglib GetProfile() (MigProfile, error)
profile Profile
} }
var _ Device = &device{} type migdevice struct {
nvml.Device
lib *devicelib
profile MigProfile
}
// NewDevice builds a new Device from an nvml.Device var _ MigDevice = &migdevice{}
func (m *miglib) NewDevice(d nvml.Device) (Device, error) {
isMig, ret := d.IsMigDeviceHandle() // NewMigDevice builds a new MigDevice from an nvml.Device
func (d *devicelib) NewMigDevice(handle nvml.Device) (MigDevice, error) {
isMig, ret := handle.IsMigDeviceHandle()
if ret != nvml.SUCCESS { if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error checking if device is a MIG device: %v", ret) return nil, fmt.Errorf("error checking if device is a MIG device: %v", ret)
} }
if !isMig { if !isMig {
return nil, fmt.Errorf("not a MIG device") return nil, fmt.Errorf("not a MIG device")
} }
return &device{d, m, nil}, nil return &migdevice{handle, d, nil}, nil
} }
// GetProfile returns the MIG profile associated with a MIG device // GetProfile returns the MIG profile associated with a MIG device
func (d *device) GetProfile() (Profile, error) { func (m *migdevice) GetProfile() (MigProfile, error) {
if d.profile != nil { if m.profile != nil {
return d.profile, nil return m.profile, nil
} }
parent, ret := d.Device.GetDeviceHandleFromMigDeviceHandle() parent, ret := m.Device.GetDeviceHandleFromMigDeviceHandle()
if ret != nvml.SUCCESS { if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting parent device handle: %v", ret) return nil, fmt.Errorf("error getting parent device handle: %v", ret)
} }
@ -63,17 +64,17 @@ func (d *device) GetProfile() (Profile, error) {
return nil, fmt.Errorf("error getting parent memory info: %v", ret) return nil, fmt.Errorf("error getting parent memory info: %v", ret)
} }
attributes, ret := d.Device.GetAttributes() attributes, ret := m.Device.GetAttributes()
if ret != nvml.SUCCESS { if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device attributes: %v", ret) return nil, fmt.Errorf("error getting MIG device attributes: %v", ret)
} }
giID, ret := d.Device.GetGpuInstanceId() giID, ret := m.Device.GetGpuInstanceId()
if ret != nvml.SUCCESS { if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device GPU Instance ID: %v", ret) return nil, fmt.Errorf("error getting MIG device GPU Instance ID: %v", ret)
} }
ciID, ret := d.Device.GetComputeInstanceId() ciID, ret := m.Device.GetComputeInstanceId()
if ret != nvml.SUCCESS { if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device Compute Instance ID: %v", ret) return nil, fmt.Errorf("error getting MIG device Compute Instance ID: %v", ret)
} }
@ -120,12 +121,12 @@ func (d *device) GetProfile() (Profile, error) {
continue continue
} }
p, err := d.miglib.NewProfile(i, j, k, attributes.MemorySizeMB, parentMemoryInfo.Total) p, err := m.lib.NewMigProfile(i, j, k, attributes.MemorySizeMB, parentMemoryInfo.Total)
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating MIG profile: %v", err) return nil, fmt.Errorf("error creating MIG profile: %v", err)
} }
d.profile = p m.profile = p
return p, nil return p, nil
} }
} }

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
package mig package device
import ( import (
"fmt" "fmt"
@ -30,16 +30,16 @@ const (
AttributeMediaExtensions = "me" AttributeMediaExtensions = "me"
) )
// Profile represents a specific MIG profile. // MigProfile represents a specific MIG profile.
// Examples include "1g.5gb", "2g.10gb", "1c.2g.10gb", or "1c.1g.5gb+me", etc. // Examples include "1g.5gb", "2g.10gb", "1c.2g.10gb", or "1c.1g.5gb+me", etc.
type Profile interface { type MigProfile interface {
String() string String() string
GetInfo() ProfileInfo GetInfo() MigProfileInfo
Equals(other Profile) bool Equals(other MigProfile) bool
} }
// ProfileInfo holds all info associated with a specific MIG profile // MigProfileInfo holds all info associated with a specific MIG profile
type ProfileInfo struct { type MigProfileInfo struct {
C int C int
G int G int
GB int GB int
@ -49,10 +49,10 @@ type ProfileInfo struct {
CIEngProfileID int CIEngProfileID int
} }
var _ Profile = &ProfileInfo{} var _ MigProfile = &MigProfileInfo{}
// NewProfile constructs a new Profile struct using info from the giProfiles and ciProfiles used to create it. // NewProfile constructs a new Profile struct using info from the giProfiles and ciProfiles used to create it.
func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (Profile, error) { func (d *devicelib) NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (MigProfile, error) {
giSlices := 0 giSlices := 0
switch giProfileID { switch giProfileID {
case nvml.GPU_INSTANCE_PROFILE_1_SLICE: case nvml.GPU_INSTANCE_PROFILE_1_SLICE:
@ -101,7 +101,7 @@ func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMem
attrs = append(attrs, AttributeMediaExtensions) attrs = append(attrs, AttributeMediaExtensions)
} }
p := &ProfileInfo{ p := &MigProfileInfo{
C: ciSlices, C: ciSlices,
G: giSlices, G: giSlices,
GB: int(getMigMemorySizeGB(deviceMemorySizeBytes, migMemorySizeMB)), GB: int(getMigMemorySizeGB(deviceMemorySizeBytes, migMemorySizeMB)),
@ -114,8 +114,8 @@ func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMem
return p, nil return p, nil
} }
// ParseProfile converts a string representation of a Profile into an object. // ParseMigProfile converts a string representation of a MigProfile into an object
func (m *miglib) ParseProfile(profile string) (Profile, error) { func (d *devicelib) ParseMigProfile(profile string) (MigProfile, error) {
var err error var err error
var c, g, gb int var c, g, gb int
var attrs []string var attrs []string
@ -126,18 +126,18 @@ func (m *miglib) ParseProfile(profile string) (Profile, error) {
split := strings.SplitN(profile, "+", 2) split := strings.SplitN(profile, "+", 2)
if len(split) == 2 { if len(split) == 2 {
attrs, err = parseProfileAttributes(split[1]) attrs, err = parseMigProfileAttributes(split[1])
if err != nil { if err != nil {
return nil, fmt.Errorf("error parsing attributes following '+' in Profile string: %v", err) return nil, fmt.Errorf("error parsing attributes following '+' in Profile string: %v", err)
} }
} }
c, g, gb, err = parseProfileFields(split[0]) c, g, gb, err = parseMigProfileFields(split[0])
if err != nil { if err != nil {
return nil, fmt.Errorf("error parsing '.' separated fields in Profile string: %v", err) return nil, fmt.Errorf("error parsing '.' separated fields in Profile string: %v", err)
} }
p := &ProfileInfo{ p := &MigProfileInfo{
C: c, C: c,
G: g, G: g,
GB: gb, GB: gb,
@ -197,7 +197,7 @@ func (m *miglib) ParseProfile(profile string) (Profile, error) {
} }
// String returns the string representation of a Profile // String returns the string representation of a Profile
func (p *ProfileInfo) String() string { func (p *MigProfileInfo) String() string {
var suffix string var suffix string
if len(p.Attributes) > 0 { if len(p.Attributes) > 0 {
suffix = "+" + strings.Join(p.Attributes, ",") suffix = "+" + strings.Join(p.Attributes, ",")
@ -209,14 +209,14 @@ func (p *ProfileInfo) String() string {
} }
// GetInfo returns detailed info about a Profile // GetInfo returns detailed info about a Profile
func (p *ProfileInfo) GetInfo() ProfileInfo { func (p *MigProfileInfo) GetInfo() MigProfileInfo {
return *p return *p
} }
// Equals checks if two Profiles are identical or not // Equals checks if two Profiles are identical or not
func (p *ProfileInfo) Equals(other Profile) bool { func (p *MigProfileInfo) Equals(other MigProfile) bool {
switch o := other.(type) { switch o := other.(type) {
case *ProfileInfo: case *MigProfileInfo:
if p.C != o.C { if p.C != o.C {
return false return false
} }
@ -240,7 +240,7 @@ func (p *ProfileInfo) Equals(other Profile) bool {
return false return false
} }
func parseProfileField(s string, field string) (int, error) { func parseMigProfileField(s string, field string) (int, error) {
if strings.TrimSpace(s) != s { if strings.TrimSpace(s) != s {
return -1, fmt.Errorf("leading or trailing spaces on '%%d%s'", field) return -1, fmt.Errorf("leading or trailing spaces on '%%d%s'", field)
} }
@ -257,32 +257,32 @@ func parseProfileField(s string, field string) (int, error) {
return v, nil return v, nil
} }
func parseProfileFields(s string) (int, int, int, error) { func parseMigProfileFields(s string) (int, int, int, error) {
var err error var err error
var c, g, gb int var c, g, gb int
split := strings.SplitN(s, ".", 3) split := strings.SplitN(s, ".", 3)
if len(split) == 3 { if len(split) == 3 {
c, err = parseProfileField(split[0], "c") c, err = parseMigProfileField(split[0], "c")
if err != nil { if err != nil {
return -1, -1, -1, err return -1, -1, -1, err
} }
g, err = parseProfileField(split[1], "g") g, err = parseMigProfileField(split[1], "g")
if err != nil { if err != nil {
return -1, -1, -1, err return -1, -1, -1, err
} }
gb, err = parseProfileField(split[2], "gb") gb, err = parseMigProfileField(split[2], "gb")
if err != nil { if err != nil {
return -1, -1, -1, err return -1, -1, -1, err
} }
return c, g, gb, err return c, g, gb, err
} }
if len(split) == 2 { if len(split) == 2 {
g, err = parseProfileField(split[0], "g") g, err = parseMigProfileField(split[0], "g")
if err != nil { if err != nil {
return -1, -1, -1, err return -1, -1, -1, err
} }
gb, err = parseProfileField(split[1], "gb") gb, err = parseMigProfileField(split[1], "gb")
if err != nil { if err != nil {
return -1, -1, -1, err return -1, -1, -1, err
} }
@ -292,7 +292,7 @@ func parseProfileFields(s string) (int, int, int, error) {
return -1, -1, -1, fmt.Errorf("parsed wrong number of fields, expected 2 or 3") return -1, -1, -1, fmt.Errorf("parsed wrong number of fields, expected 2 or 3")
} }
func parseProfileAttributes(s string) ([]string, error) { func parseMigProfileAttributes(s string) ([]string, error) {
attr := strings.Split(s, ",") attr := strings.Split(s, ",")
if len(attr) == 0 { if len(attr) == 0 {
return nil, fmt.Errorf("empty attribute list") return nil, fmt.Errorf("empty attribute list")

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
package mig package device
import ( import (
"fmt" "fmt"
@ -256,10 +256,10 @@ func TestParseMigProfile(t *testing.T) {
}, },
} }
m := New() d := New()
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
_, err := m.ParseProfile(tc.device) _, err := d.ParseMigProfile(tc.device)
if tc.valid { if tc.valid {
require.Nil(t, err) require.Nil(t, err)
} else { } else {