Merge branch 'add-device-apis' into 'main'

Move MIG apis to device package and add extended APIs for top-level devices to it

See merge request nvidia/cloud-native/go-nvlib!19
This commit is contained in:
Evan Lezar 2022-09-16 14:15:32 +00:00
commit 01649c65ea
5 changed files with 381 additions and 69 deletions

View File

@ -14,43 +14,50 @@
* limitations under the License.
*/
package mig
package device
import (
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
)
// Interface provides the API to the mig package
// Interface provides the API to the 'device' package
type Interface interface {
NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (Profile, error)
ParseProfile(profile string) (Profile, error)
GetDevices() ([]Device, error)
GetMigDevices() ([]MigDevice, error)
GetMigProfiles() ([]MigProfile, error)
NewDevice(d nvml.Device) (Device, error)
NewMigDevice(d nvml.Device) (MigDevice, error)
NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (MigProfile, error)
ParseMigProfile(profile string) (MigProfile, error)
VisitDevices(func(i int, d Device) error) error
VisitMigDevices(func(i int, d Device, j int, m MigDevice) error) error
VisitMigProfiles(func(p MigProfile) error) error
}
type miglib struct {
type devicelib struct {
nvml nvml.Interface
}
var _ Interface = &miglib{}
var _ Interface = &devicelib{}
// New creates a new instance of the 'mig' interface
// New creates a new instance of the 'device' interface
func New(opts ...Option) Interface {
m := &miglib{}
d := &devicelib{}
for _, opt := range opts {
opt(m)
opt(d)
}
if m.nvml == nil {
m.nvml = nvml.New()
if d.nvml == nil {
d.nvml = nvml.New()
}
return m
return d
}
// WithNvml provides an Option to set the NVML library used by the 'mig' interface
// WithNvml provides an Option to set the NVML library used by the 'device' interface
func WithNvml(nvml nvml.Interface) Option {
return func(m *miglib) {
m.nvml = nvml
return func(d *devicelib) {
d.nvml = nvml
}
}
// Option defines a function for passing options to the New() call
type Option func(*miglib)
type Option func(*devicelib)

304
pkg/nvlib/device/device.go Normal file
View File

@ -0,0 +1,304 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package device
import (
"fmt"
"github.com/NVIDIA/go-nvml/pkg/dl"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
)
// Device defines the set of extended functions associated with a device.Device
type Device interface {
nvml.Device
GetMigDevices() ([]MigDevice, error)
GetMigProfiles() ([]MigProfile, error)
IsMigCapable() (bool, error)
IsMigEnabled() (bool, error)
VisitMigDevices(func(j int, m MigDevice) error) error
VisitMigProfiles(func(p MigProfile) error) error
}
type device struct {
nvml.Device
lib *devicelib
}
var _ Device = &device{}
// NewDevice builds a new Device from an nvml.Device
func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) {
return &device{dev, d}, nil
}
// IsMigCapable checks if a device is capable of having MIG paprtitions created on it
func (d *device) IsMigCapable() (bool, error) {
err := nvmlLookupSymbol("nvmlDeviceGetMigMode")
if err != nil {
return false, nil
}
_, _, ret := nvml.Device(d).GetMigMode()
if ret == nvml.ERROR_NOT_SUPPORTED {
return false, nil
}
if ret != nvml.SUCCESS {
return false, fmt.Errorf("error getting MIG mode: %v", ret)
}
return true, nil
}
// IsMigEnabled checks if a device has MIG mode currently enabled on it
func (d *device) IsMigEnabled() (bool, error) {
err := nvmlLookupSymbol("nvmlDeviceGetMigMode")
if err != nil {
return false, nil
}
mode, _, ret := nvml.Device(d).GetMigMode()
if ret == nvml.ERROR_NOT_SUPPORTED {
return false, nil
}
if ret != nvml.SUCCESS {
return false, fmt.Errorf("error getting MIG mode: %v", ret)
}
return (mode == nvml.DEVICE_MIG_ENABLE), nil
}
// VisitMigDevices walks a top-level device and invokes a callback function for each MIG device configured on it
func (d *device) VisitMigDevices(visit func(int, MigDevice) error) error {
count, ret := nvml.Device(d).GetMaxMigDeviceCount()
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting max MIG device count: %v", ret)
}
for i := 0; i < count; i++ {
device, ret := nvml.Device(d).GetMigDeviceHandleByIndex(i)
if ret == nvml.ERROR_NOT_FOUND {
continue
}
if ret == nvml.ERROR_INVALID_ARGUMENT {
continue
}
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting MIG device handle at index '%v': %v", i, ret)
}
mig, err := d.lib.NewMigDevice(device)
if err != nil {
return fmt.Errorf("error creating new MIG device wrapper: %v", err)
}
err = visit(i, mig)
if err != nil {
return fmt.Errorf("error visiting MIG device: %v", err)
}
}
return nil
}
// VisitMigProfiles walks a top-level device and invokes a callback function for each unique MIG Profile that can be configured on it
func (d *device) VisitMigProfiles(visit func(MigProfile) error) error {
capable, err := d.IsMigCapable()
if err != nil {
return fmt.Errorf("error checking if GPU is MIG capable: %v", err)
}
if !capable {
return nil
}
memory, ret := d.GetMemoryInfo()
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device memory info: %v", ret)
}
for i := 0; i < nvml.GPU_INSTANCE_PROFILE_COUNT; i++ {
giProfileInfo, ret := d.GetGpuInstanceProfileInfo(i)
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting GPU Instance profile info: %v", ret)
}
for j := 0; j < nvml.COMPUTE_INSTANCE_PROFILE_COUNT; j++ {
for k := 0; k < nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT; k++ {
p, err := d.lib.NewMigProfile(i, j, k, giProfileInfo.MemorySizeMB, memory.Total)
if err != nil {
return fmt.Errorf("error creating MIG profile: %v", err)
}
err = visit(p)
if err != nil {
return fmt.Errorf("error visiting MIG profile: %v", err)
}
}
}
}
return nil
}
// GetMigDevices gets the set of MIG devices associated with a top-level device
func (d *device) GetMigDevices() ([]MigDevice, error) {
var migs []MigDevice
err := d.VisitMigDevices(func(j int, m MigDevice) error {
migs = append(migs, m)
return nil
})
if err != nil {
return nil, err
}
return migs, nil
}
// GetMigProfiles gets the set of unique MIG profiles associated with a top-level device
func (d *device) GetMigProfiles() ([]MigProfile, error) {
var profiles []MigProfile
err := d.VisitMigProfiles(func(p MigProfile) error {
profiles = append(profiles, p)
return nil
})
if err != nil {
return nil, err
}
return profiles, nil
}
// VisitDevices visits each top-level device and invokes a callback function for it
func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
count, ret := d.nvml.DeviceGetCount()
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device count: %v", ret)
}
for i := 0; i < count; i++ {
device, ret := d.nvml.DeviceGetHandleByIndex(i)
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device handle for index '%v': %v", i, ret)
}
dev, err := d.NewDevice(device)
if err != nil {
return fmt.Errorf("error creating new device wrapper: %v", err)
}
err = visit(i, dev)
if err != nil {
return fmt.Errorf("error visiting device: %v", err)
}
}
return nil
}
// VisitMigDevices walks a top-level device and invokes a callback function for each MIG device configured on it
func (d *devicelib) VisitMigDevices(visit func(int, Device, int, MigDevice) error) error {
err := d.VisitDevices(func(i int, dev Device) error {
err := dev.VisitMigDevices(func(j int, mig MigDevice) error {
err := visit(i, dev, j, mig)
if err != nil {
return fmt.Errorf("error visiting MIG device: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("error visiting device: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("error visiting devices: %v", err)
}
return nil
}
// VisitMigProfiles walks a top-level device and invokes a callback function for each unique MIG profile found on them
func (d *devicelib) VisitMigProfiles(visit func(MigProfile) error) error {
visited := make(map[string]bool)
err := d.VisitDevices(func(i int, dev Device) error {
err := dev.VisitMigProfiles(func(p MigProfile) error {
if visited[p.String()] {
return nil
}
err := visit(p)
if err != nil {
return fmt.Errorf("error visiting MIG profile: %v", err)
}
visited[p.String()] = true
return nil
})
if err != nil {
return fmt.Errorf("error visiting device: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("error visiting devices: %v", err)
}
return nil
}
// GetDevices gets the set of all top-level devices
func (d *devicelib) GetDevices() ([]Device, error) {
var devs []Device
err := d.VisitDevices(func(i int, dev Device) error {
devs = append(devs, dev)
return nil
})
if err != nil {
return nil, err
}
return devs, nil
}
// GetMigDevices gets the set of MIG devices across all top-level devices
func (d *devicelib) GetMigDevices() ([]MigDevice, error) {
var migs []MigDevice
err := d.VisitMigDevices(func(i int, dev Device, j int, m MigDevice) error {
migs = append(migs, m)
return nil
})
if err != nil {
return nil, err
}
return migs, nil
}
// GetMigProfiles gets the set of unique MIG profiles across all top-level devices
func (d *devicelib) GetMigProfiles() ([]MigProfile, error) {
var profiles []MigProfile
err := d.VisitMigProfiles(func(p MigProfile) error {
profiles = append(profiles, p)
return nil
})
if err != nil {
return nil, err
}
return profiles, nil
}
// nvmlLookupSymbol checks to see if the given symbol is present in the NVML library
func nvmlLookupSymbol(symbol string) error {
lib := dl.New("libnvidia-ml.so.1", dl.RTLD_LAZY|dl.RTLD_GLOBAL)
if lib == nil {
return fmt.Errorf("error instantiating DynamicLibrary for NVML")
}
err := lib.Open()
if err != nil {
return fmt.Errorf("error opening DynamicLibrary for NVML: %v", err)
}
defer lib.Close()
return lib.Lookup(symbol)
}

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package mig
package device
import (
"fmt"
@ -22,38 +22,39 @@ import (
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
)
// Device defines the set of extended functions associated with a mig.Device
type Device interface {
GetProfile() (Profile, error)
}
type device struct {
// MigDevice defines the set of extended functions associated with a MIG device
type MigDevice interface {
nvml.Device
miglib *miglib
profile Profile
GetProfile() (MigProfile, error)
}
var _ Device = &device{}
type migdevice struct {
nvml.Device
lib *devicelib
profile MigProfile
}
// NewDevice builds a new Device from an nvml.Device
func (m *miglib) NewDevice(d nvml.Device) (Device, error) {
isMig, ret := d.IsMigDeviceHandle()
var _ MigDevice = &migdevice{}
// NewMigDevice builds a new MigDevice from an nvml.Device
func (d *devicelib) NewMigDevice(handle nvml.Device) (MigDevice, error) {
isMig, ret := handle.IsMigDeviceHandle()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error checking if device is a MIG device: %v", ret)
}
if !isMig {
return nil, fmt.Errorf("not a MIG device")
}
return &device{d, m, nil}, nil
return &migdevice{handle, d, nil}, nil
}
// GetProfile returns the MIG profile associated with a MIG device
func (d *device) GetProfile() (Profile, error) {
if d.profile != nil {
return d.profile, nil
func (m *migdevice) GetProfile() (MigProfile, error) {
if m.profile != nil {
return m.profile, nil
}
parent, ret := d.Device.GetDeviceHandleFromMigDeviceHandle()
parent, ret := m.Device.GetDeviceHandleFromMigDeviceHandle()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting parent device handle: %v", ret)
}
@ -63,17 +64,17 @@ func (d *device) GetProfile() (Profile, error) {
return nil, fmt.Errorf("error getting parent memory info: %v", ret)
}
attributes, ret := d.Device.GetAttributes()
attributes, ret := m.Device.GetAttributes()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device attributes: %v", ret)
}
giID, ret := d.Device.GetGpuInstanceId()
giID, ret := m.Device.GetGpuInstanceId()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device GPU Instance ID: %v", ret)
}
ciID, ret := d.Device.GetComputeInstanceId()
ciID, ret := m.Device.GetComputeInstanceId()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device Compute Instance ID: %v", ret)
}
@ -120,12 +121,12 @@ func (d *device) GetProfile() (Profile, error) {
continue
}
p, err := d.miglib.NewProfile(i, j, k, attributes.MemorySizeMB, parentMemoryInfo.Total)
p, err := m.lib.NewMigProfile(i, j, k, attributes.MemorySizeMB, parentMemoryInfo.Total)
if err != nil {
return nil, fmt.Errorf("error creating MIG profile: %v", err)
}
d.profile = p
m.profile = p
return p, nil
}
}

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package mig
package device
import (
"fmt"
@ -30,16 +30,16 @@ const (
AttributeMediaExtensions = "me"
)
// Profile represents a specific MIG profile.
// MigProfile represents a specific MIG profile.
// Examples include "1g.5gb", "2g.10gb", "1c.2g.10gb", or "1c.1g.5gb+me", etc.
type Profile interface {
type MigProfile interface {
String() string
GetInfo() ProfileInfo
Equals(other Profile) bool
GetInfo() MigProfileInfo
Equals(other MigProfile) bool
}
// ProfileInfo holds all info associated with a specific MIG profile
type ProfileInfo struct {
// MigProfileInfo holds all info associated with a specific MIG profile
type MigProfileInfo struct {
C int
G int
GB int
@ -49,10 +49,10 @@ type ProfileInfo struct {
CIEngProfileID int
}
var _ Profile = &ProfileInfo{}
var _ MigProfile = &MigProfileInfo{}
// NewProfile constructs a new Profile struct using info from the giProfiles and ciProfiles used to create it.
func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (Profile, error) {
func (d *devicelib) NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (MigProfile, error) {
giSlices := 0
switch giProfileID {
case nvml.GPU_INSTANCE_PROFILE_1_SLICE:
@ -101,7 +101,7 @@ func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMem
attrs = append(attrs, AttributeMediaExtensions)
}
p := &ProfileInfo{
p := &MigProfileInfo{
C: ciSlices,
G: giSlices,
GB: int(getMigMemorySizeGB(deviceMemorySizeBytes, migMemorySizeMB)),
@ -114,8 +114,8 @@ func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMem
return p, nil
}
// ParseProfile converts a string representation of a Profile into an object.
func (m *miglib) ParseProfile(profile string) (Profile, error) {
// ParseMigProfile converts a string representation of a MigProfile into an object
func (d *devicelib) ParseMigProfile(profile string) (MigProfile, error) {
var err error
var c, g, gb int
var attrs []string
@ -126,18 +126,18 @@ func (m *miglib) ParseProfile(profile string) (Profile, error) {
split := strings.SplitN(profile, "+", 2)
if len(split) == 2 {
attrs, err = parseProfileAttributes(split[1])
attrs, err = parseMigProfileAttributes(split[1])
if err != nil {
return nil, fmt.Errorf("error parsing attributes following '+' in Profile string: %v", err)
}
}
c, g, gb, err = parseProfileFields(split[0])
c, g, gb, err = parseMigProfileFields(split[0])
if err != nil {
return nil, fmt.Errorf("error parsing '.' separated fields in Profile string: %v", err)
}
p := &ProfileInfo{
p := &MigProfileInfo{
C: c,
G: g,
GB: gb,
@ -197,7 +197,7 @@ func (m *miglib) ParseProfile(profile string) (Profile, error) {
}
// String returns the string representation of a Profile
func (p *ProfileInfo) String() string {
func (p *MigProfileInfo) String() string {
var suffix string
if len(p.Attributes) > 0 {
suffix = "+" + strings.Join(p.Attributes, ",")
@ -209,14 +209,14 @@ func (p *ProfileInfo) String() string {
}
// GetInfo returns detailed info about a Profile
func (p *ProfileInfo) GetInfo() ProfileInfo {
func (p *MigProfileInfo) GetInfo() MigProfileInfo {
return *p
}
// Equals checks if two Profiles are identical or not
func (p *ProfileInfo) Equals(other Profile) bool {
func (p *MigProfileInfo) Equals(other MigProfile) bool {
switch o := other.(type) {
case *ProfileInfo:
case *MigProfileInfo:
if p.C != o.C {
return false
}
@ -240,7 +240,7 @@ func (p *ProfileInfo) Equals(other Profile) bool {
return false
}
func parseProfileField(s string, field string) (int, error) {
func parseMigProfileField(s string, field string) (int, error) {
if strings.TrimSpace(s) != s {
return -1, fmt.Errorf("leading or trailing spaces on '%%d%s'", field)
}
@ -257,32 +257,32 @@ func parseProfileField(s string, field string) (int, error) {
return v, nil
}
func parseProfileFields(s string) (int, int, int, error) {
func parseMigProfileFields(s string) (int, int, int, error) {
var err error
var c, g, gb int
split := strings.SplitN(s, ".", 3)
if len(split) == 3 {
c, err = parseProfileField(split[0], "c")
c, err = parseMigProfileField(split[0], "c")
if err != nil {
return -1, -1, -1, err
}
g, err = parseProfileField(split[1], "g")
g, err = parseMigProfileField(split[1], "g")
if err != nil {
return -1, -1, -1, err
}
gb, err = parseProfileField(split[2], "gb")
gb, err = parseMigProfileField(split[2], "gb")
if err != nil {
return -1, -1, -1, err
}
return c, g, gb, err
}
if len(split) == 2 {
g, err = parseProfileField(split[0], "g")
g, err = parseMigProfileField(split[0], "g")
if err != nil {
return -1, -1, -1, err
}
gb, err = parseProfileField(split[1], "gb")
gb, err = parseMigProfileField(split[1], "gb")
if err != nil {
return -1, -1, -1, err
}
@ -292,7 +292,7 @@ func parseProfileFields(s string) (int, int, int, error) {
return -1, -1, -1, fmt.Errorf("parsed wrong number of fields, expected 2 or 3")
}
func parseProfileAttributes(s string) ([]string, error) {
func parseMigProfileAttributes(s string) ([]string, error) {
attr := strings.Split(s, ",")
if len(attr) == 0 {
return nil, fmt.Errorf("empty attribute list")

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package mig
package device
import (
"fmt"
@ -256,10 +256,10 @@ func TestParseMigProfile(t *testing.T) {
},
}
m := New()
d := New()
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
_, err := m.ParseProfile(tc.device)
_, err := d.ParseMigProfile(tc.device)
if tc.valid {
require.Nil(t, err)
} else {