Merge branch 'add-device-apis' into 'main'

Move MIG apis to device package and add extended APIs for top-level devices to it

See merge request nvidia/cloud-native/go-nvlib!19
This commit is contained in:
Evan Lezar 2022-09-16 14:15:32 +00:00
commit 01649c65ea
5 changed files with 381 additions and 69 deletions

View File

@ -14,43 +14,50 @@
* limitations under the License. * limitations under the License.
*/ */
package mig package device
import ( import (
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
) )
// Interface provides the API to the mig package // Interface provides the API to the 'device' package
type Interface interface { type Interface interface {
NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (Profile, error) GetDevices() ([]Device, error)
ParseProfile(profile string) (Profile, error) GetMigDevices() ([]MigDevice, error)
GetMigProfiles() ([]MigProfile, error)
NewDevice(d nvml.Device) (Device, error) NewDevice(d nvml.Device) (Device, error)
NewMigDevice(d nvml.Device) (MigDevice, error)
NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (MigProfile, error)
ParseMigProfile(profile string) (MigProfile, error)
VisitDevices(func(i int, d Device) error) error
VisitMigDevices(func(i int, d Device, j int, m MigDevice) error) error
VisitMigProfiles(func(p MigProfile) error) error
} }
type miglib struct { type devicelib struct {
nvml nvml.Interface nvml nvml.Interface
} }
var _ Interface = &miglib{} var _ Interface = &devicelib{}
// New creates a new instance of the 'mig' interface // New creates a new instance of the 'device' interface
func New(opts ...Option) Interface { func New(opts ...Option) Interface {
m := &miglib{} d := &devicelib{}
for _, opt := range opts { for _, opt := range opts {
opt(m) opt(d)
} }
if m.nvml == nil { if d.nvml == nil {
m.nvml = nvml.New() d.nvml = nvml.New()
} }
return m return d
} }
// WithNvml provides an Option to set the NVML library used by the 'mig' interface // WithNvml provides an Option to set the NVML library used by the 'device' interface
func WithNvml(nvml nvml.Interface) Option { func WithNvml(nvml nvml.Interface) Option {
return func(m *miglib) { return func(d *devicelib) {
m.nvml = nvml d.nvml = nvml
} }
} }
// Option defines a function for passing options to the New() call // Option defines a function for passing options to the New() call
type Option func(*miglib) type Option func(*devicelib)

304
pkg/nvlib/device/device.go Normal file
View File

