go-nvlib/pkg/nvlib/device/device.go
Evan Lezar fa5d0408ce Ensure pci bus ID is lower case
The PCI Bus ID returned by NVML is upper case and results in the following error:

error getting PCI device class for device:
failed to construct PCI device:
unable to read PCI device vendor id for 0000:0A:00.0:
open /sys/bus/pci/devices/0000:0A:00.0/vendor:
no such file or directory

Signed-off-by: Evan Lezar <elezar@nvidia.com>
2022-11-16 12:12:07 +01:00

375 lines
9.8 KiB
Go

/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package device
import (
"fmt"
"strings"
"github.com/NVIDIA/go-nvml/pkg/dl"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci"
)
// Device defines the set of extended functions associated with a device.Device
type Device interface {
nvml.Device
GetMigDevices() ([]MigDevice, error)
GetMigProfiles() ([]MigProfile, error)
IsMigCapable() (bool, error)
IsMigEnabled() (bool, error)
VisitMigDevices(func(j int, m MigDevice) error) error
VisitMigProfiles(func(p MigProfile) error) error
}
type device struct {
nvml.Device
lib *devicelib
}
// Class represents the PCI class for a device
type Class uint32
// Define constants for common device classes
const (
ClassCompute = Class(nvpci.PCI3dControllerClass)
ClassDisplay = Class(nvpci.PCIVgaControllerClass)
)
var _ Device = &device{}
// NewDevice builds a new Device from an nvml.Device
func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) {
return d.newDevice(dev)
}
// newDevice creates a device from an nvml.Device
func (d *devicelib) newDevice(dev nvml.Device) (*device, error) {
return &device{dev, d}, nil
}
// classIsSelected checks whether the specified class has been selected when constructing the devicelib
func (d *devicelib) classIsSelected(c Class) bool {
if d.selectedDeviceClasses == nil {
return false
}
_, exists := d.selectedDeviceClasses[c]
return exists
}
// IsMigCapable checks if a device is capable of having MIG paprtitions created on it
func (d *device) IsMigCapable() (bool, error) {
err := nvmlLookupSymbol("nvmlDeviceGetMigMode")
if err != nil {
return false, nil
}
_, _, ret := nvml.Device(d).GetMigMode()
if ret == nvml.ERROR_NOT_SUPPORTED {
return false, nil
}
if ret != nvml.SUCCESS {
return false, fmt.Errorf("error getting MIG mode: %v", ret)
}
return true, nil
}
// IsMigEnabled checks if a device has MIG mode currently enabled on it
func (d *device) IsMigEnabled() (bool, error) {
err := nvmlLookupSymbol("nvmlDeviceGetMigMode")
if err != nil {
return false, nil
}
mode, _, ret := nvml.Device(d).GetMigMode()
if ret == nvml.ERROR_NOT_SUPPORTED {
return false, nil
}
if ret != nvml.SUCCESS {
return false, fmt.Errorf("error getting MIG mode: %v", ret)
}
return (mode == nvml.DEVICE_MIG_ENABLE), nil
}
// VisitMigDevices walks a top-level device and invokes a callback function for each MIG device configured on it
func (d *device) VisitMigDevices(visit func(int, MigDevice) error) error {
count, ret := nvml.Device(d).GetMaxMigDeviceCount()
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting max MIG device count: %v", ret)
}
for i := 0; i < count; i++ {
device, ret := nvml.Device(d).GetMigDeviceHandleByIndex(i)
if ret == nvml.ERROR_NOT_FOUND {
continue
}
if ret == nvml.ERROR_INVALID_ARGUMENT {
continue
}
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting MIG device handle at index '%v': %v", i, ret)
}
mig, err := d.lib.NewMigDevice(device)
if err != nil {
return fmt.Errorf("error creating new MIG device wrapper: %v", err)
}
err = visit(i, mig)
if err != nil {
return fmt.Errorf("error visiting MIG device: %v", err)
}
}
return nil
}
// VisitMigProfiles walks a top-level device and invokes a callback function for each unique MIG Profile that can be configured on it
func (d *device) VisitMigProfiles(visit func(MigProfile) error) error {
capable, err := d.IsMigCapable()
if err != nil {
return fmt.Errorf("error checking if GPU is MIG capable: %v", err)
}
if !capable {
return nil
}
memory, ret := d.GetMemoryInfo()
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device memory info: %v", ret)
}
for i := 0; i < nvml.GPU_INSTANCE_PROFILE_COUNT; i++ {
giProfileInfo, ret := d.GetGpuInstanceProfileInfo(i)
if ret == nvml.ERROR_NOT_SUPPORTED {
continue
}
if ret == nvml.ERROR_INVALID_ARGUMENT {
continue
}
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting GPU Instance profile info: %v", ret)
}
for j := 0; j < nvml.COMPUTE_INSTANCE_PROFILE_COUNT; j++ {
for k := 0; k < nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT; k++ {
p, err := d.lib.NewMigProfile(i, j, k, giProfileInfo.MemorySizeMB, memory.Total)
if err != nil {
return fmt.Errorf("error creating MIG profile: %v", err)
}
err = visit(p)
if err != nil {
return fmt.Errorf("error visiting MIG profile: %v", err)
}
}
}
}
return nil
}
// GetMigDevices gets the set of MIG devices associated with a top-level device
func (d *device) GetMigDevices() ([]MigDevice, error) {
var migs []MigDevice
err := d.VisitMigDevices(func(j int, m MigDevice) error {
migs = append(migs, m)
return nil
})
if err != nil {
return nil, err
}
return migs, nil
}
// GetMigProfiles gets the set of unique MIG profiles associated with a top-level device
func (d *device) GetMigProfiles() ([]MigProfile, error) {
var profiles []MigProfile
err := d.VisitMigProfiles(func(p MigProfile) error {
profiles = append(profiles, p)
return nil
})
if err != nil {
return nil, err
}
return profiles, nil
}
// getClass returns the PCI device class for the device
func (d *device) getClass() (Class, error) {
info, ret := d.GetPciInfo()
if ret != nvml.SUCCESS {
return 0, fmt.Errorf("failed to get PCI info: %v", ret)
}
// We convert the BusId to a string
var bytes []byte
for _, b := range info.BusId {
if byte(b) == '\x00' {
break
}
bytes = append(bytes, byte(b))
}
id := strings.ToLower(string(bytes))
if id != "0000" {
id = strings.TrimPrefix(id, "0000")
}
device, err := nvpci.New().GetGPUByPciBusID(id)
if err != nil {
return 0, fmt.Errorf("failed to construct PCI device: %v", err)
}
return Class(device.Class), nil
}
// VisitDevices visits each top-level device and invokes a callback function for it
func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
count, ret := d.nvml.DeviceGetCount()
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device count: %v", ret)
}
for i := 0; i < count; i++ {
device, ret := d.nvml.DeviceGetHandleByIndex(i)
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device handle for index '%v': %v", i, ret)
}
dev, err := d.newDevice(device)
if err != nil {
return fmt.Errorf("error creating new device wrapper: %v", err)
}
class, err := dev.getClass()
if err != nil {
return fmt.Errorf("error getting PCI device class for device: %v", err)
}
if !d.classIsSelected(class) {
continue
}
err = visit(i, dev)
if err != nil {
return fmt.Errorf("error visiting device: %v", err)
}
}
return nil
}
// VisitMigDevices walks a top-level device and invokes a callback function for each MIG device configured on it
func (d *devicelib) VisitMigDevices(visit func(int, Device, int, MigDevice) error) error {
err := d.VisitDevices(func(i int, dev Device) error {
err := dev.VisitMigDevices(func(j int, mig MigDevice) error {
err := visit(i, dev, j, mig)
if err != nil {
return fmt.Errorf("error visiting MIG device: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("error visiting device: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("error visiting devices: %v", err)
}
return nil
}
// VisitMigProfiles walks a top-level device and invokes a callback function for each unique MIG profile found on them
func (d *devicelib) VisitMigProfiles(visit func(MigProfile) error) error {
visited := make(map[string]bool)
err := d.VisitDevices(func(i int, dev Device) error {
err := dev.VisitMigProfiles(func(p MigProfile) error {
if visited[p.String()] {
return nil
}
err := visit(p)
if err != nil {
return fmt.Errorf("error visiting MIG profile: %v", err)
}
visited[p.String()] = true
return nil
})
if err != nil {
return fmt.Errorf("error visiting device: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("error visiting devices: %v", err)
}
return nil
}
// GetDevices gets the set of all top-level devices
func (d *devicelib) GetDevices() ([]Device, error) {
var devs []Device
err := d.VisitDevices(func(i int, dev Device) error {
devs = append(devs, dev)
return nil
})
if err != nil {
return nil, err
}
return devs, nil
}
// GetMigDevices gets the set of MIG devices across all top-level devices
func (d *devicelib) GetMigDevices() ([]MigDevice, error) {
var migs []MigDevice
err := d.VisitMigDevices(func(i int, dev Device, j int, m MigDevice) error {
migs = append(migs, m)
return nil
})
if err != nil {
return nil, err
}
return migs, nil
}
// GetMigProfiles gets the set of unique MIG profiles across all top-level devices
func (d *devicelib) GetMigProfiles() ([]MigProfile, error) {
var profiles []MigProfile
err := d.VisitMigProfiles(func(p MigProfile) error {
profiles = append(profiles, p)
return nil
})
if err != nil {
return nil, err
}
return profiles, nil
}
// nvmlLookupSymbol checks to see if the given symbol is present in the NVML library
func nvmlLookupSymbol(symbol string) error {
lib := dl.New("libnvidia-ml.so.1", dl.RTLD_LAZY|dl.RTLD_GLOBAL)
if lib == nil {
return fmt.Errorf("error instantiating DynamicLibrary for NVML")
}
err := lib.Open()
if err != nil {
return fmt.Errorf("error opening DynamicLibrary for NVML: %v", err)
}
defer lib.Close()
return lib.Lookup(symbol)
}