mirror of
https://github.com/clearml/go-nvlib
synced 2025-04-22 15:15:53 +00:00
Merge branch 'add-device-apis' into 'main'
Move MIG apis to device package and add extended APIs for top-level devices to it See merge request nvidia/cloud-native/go-nvlib!19
This commit is contained in:
commit
01649c65ea
@ -14,43 +14,50 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package mig
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
|
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Interface provides the API to the mig package
|
// Interface provides the API to the 'device' package
|
||||||
type Interface interface {
|
type Interface interface {
|
||||||
NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (Profile, error)
|
GetDevices() ([]Device, error)
|
||||||
ParseProfile(profile string) (Profile, error)
|
GetMigDevices() ([]MigDevice, error)
|
||||||
|
GetMigProfiles() ([]MigProfile, error)
|
||||||
NewDevice(d nvml.Device) (Device, error)
|
NewDevice(d nvml.Device) (Device, error)
|
||||||
|
NewMigDevice(d nvml.Device) (MigDevice, error)
|
||||||
|
NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (MigProfile, error)
|
||||||
|
ParseMigProfile(profile string) (MigProfile, error)
|
||||||
|
VisitDevices(func(i int, d Device) error) error
|
||||||
|
VisitMigDevices(func(i int, d Device, j int, m MigDevice) error) error
|
||||||
|
VisitMigProfiles(func(p MigProfile) error) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type miglib struct {
|
type devicelib struct {
|
||||||
nvml nvml.Interface
|
nvml nvml.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Interface = &miglib{}
|
var _ Interface = &devicelib{}
|
||||||
|
|
||||||
// New creates a new instance of the 'mig' interface
|
// New creates a new instance of the 'device' interface
|
||||||
func New(opts ...Option) Interface {
|
func New(opts ...Option) Interface {
|
||||||
m := &miglib{}
|
d := &devicelib{}
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(m)
|
opt(d)
|
||||||
}
|
}
|
||||||
if m.nvml == nil {
|
if d.nvml == nil {
|
||||||
m.nvml = nvml.New()
|
d.nvml = nvml.New()
|
||||||
}
|
}
|
||||||
return m
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithNvml provides an Option to set the NVML library used by the 'mig' interface
|
// WithNvml provides an Option to set the NVML library used by the 'device' interface
|
||||||
func WithNvml(nvml nvml.Interface) Option {
|
func WithNvml(nvml nvml.Interface) Option {
|
||||||
return func(m *miglib) {
|
return func(d *devicelib) {
|
||||||
m.nvml = nvml
|
d.nvml = nvml
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Option defines a function for passing options to the New() call
|
// Option defines a function for passing options to the New() call
|
||||||
type Option func(*miglib)
|
type Option func(*devicelib)
|
304
pkg/nvlib/device/device.go
Normal file
304
pkg/nvlib/device/device.go
Normal file
@ -0,0 +1,304 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/NVIDIA/go-nvml/pkg/dl"
|
||||||
|
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Device defines the set of extended functions associated with a device.Device
|
||||||
|
type Device interface {
|
||||||
|
nvml.Device
|
||||||
|
GetMigDevices() ([]MigDevice, error)
|
||||||
|
GetMigProfiles() ([]MigProfile, error)
|
||||||
|
IsMigCapable() (bool, error)
|
||||||
|
IsMigEnabled() (bool, error)
|
||||||
|
VisitMigDevices(func(j int, m MigDevice) error) error
|
||||||
|
VisitMigProfiles(func(p MigProfile) error) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type device struct {
|
||||||
|
nvml.Device
|
||||||
|
lib *devicelib
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Device = &device{}
|
||||||
|
|
||||||
|
// NewDevice builds a new Device from an nvml.Device
|
||||||
|
func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) {
|
||||||
|
return &device{dev, d}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsMigCapable checks if a device is capable of having MIG paprtitions created on it
|
||||||
|
func (d *device) IsMigCapable() (bool, error) {
|
||||||
|
err := nvmlLookupSymbol("nvmlDeviceGetMigMode")
|
||||||
|
if err != nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, ret := nvml.Device(d).GetMigMode()
|
||||||
|
if ret == nvml.ERROR_NOT_SUPPORTED {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if ret != nvml.SUCCESS {
|
||||||
|
return false, fmt.Errorf("error getting MIG mode: %v", ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsMigEnabled checks if a device has MIG mode currently enabled on it
|
||||||
|
func (d *device) IsMigEnabled() (bool, error) {
|
||||||
|
err := nvmlLookupSymbol("nvmlDeviceGetMigMode")
|
||||||
|
if err != nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mode, _, ret := nvml.Device(d).GetMigMode()
|
||||||
|
if ret == nvml.ERROR_NOT_SUPPORTED {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if ret != nvml.SUCCESS {
|
||||||
|
return false, fmt.Errorf("error getting MIG mode: %v", ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (mode == nvml.DEVICE_MIG_ENABLE), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisitMigDevices walks a top-level device and invokes a callback function for each MIG device configured on it
|
||||||
|
func (d *device) VisitMigDevices(visit func(int, MigDevice) error) error {
|
||||||
|
count, ret := nvml.Device(d).GetMaxMigDeviceCount()
|
||||||
|
if ret != nvml.SUCCESS {
|
||||||
|
return fmt.Errorf("error getting max MIG device count: %v", ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
device, ret := nvml.Device(d).GetMigDeviceHandleByIndex(i)
|
||||||
|
if ret == nvml.ERROR_NOT_FOUND {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ret == nvml.ERROR_INVALID_ARGUMENT {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ret != nvml.SUCCESS {
|
||||||
|
return fmt.Errorf("error getting MIG device handle at index '%v': %v", i, ret)
|
||||||
|
}
|
||||||
|
mig, err := d.lib.NewMigDevice(device)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error creating new MIG device wrapper: %v", err)
|
||||||
|
}
|
||||||
|
err = visit(i, mig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error visiting MIG device: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisitMigProfiles walks a top-level device and invokes a callback function for each unique MIG Profile that can be configured on it
|
||||||
|
func (d *device) VisitMigProfiles(visit func(MigProfile) error) error {
|
||||||
|
capable, err := d.IsMigCapable()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error checking if GPU is MIG capable: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !capable {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
memory, ret := d.GetMemoryInfo()
|
||||||
|
if ret != nvml.SUCCESS {
|
||||||
|
return fmt.Errorf("error getting device memory info: %v", ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < nvml.GPU_INSTANCE_PROFILE_COUNT; i++ {
|
||||||
|
giProfileInfo, ret := d.GetGpuInstanceProfileInfo(i)
|
||||||
|
if ret != nvml.SUCCESS {
|
||||||
|
return fmt.Errorf("error getting GPU Instance profile info: %v", ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
for j := 0; j < nvml.COMPUTE_INSTANCE_PROFILE_COUNT; j++ {
|
||||||
|
for k := 0; k < nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT; k++ {
|
||||||
|
p, err := d.lib.NewMigProfile(i, j, k, giProfileInfo.MemorySizeMB, memory.Total)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error creating MIG profile: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = visit(p)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error visiting MIG profile: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMigDevices gets the set of MIG devices associated with a top-level device
|
||||||
|
func (d *device) GetMigDevices() ([]MigDevice, error) {
|
||||||
|
var migs []MigDevice
|
||||||
|
err := d.VisitMigDevices(func(j int, m MigDevice) error {
|
||||||
|
migs = append(migs, m)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return migs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMigProfiles gets the set of unique MIG profiles associated with a top-level device
|
||||||
|
func (d *device) GetMigProfiles() ([]MigProfile, error) {
|
||||||
|
var profiles []MigProfile
|
||||||
|
err := d.VisitMigProfiles(func(p MigProfile) error {
|
||||||
|
profiles = append(profiles, p)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return profiles, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisitDevices visits each top-level device and invokes a callback function for it
|
||||||
|
func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
|
||||||
|
count, ret := d.nvml.DeviceGetCount()
|
||||||
|
if ret != nvml.SUCCESS {
|
||||||
|
return fmt.Errorf("error getting device count: %v", ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
device, ret := d.nvml.DeviceGetHandleByIndex(i)
|
||||||
|
if ret != nvml.SUCCESS {
|
||||||
|
return fmt.Errorf("error getting device handle for index '%v': %v", i, ret)
|
||||||
|
}
|
||||||
|
dev, err := d.NewDevice(device)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error creating new device wrapper: %v", err)
|
||||||
|
}
|
||||||
|
err = visit(i, dev)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error visiting device: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisitMigDevices walks a top-level device and invokes a callback function for each MIG device configured on it
|
||||||
|
func (d *devicelib) VisitMigDevices(visit func(int, Device, int, MigDevice) error) error {
|
||||||
|
err := d.VisitDevices(func(i int, dev Device) error {
|
||||||
|
err := dev.VisitMigDevices(func(j int, mig MigDevice) error {
|
||||||
|
err := visit(i, dev, j, mig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error visiting MIG device: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error visiting device: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error visiting devices: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisitMigProfiles walks a top-level device and invokes a callback function for each unique MIG profile found on them
|
||||||
|
func (d *devicelib) VisitMigProfiles(visit func(MigProfile) error) error {
|
||||||
|
visited := make(map[string]bool)
|
||||||
|
err := d.VisitDevices(func(i int, dev Device) error {
|
||||||
|
err := dev.VisitMigProfiles(func(p MigProfile) error {
|
||||||
|
if visited[p.String()] {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := visit(p)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error visiting MIG profile: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
visited[p.String()] = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error visiting device: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error visiting devices: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDevices gets the set of all top-level devices
|
||||||
|
func (d *devicelib) GetDevices() ([]Device, error) {
|
||||||
|
var devs []Device
|
||||||
|
err := d.VisitDevices(func(i int, dev Device) error {
|
||||||
|
devs = append(devs, dev)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return devs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMigDevices gets the set of MIG devices across all top-level devices
|
||||||
|
func (d *devicelib) GetMigDevices() ([]MigDevice, error) {
|
||||||
|
var migs []MigDevice
|
||||||
|
err := d.VisitMigDevices(func(i int, dev Device, j int, m MigDevice) error {
|
||||||
|
migs = append(migs, m)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return migs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMigProfiles gets the set of unique MIG profiles across all top-level devices
|
||||||
|
func (d *devicelib) GetMigProfiles() ([]MigProfile, error) {
|
||||||
|
var profiles []MigProfile
|
||||||
|
err := d.VisitMigProfiles(func(p MigProfile) error {
|
||||||
|
profiles = append(profiles, p)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return profiles, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// nvmlLookupSymbol checks to see if the given symbol is present in the NVML library
|
||||||
|
func nvmlLookupSymbol(symbol string) error {
|
||||||
|
lib := dl.New("libnvidia-ml.so.1", dl.RTLD_LAZY|dl.RTLD_GLOBAL)
|
||||||
|
if lib == nil {
|
||||||
|
return fmt.Errorf("error instantiating DynamicLibrary for NVML")
|
||||||
|
}
|
||||||
|
err := lib.Open()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error opening DynamicLibrary for NVML: %v", err)
|
||||||
|
}
|
||||||
|
defer lib.Close()
|
||||||
|
return lib.Lookup(symbol)
|
||||||
|
}
|
@ -14,7 +14,7 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package mig
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -22,38 +22,39 @@ import (
|
|||||||
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
|
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Device defines the set of extended functions associated with a mig.Device
|
// MigDevice defines the set of extended functions associated with a MIG device
|
||||||
type Device interface {
|
type MigDevice interface {
|
||||||
GetProfile() (Profile, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type device struct {
|
|
||||||
nvml.Device
|
nvml.Device
|
||||||
miglib *miglib
|
GetProfile() (MigProfile, error)
|
||||||
profile Profile
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Device = &device{}
|
type migdevice struct {
|
||||||
|
nvml.Device
|
||||||
|
lib *devicelib
|
||||||
|
profile MigProfile
|
||||||
|
}
|
||||||
|
|
||||||
// NewDevice builds a new Device from an nvml.Device
|
var _ MigDevice = &migdevice{}
|
||||||
func (m *miglib) NewDevice(d nvml.Device) (Device, error) {
|
|
||||||
isMig, ret := d.IsMigDeviceHandle()
|
// NewMigDevice builds a new MigDevice from an nvml.Device
|
||||||
|
func (d *devicelib) NewMigDevice(handle nvml.Device) (MigDevice, error) {
|
||||||
|
isMig, ret := handle.IsMigDeviceHandle()
|
||||||
if ret != nvml.SUCCESS {
|
if ret != nvml.SUCCESS {
|
||||||
return nil, fmt.Errorf("error checking if device is a MIG device: %v", ret)
|
return nil, fmt.Errorf("error checking if device is a MIG device: %v", ret)
|
||||||
}
|
}
|
||||||
if !isMig {
|
if !isMig {
|
||||||
return nil, fmt.Errorf("not a MIG device")
|
return nil, fmt.Errorf("not a MIG device")
|
||||||
}
|
}
|
||||||
return &device{d, m, nil}, nil
|
return &migdevice{handle, d, nil}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProfile returns the MIG profile associated with a MIG device
|
// GetProfile returns the MIG profile associated with a MIG device
|
||||||
func (d *device) GetProfile() (Profile, error) {
|
func (m *migdevice) GetProfile() (MigProfile, error) {
|
||||||
if d.profile != nil {
|
if m.profile != nil {
|
||||||
return d.profile, nil
|
return m.profile, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
parent, ret := d.Device.GetDeviceHandleFromMigDeviceHandle()
|
parent, ret := m.Device.GetDeviceHandleFromMigDeviceHandle()
|
||||||
if ret != nvml.SUCCESS {
|
if ret != nvml.SUCCESS {
|
||||||
return nil, fmt.Errorf("error getting parent device handle: %v", ret)
|
return nil, fmt.Errorf("error getting parent device handle: %v", ret)
|
||||||
}
|
}
|
||||||
@ -63,17 +64,17 @@ func (d *device) GetProfile() (Profile, error) {
|
|||||||
return nil, fmt.Errorf("error getting parent memory info: %v", ret)
|
return nil, fmt.Errorf("error getting parent memory info: %v", ret)
|
||||||
}
|
}
|
||||||
|
|
||||||
attributes, ret := d.Device.GetAttributes()
|
attributes, ret := m.Device.GetAttributes()
|
||||||
if ret != nvml.SUCCESS {
|
if ret != nvml.SUCCESS {
|
||||||
return nil, fmt.Errorf("error getting MIG device attributes: %v", ret)
|
return nil, fmt.Errorf("error getting MIG device attributes: %v", ret)
|
||||||
}
|
}
|
||||||
|
|
||||||
giID, ret := d.Device.GetGpuInstanceId()
|
giID, ret := m.Device.GetGpuInstanceId()
|
||||||
if ret != nvml.SUCCESS {
|
if ret != nvml.SUCCESS {
|
||||||
return nil, fmt.Errorf("error getting MIG device GPU Instance ID: %v", ret)
|
return nil, fmt.Errorf("error getting MIG device GPU Instance ID: %v", ret)
|
||||||
}
|
}
|
||||||
|
|
||||||
ciID, ret := d.Device.GetComputeInstanceId()
|
ciID, ret := m.Device.GetComputeInstanceId()
|
||||||
if ret != nvml.SUCCESS {
|
if ret != nvml.SUCCESS {
|
||||||
return nil, fmt.Errorf("error getting MIG device Compute Instance ID: %v", ret)
|
return nil, fmt.Errorf("error getting MIG device Compute Instance ID: %v", ret)
|
||||||
}
|
}
|
||||||
@ -120,12 +121,12 @@ func (d *device) GetProfile() (Profile, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := d.miglib.NewProfile(i, j, k, attributes.MemorySizeMB, parentMemoryInfo.Total)
|
p, err := m.lib.NewMigProfile(i, j, k, attributes.MemorySizeMB, parentMemoryInfo.Total)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating MIG profile: %v", err)
|
return nil, fmt.Errorf("error creating MIG profile: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
d.profile = p
|
m.profile = p
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -14,7 +14,7 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package mig
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -30,16 +30,16 @@ const (
|
|||||||
AttributeMediaExtensions = "me"
|
AttributeMediaExtensions = "me"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Profile represents a specific MIG profile.
|
// MigProfile represents a specific MIG profile.
|
||||||
// Examples include "1g.5gb", "2g.10gb", "1c.2g.10gb", or "1c.1g.5gb+me", etc.
|
// Examples include "1g.5gb", "2g.10gb", "1c.2g.10gb", or "1c.1g.5gb+me", etc.
|
||||||
type Profile interface {
|
type MigProfile interface {
|
||||||
String() string
|
String() string
|
||||||
GetInfo() ProfileInfo
|
GetInfo() MigProfileInfo
|
||||||
Equals(other Profile) bool
|
Equals(other MigProfile) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProfileInfo holds all info associated with a specific MIG profile
|
// MigProfileInfo holds all info associated with a specific MIG profile
|
||||||
type ProfileInfo struct {
|
type MigProfileInfo struct {
|
||||||
C int
|
C int
|
||||||
G int
|
G int
|
||||||
GB int
|
GB int
|
||||||
@ -49,10 +49,10 @@ type ProfileInfo struct {
|
|||||||
CIEngProfileID int
|
CIEngProfileID int
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Profile = &ProfileInfo{}
|
var _ MigProfile = &MigProfileInfo{}
|
||||||
|
|
||||||
// NewProfile constructs a new Profile struct using info from the giProfiles and ciProfiles used to create it.
|
// NewProfile constructs a new Profile struct using info from the giProfiles and ciProfiles used to create it.
|
||||||
func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (Profile, error) {
|
func (d *devicelib) NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (MigProfile, error) {
|
||||||
giSlices := 0
|
giSlices := 0
|
||||||
switch giProfileID {
|
switch giProfileID {
|
||||||
case nvml.GPU_INSTANCE_PROFILE_1_SLICE:
|
case nvml.GPU_INSTANCE_PROFILE_1_SLICE:
|
||||||
@ -101,7 +101,7 @@ func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMem
|
|||||||
attrs = append(attrs, AttributeMediaExtensions)
|
attrs = append(attrs, AttributeMediaExtensions)
|
||||||
}
|
}
|
||||||
|
|
||||||
p := &ProfileInfo{
|
p := &MigProfileInfo{
|
||||||
C: ciSlices,
|
C: ciSlices,
|
||||||
G: giSlices,
|
G: giSlices,
|
||||||
GB: int(getMigMemorySizeGB(deviceMemorySizeBytes, migMemorySizeMB)),
|
GB: int(getMigMemorySizeGB(deviceMemorySizeBytes, migMemorySizeMB)),
|
||||||
@ -114,8 +114,8 @@ func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMem
|
|||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseProfile converts a string representation of a Profile into an object.
|
// ParseMigProfile converts a string representation of a MigProfile into an object
|
||||||
func (m *miglib) ParseProfile(profile string) (Profile, error) {
|
func (d *devicelib) ParseMigProfile(profile string) (MigProfile, error) {
|
||||||
var err error
|
var err error
|
||||||
var c, g, gb int
|
var c, g, gb int
|
||||||
var attrs []string
|
var attrs []string
|
||||||
@ -126,18 +126,18 @@ func (m *miglib) ParseProfile(profile string) (Profile, error) {
|
|||||||
|
|
||||||
split := strings.SplitN(profile, "+", 2)
|
split := strings.SplitN(profile, "+", 2)
|
||||||
if len(split) == 2 {
|
if len(split) == 2 {
|
||||||
attrs, err = parseProfileAttributes(split[1])
|
attrs, err = parseMigProfileAttributes(split[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error parsing attributes following '+' in Profile string: %v", err)
|
return nil, fmt.Errorf("error parsing attributes following '+' in Profile string: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c, g, gb, err = parseProfileFields(split[0])
|
c, g, gb, err = parseMigProfileFields(split[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error parsing '.' separated fields in Profile string: %v", err)
|
return nil, fmt.Errorf("error parsing '.' separated fields in Profile string: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
p := &ProfileInfo{
|
p := &MigProfileInfo{
|
||||||
C: c,
|
C: c,
|
||||||
G: g,
|
G: g,
|
||||||
GB: gb,
|
GB: gb,
|
||||||
@ -197,7 +197,7 @@ func (m *miglib) ParseProfile(profile string) (Profile, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// String returns the string representation of a Profile
|
// String returns the string representation of a Profile
|
||||||
func (p *ProfileInfo) String() string {
|
func (p *MigProfileInfo) String() string {
|
||||||
var suffix string
|
var suffix string
|
||||||
if len(p.Attributes) > 0 {
|
if len(p.Attributes) > 0 {
|
||||||
suffix = "+" + strings.Join(p.Attributes, ",")
|
suffix = "+" + strings.Join(p.Attributes, ",")
|
||||||
@ -209,14 +209,14 @@ func (p *ProfileInfo) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetInfo returns detailed info about a Profile
|
// GetInfo returns detailed info about a Profile
|
||||||
func (p *ProfileInfo) GetInfo() ProfileInfo {
|
func (p *MigProfileInfo) GetInfo() MigProfileInfo {
|
||||||
return *p
|
return *p
|
||||||
}
|
}
|
||||||
|
|
||||||
// Equals checks if two Profiles are identical or not
|
// Equals checks if two Profiles are identical or not
|
||||||
func (p *ProfileInfo) Equals(other Profile) bool {
|
func (p *MigProfileInfo) Equals(other MigProfile) bool {
|
||||||
switch o := other.(type) {
|
switch o := other.(type) {
|
||||||
case *ProfileInfo:
|
case *MigProfileInfo:
|
||||||
if p.C != o.C {
|
if p.C != o.C {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -240,7 +240,7 @@ func (p *ProfileInfo) Equals(other Profile) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseProfileField(s string, field string) (int, error) {
|
func parseMigProfileField(s string, field string) (int, error) {
|
||||||
if strings.TrimSpace(s) != s {
|
if strings.TrimSpace(s) != s {
|
||||||
return -1, fmt.Errorf("leading or trailing spaces on '%%d%s'", field)
|
return -1, fmt.Errorf("leading or trailing spaces on '%%d%s'", field)
|
||||||
}
|
}
|
||||||
@ -257,32 +257,32 @@ func parseProfileField(s string, field string) (int, error) {
|
|||||||
return v, nil
|
return v, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseProfileFields(s string) (int, int, int, error) {
|
func parseMigProfileFields(s string) (int, int, int, error) {
|
||||||
var err error
|
var err error
|
||||||
var c, g, gb int
|
var c, g, gb int
|
||||||
|
|
||||||
split := strings.SplitN(s, ".", 3)
|
split := strings.SplitN(s, ".", 3)
|
||||||
if len(split) == 3 {
|
if len(split) == 3 {
|
||||||
c, err = parseProfileField(split[0], "c")
|
c, err = parseMigProfileField(split[0], "c")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return -1, -1, -1, err
|
return -1, -1, -1, err
|
||||||
}
|
}
|
||||||
g, err = parseProfileField(split[1], "g")
|
g, err = parseMigProfileField(split[1], "g")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return -1, -1, -1, err
|
return -1, -1, -1, err
|
||||||
}
|
}
|
||||||
gb, err = parseProfileField(split[2], "gb")
|
gb, err = parseMigProfileField(split[2], "gb")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return -1, -1, -1, err
|
return -1, -1, -1, err
|
||||||
}
|
}
|
||||||
return c, g, gb, err
|
return c, g, gb, err
|
||||||
}
|
}
|
||||||
if len(split) == 2 {
|
if len(split) == 2 {
|
||||||
g, err = parseProfileField(split[0], "g")
|
g, err = parseMigProfileField(split[0], "g")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return -1, -1, -1, err
|
return -1, -1, -1, err
|
||||||
}
|
}
|
||||||
gb, err = parseProfileField(split[1], "gb")
|
gb, err = parseMigProfileField(split[1], "gb")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return -1, -1, -1, err
|
return -1, -1, -1, err
|
||||||
}
|
}
|
||||||
@ -292,7 +292,7 @@ func parseProfileFields(s string) (int, int, int, error) {
|
|||||||
return -1, -1, -1, fmt.Errorf("parsed wrong number of fields, expected 2 or 3")
|
return -1, -1, -1, fmt.Errorf("parsed wrong number of fields, expected 2 or 3")
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseProfileAttributes(s string) ([]string, error) {
|
func parseMigProfileAttributes(s string) ([]string, error) {
|
||||||
attr := strings.Split(s, ",")
|
attr := strings.Split(s, ",")
|
||||||
if len(attr) == 0 {
|
if len(attr) == 0 {
|
||||||
return nil, fmt.Errorf("empty attribute list")
|
return nil, fmt.Errorf("empty attribute list")
|
@ -14,7 +14,7 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package mig
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -256,10 +256,10 @@ func TestParseMigProfile(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m := New()
|
d := New()
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.description, func(t *testing.T) {
|
t.Run(tc.description, func(t *testing.T) {
|
||||||
_, err := m.ParseProfile(tc.device)
|
_, err := d.ParseMigProfile(tc.device)
|
||||||
if tc.valid {
|
if tc.valid {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
} else {
|
} else {
|
Loading…
Reference in New Issue
Block a user