@ -0,0 +1,304 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package device
import (
"fmt"
"github.com/NVIDIA/go-nvml/pkg/dl"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
)
// Device defines the set of extended functions associated with a device.Device
type Device interface {
nvml.Device
GetMigDevices() ([]MigDevice, error)
GetMigProfiles() ([]MigProfile, error)
IsMigCapable() (bool, error)
IsMigEnabled() (bool, error)
VisitMigDevices(func(j int, m MigDevice) error) error
VisitMigProfiles(func(p MigProfile) error) error
}
type device struct {
nvml.Device
lib *devicelib
}
var _ Device = &device{}
// NewDevice builds a new Device from an nvml.Device
func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) {
return &device{dev, d}, nil
}
// IsMigCapable checks if a device is capable of having MIG paprtitions created on it
func (d *device) IsMigCapable() (bool, error) {
err := nvmlLookupSymbol("nvmlDeviceGetMigMode")
if err != nil {
return false, nil
}
_, _, ret := nvml.Device(d).GetMigMode()
if ret == nvml.ERROR_NOT_SUPPORTED {
return false, nil
}
if ret != nvml.SUCCESS {
return false, fmt.Errorf("error getting MIG mode: %v", ret)
}
return true, nil
}
// IsMigEnabled checks if a device has MIG mode currently enabled on it
func (d *device) IsMigEnabled() (bool, error) {
err := nvmlLookupSymbol("nvmlDeviceGetMigMode")
if err != nil {
return false, nil
}
mode, _, ret := nvml.Device(d).GetMigMode()
if ret == nvml.ERROR_NOT_SUPPORTED {
return false, nil
}
if ret != nvml.SUCCESS {
return false, fmt.Errorf("error getting MIG mode: %v", ret)
}
return (mode == nvml.DEVICE_MIG_ENABLE), nil
}
// VisitMigDevices walks a top-level device and invokes a callback function for each MIG device configured on it
func (d *device) VisitMigDevices(visit func(int, MigDevice) error) error {
count, ret := nvml.Device(d).GetMaxMigDeviceCount()
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting max MIG device count: %v", ret)
}
for i := 0; i < count; i++ {
device, ret := nvml.Device(d).GetMigDeviceHandleByIndex(i)
if ret == nvml.ERROR_NOT_FOUND {
continue
}
if ret == nvml.ERROR_INVALID_ARGUMENT {
continue
}
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting MIG device handle at index '%v': %v", i, ret)
}
mig, err := d.lib.NewMigDevice(device)
if err != nil {
return fmt.Errorf("error creating new MIG device wrapper: %v", err)
}
err = visit(i, mig)
if err != nil {
return fmt.Errorf("error visiting MIG device: %v", err)
}
}
return nil
}
// VisitMigProfiles walks a top-level device and invokes a callback function for each unique MIG Profile that can be configured on it
func (d *device) VisitMigProfiles(visit func(MigProfile) error) error {
capable, err := d.IsMigCapable()
if err != nil {
return fmt.Errorf("error checking if GPU is MIG capable: %v", err)
}
if !capable {
return nil
}
memory, ret := d.GetMemoryInfo()
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device memory info: %v", ret)
}
for i := 0; i < nvml.GPU_INSTANCE_PROFILE_COUNT; i++ {
giProfileInfo, ret := d.GetGpuInstanceProfileInfo(i)
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting GPU Instance profile info: %v", ret)
}
for j := 0; j < nvml.COMPUTE_INSTANCE_PROFILE_COUNT; j++ {
for k := 0; k < nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT; k++ {
p, err := d.lib.NewMigProfile(i, j, k, giProfileInfo.MemorySizeMB, memory.Total)
if err != nil {
return fmt.Errorf("error creating MIG profile: %v", err)
}
err = visit(p)
if err != nil {
return fmt.Errorf("error visiting MIG profile: %v", err)
}
}
}
}
return nil
}
// GetMigDevices gets the set of MIG devices associated with a top-level device
func (d *device) GetMigDevices() ([]MigDevice, error) {
var migs []MigDevice
err := d.VisitMigDevices(func(j int, m MigDevice) error {
migs = append(migs, m)
return nil
})
if err != nil {
return nil, err
}
return migs, nil
}
// GetMigProfiles gets the set of unique MIG profiles associated with a top-level device
func (d *device) GetMigProfiles() ([]MigProfile, error) {
var profiles []MigProfile
err := d.VisitMigProfiles(func(p MigProfile) error {
profiles = append(profiles, p)
return nil
})
if err != nil {
return nil, err
}
return profiles, nil
}
// VisitDevices visits each top-level device and invokes a callback function for it
func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
count, ret := d.nvml.DeviceGetCount()
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device count: %v", ret)
}
for i := 0; i < count; i++ {
device, ret := d.nvml.DeviceGetHandleByIndex(i)
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device handle for index '%v': %v", i, ret)
}
dev, err := d.NewDevice(device)
if err != nil {
return fmt.Errorf("error creating new device wrapper: %v", err)
}
err = visit(i, dev)
if err != nil {
return fmt.Errorf("error visiting device: %v", err)
}
}
return nil
}
// VisitMigDevices walks a top-level device and invokes a callback function for each MIG device configured on it
func (d *devicelib) VisitMigDevices(visit func(int, Device, int, MigDevice) error) error {
err := d.VisitDevices(func(i int, dev Device) error {
err := dev.VisitMigDevices(func(j int, mig MigDevice) error {
err := visit(i, dev, j, mig)
if err != nil {
return fmt.Errorf("error visiting MIG device: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("error visiting device: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("error visiting devices: %v", err)
}
return nil
}
// VisitMigProfiles walks a top-level device and invokes a callback function for each unique MIG profile found on them
func (d *devicelib) VisitMigProfiles(visit func(MigProfile) error) error {
visited := make(map[string]bool)
err := d.VisitDevices(func(i int, dev Device) error {
err := dev.VisitMigProfiles(func(p MigProfile) error {
if visited[p.String()] {
return nil
}
err := visit(p)
if err != nil {
return fmt.Errorf("error visiting MIG profile: %v", err)
}
visited[p.String()] = true
return nil
})
if err != nil {
return fmt.Errorf("error visiting device: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("error visiting devices: %v", err)
}
return nil
}
// GetDevices gets the set of all top-level devices
func (d *devicelib) GetDevices() ([]Device, error) {
var devs []Device
err := d.VisitDevices(func(i int, dev Device) error {
devs = append(devs, dev)
return nil
})
if err != nil {
return nil, err
}
return devs, nil
}
// GetMigDevices gets the set of MIG devices across all top-level devices
func (d *devicelib) GetMigDevices() ([]MigDevice, error) {
var migs []MigDevice
err := d.VisitMigDevices(func(i int, dev Device, j int, m MigDevice) error {
migs = append(migs, m)
return nil
})
if err != nil {
return nil, err
}
return migs, nil
}
// GetMigProfiles gets the set of unique MIG profiles across all top-level devices
func (d *devicelib) GetMigProfiles() ([]MigProfile, error) {
var profiles []MigProfile
err := d.VisitMigProfiles(func(p MigProfile) error {
profiles = append(profiles, p)
return nil
})
if err != nil {
return nil, err
}
return profiles, nil
}
// nvmlLookupSymbol checks to see if the given symbol is present in the NVML library
func nvmlLookupSymbol(symbol string) error {
lib := dl.New("libnvidia-ml.so.1", dl.RTLD_LAZY|dl.RTLD_GLOBAL)
if lib == nil {
return fmt.Errorf("error instantiating DynamicLibrary for NVML")
}
err := lib.Open()
if err != nil {
return fmt.Errorf("error opening DynamicLibrary for NVML: %v", err)
}
defer lib.Close()
return lib.Lookup(symbol)
}

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
package mig package device
import ( import (
"fmt" "fmt"
@ -22,38 +22,39 @@ import (
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
) )
// Device defines the set of extended functions associated with a mig.Device // MigDevice defines the set of extended functions associated with a MIG device
type Device interface { type MigDevice interface {
GetProfile() (Profile, error)
}
type device struct {
nvml.Device nvml.Device
miglib *miglib GetProfile() (MigProfile, error)
profile Profile
} }
var _ Device = &device{} type migdevice struct {
nvml.Device
lib *devicelib
profile MigProfile
}
// NewDevice builds a new Device from an nvml.Device var _ MigDevice = &migdevice{}
func (m *miglib) NewDevice(d nvml.Device) (Device, error) {
isMig, ret := d.IsMigDeviceHandle() // NewMigDevice builds a new MigDevice from an nvml.Device
func (d *devicelib) NewMigDevice(handle nvml.Device) (MigDevice, error) {
isMig, ret := handle.IsMigDeviceHandle()
if ret != nvml.SUCCESS { if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error checking if device is a MIG device: %v", ret) return nil, fmt.Errorf("error checking if device is a MIG device: %v", ret)
} }
if !isMig { if !isMig {
return nil, fmt.Errorf("not a MIG device") return nil, fmt.Errorf("not a MIG device")
} }
return &device{d, m, nil}, nil return &migdevice{handle, d, nil}, nil
} }
// GetProfile returns the MIG profile associated with a MIG device // GetProfile returns the MIG profile associated with a MIG device
func (d *device) GetProfile() (Profile, error) { func (m *migdevice) GetProfile() (MigProfile, error) {
if d.profile != nil { if m.profile != nil {
return d.profile, nil return m.profile, nil
} }
parent, ret := d.Device.GetDeviceHandleFromMigDeviceHandle() parent, ret := m.Device.GetDeviceHandleFromMigDeviceHandle()
if ret != nvml.SUCCESS { if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting parent device handle: %v", ret) return nil, fmt.Errorf("error getting parent device handle: %v", ret)
} }
@ -63,17 +64,17 @@ func (d *device) GetProfile() (Profile, error) {
return nil, fmt.Errorf("error getting parent memory info: %v", ret) return nil, fmt.Errorf("error getting parent memory info: %v", ret)
} }
attributes, ret := d.Device.GetAttributes() attributes, ret := m.Device.GetAttributes()
if ret != nvml.SUCCESS { if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device attributes: %v", ret) return nil, fmt.Errorf("error getting MIG device attributes: %v", ret)
} }
giID, ret := d.Device.GetGpuInstanceId() giID, ret := m.Device.GetGpuInstanceId()
if ret != nvml.SUCCESS { if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device GPU Instance ID: %v", ret) return nil, fmt.Errorf("error getting MIG device GPU Instance ID: %v", ret)
} }
ciID, ret := d.Device.GetComputeInstanceId() ciID, ret := m.Device.GetComputeInstanceId()
if ret != nvml.SUCCESS { if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device Compute Instance ID: %v", ret) return nil, fmt.Errorf("error getting MIG device Compute Instance ID: %v", ret)
} }
@ -120,12 +121,12 @@ func (d *device) GetProfile() (Profile, error) {
continue continue
} }
p, err := d.miglib.NewProfile(i, j, k, attributes.MemorySizeMB, parentMemoryInfo.Total) p, err := m.lib.NewMigProfile(i, j, k, attributes.MemorySizeMB, parentMemoryInfo.Total)
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating MIG profile: %v", err) return nil, fmt.Errorf("error creating MIG profile: %v", err)
} }
d.profile = p m.profile = p
return p, nil return p, nil
} }
} }

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
package mig package device
import ( import (
"fmt" "fmt"
@ -30,16 +30,16 @@ const (
AttributeMediaExtensions = "me" AttributeMediaExtensions = "me"
) )
// Profile represents a specific MIG profile. // MigProfile represents a specific MIG profile.
// Examples include "1g.5gb", "2g.10gb", "1c.2g.10gb", or "1c.1g.5gb+me", etc. // Examples include "1g.5gb", "2g.10gb", "1c.2g.10gb", or "1c.1g.5gb+me", etc.
type Profile interface { type MigProfile interface {
String() string String() string
GetInfo() ProfileInfo GetInfo() MigProfileInfo
Equals(other Profile) bool Equals(other MigProfile) bool
} }
// ProfileInfo holds all info associated with a specific MIG profile // MigProfileInfo holds all info associated with a specific MIG profile
type ProfileInfo struct { type MigProfileInfo struct {
C int C int
G int G int
GB int GB int
@ -49,10 +49,10 @@ type ProfileInfo struct {
CIEngProfileID int CIEngProfileID int
} }
var _ Profile = &ProfileInfo{} var _ MigProfile = &MigProfileInfo{}
// NewProfile constructs a new Profile struct using info from the giProfiles and ciProfiles used to create it. // NewProfile constructs a new Profile struct using info from the giProfiles and ciProfiles used to create it.
func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (Profile, error) { func (d *devicelib) NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (MigProfile, error) {
giSlices := 0 giSlices := 0
switch giProfileID { switch giProfileID {
case nvml.GPU_INSTANCE_PROFILE_1_SLICE: case nvml.GPU_INSTANCE_PROFILE_1_SLICE:
@ -101,7 +101,7 @@ func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMem
attrs = append(attrs, AttributeMediaExtensions) attrs = append(attrs, AttributeMediaExtensions)
} }
p := &ProfileInfo{ p := &MigProfileInfo{
C: ciSlices, C: ciSlices,
G: giSlices, G: giSlices,
GB: int(getMigMemorySizeGB(deviceMemorySizeBytes, migMemorySizeMB)), GB: int(getMigMemorySizeGB(deviceMemorySizeBytes, migMemorySizeMB)),
@ -114,8 +114,8 @@ func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMem
return p, nil return p, nil
} }
// ParseProfile converts a string representation of a Profile into an object. // ParseMigProfile converts a string representation of a MigProfile into an object
func (m *miglib) ParseProfile(profile string) (Profile, error) { func (d *devicelib) ParseMigProfile(profile string) (MigProfile, error) {
var err error var err error
var c, g, gb int var c, g, gb int
var attrs []string var attrs []string
@ -126,18 +126,18 @@ func (m *miglib) ParseProfile(profile string) (Profile, error) {
split := strings.SplitN(profile, "+", 2) split := strings.SplitN(profile, "+", 2)
if len(split) == 2 { if len(split) == 2 {
attrs, err = parseProfileAttributes(split[1]) attrs, err = parseMigProfileAttributes(split[1])
if err != nil { if err != nil {
return nil, fmt.Errorf("error parsing attributes following '+' in Profile string: %v", err) return nil, fmt.Errorf("error parsing attributes following '+' in Profile string: %v", err)
} }
} }
c, g, gb, err = parseProfileFields(split[0]) c, g, gb, err = parseMigProfileFields(split[0])
if err != nil { if err != nil {
return nil, fmt.Errorf("error parsing '.' separated fields in Profile string: %v", err) return nil, fmt.Errorf("error parsing '.' separated fields in Profile string: %v", err)
} }
p := &ProfileInfo{ p := &MigProfileInfo{
C: c, C: c,
G: g, G: g,
GB: gb, GB: gb,
@ -197,7 +197,7 @@ func (m *miglib) ParseProfile(profile string) (Profile, error) {
} }
// String returns the string representation of a Profile // String returns the string representation of a Profile
func (p *ProfileInfo) String() string { func (p *MigProfileInfo) String() string {
var suffix string var suffix string
if len(p.Attributes) > 0 { if len(p.Attributes) > 0 {
suffix = "+" + strings.Join(p.Attributes, ",") suffix = "+" + strings.Join(p.Attributes, ",")
@ -209,14 +209,14 @@ func (p *ProfileInfo) String() string {
} }
// GetInfo returns detailed info about a Profile // GetInfo returns detailed info about a Profile
func (p *ProfileInfo) GetInfo() ProfileInfo { func (p *MigProfileInfo) GetInfo() MigProfileInfo {
return *p return *p
} }
// Equals checks if two Profiles are identical or not // Equals checks if two Profiles are identical or not
func (p *ProfileInfo) Equals(other Profile) bool { func (p *MigProfileInfo) Equals(other MigProfile) bool {
switch o := other.(type) { switch o := other.(type) {
case *ProfileInfo: case *MigProfileInfo:
if p.C != o.C { if p.C != o.C {
return false return false
} }
@ -240,7 +240,7 @@ func (p *ProfileInfo) Equals(other Profile) bool {
return false return false
} }
func parseProfileField(s string, field string) (int, error) { func parseMigProfileField(s string, field string) (int, error) {
if strings.TrimSpace(s) != s { if strings.TrimSpace(s) != s {
return -1, fmt.Errorf("leading or trailing spaces on '%%d%s'", field) return -1, fmt.Errorf("leading or trailing spaces on '%%d%s'", field)
} }
@ -257,32 +257,32 @@ func parseProfileField(s string, field string) (int, error) {
return v, nil return v, nil
} }
func parseProfileFields(s string) (int, int, int, error) { func parseMigProfileFields(s string) (int, int, int, error) {
var err error var err error
var c, g, gb int var c, g, gb int
split := strings.SplitN(s, ".", 3) split := strings.SplitN(s, ".", 3)
if len(split) == 3 { if len(split) == 3 {
c, err = parseProfileField(split[0], "c") c, err = parseMigProfileField(split[0], "c")
if err != nil { if err != nil {
return -1, -1, -1, err return -1, -1, -1, err
} }
g, err = parseProfileField(split[1], "g") g, err = parseMigProfileField(split[1], "g")
if err != nil { if err != nil {
return -1, -1, -1, err return -1, -1, -1, err
} }
gb, err = parseProfileField(split[2], "gb") gb, err = parseMigProfileField(split[2], "gb")
if err != nil { if err != nil {
return -1, -1, -1, err return -1, -1, -1, err
} }
return c, g, gb, err return c, g, gb, err
} }
if len(split) == 2 { if len(split) == 2 {
g, err = parseProfileField(split[0], "g") g, err = parseMigProfileField(split[0], "g")
if err != nil { if err != nil {
return -1, -1, -1, err return -1, -1, -1, err
} }
gb, err = parseProfileField(split[1], "gb") gb, err = parseMigProfileField(split[1], "gb")
if err != nil { if err != nil {
return -1, -1, -1, err return -1, -1, -1, err
} }
@ -292,7 +292,7 @@ func parseProfileFields(s string) (int, int, int, error) {
return -1, -1, -1, fmt.Errorf("parsed wrong number of fields, expected 2 or 3") return -1, -1, -1, fmt.Errorf("parsed wrong number of fields, expected 2 or 3")
} }
func parseProfileAttributes(s string) ([]string, error) { func parseMigProfileAttributes(s string) ([]string, error) {
attr := strings.Split(s, ",") attr := strings.Split(s, ",")
if len(attr) == 0 { if len(attr) == 0 {
return nil, fmt.Errorf("empty attribute list") return nil, fmt.Errorf("empty attribute list")

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
package mig package device
import ( import (
"fmt" "fmt"
@ -256,10 +256,10 @@ func TestParseMigProfile(t *testing.T) {
}, },
} }
m := New() d := New()
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
_, err := m.ParseProfile(tc.device) _, err := d.ParseMigProfile(tc.device)
if tc.valid { if tc.valid {
require.Nil(t, err) require.Nil(t, err)
} else { } else {