mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2024-11-24 21:14:00 +00:00
ca528c4f53
Bumps [github.com/NVIDIA/go-nvlib](https://github.com/NVIDIA/go-nvlib) from 0.5.0 to 0.6.0. - [Release notes](https://github.com/NVIDIA/go-nvlib/releases) - [Commits](https://github.com/NVIDIA/go-nvlib/compare/v0.5.0...v0.6.0) --- updated-dependencies: - dependency-name: github.com/NVIDIA/go-nvlib dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com>
499 lines
14 KiB
Go
499 lines
14 KiB
Go
/*
|
|
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package device
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/NVIDIA/go-nvml/pkg/nvml"
|
|
)
|
|
|
|
// Device defines the set of extended functions associated with a device.Device.
|
|
type Device interface {
|
|
nvml.Device
|
|
GetArchitectureAsString() (string, error)
|
|
GetBrandAsString() (string, error)
|
|
GetCudaComputeCapabilityAsString() (string, error)
|
|
GetMigDevices() ([]MigDevice, error)
|
|
GetMigProfiles() ([]MigProfile, error)
|
|
GetPCIBusID() (string, error)
|
|
IsMigCapable() (bool, error)
|
|
IsMigEnabled() (bool, error)
|
|
VisitMigDevices(func(j int, m MigDevice) error) error
|
|
VisitMigProfiles(func(p MigProfile) error) error
|
|
}
|
|
|
|
type device struct {
|
|
nvml.Device
|
|
lib *devicelib
|
|
migProfiles []MigProfile
|
|
}
|
|
|
|
var _ Device = &device{}
|
|
|
|
// NewDevice builds a new Device from an nvml.Device.
|
|
func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) {
|
|
return d.newDevice(dev)
|
|
}
|
|
|
|
// NewDeviceByUUID builds a new Device from a UUID.
|
|
func (d *devicelib) NewDeviceByUUID(uuid string) (Device, error) {
|
|
dev, ret := d.nvmllib.DeviceGetHandleByUUID(uuid)
|
|
if ret != nvml.SUCCESS {
|
|
return nil, fmt.Errorf("error getting device handle for uuid '%v': %v", uuid, ret)
|
|
}
|
|
return d.newDevice(dev)
|
|
}
|
|
|
|
// newDevice creates a device from an nvml.Device.
|
|
func (d *devicelib) newDevice(dev nvml.Device) (*device, error) {
|
|
return &device{dev, d, nil}, nil
|
|
}
|
|
|
|
// GetArchitectureAsString returns the Device architecture as a string.
|
|
func (d *device) GetArchitectureAsString() (string, error) {
|
|
arch, ret := d.GetArchitecture()
|
|
if ret != nvml.SUCCESS {
|
|
return "", fmt.Errorf("error getting device architecture: %v", ret)
|
|
}
|
|
switch arch {
|
|
case nvml.DEVICE_ARCH_KEPLER:
|
|
return "Kepler", nil
|
|
case nvml.DEVICE_ARCH_MAXWELL:
|
|
return "Maxwell", nil
|
|
case nvml.DEVICE_ARCH_PASCAL:
|
|
return "Pascal", nil
|
|
case nvml.DEVICE_ARCH_VOLTA:
|
|
return "Volta", nil
|
|
case nvml.DEVICE_ARCH_TURING:
|
|
return "Turing", nil
|
|
case nvml.DEVICE_ARCH_AMPERE:
|
|
return "Ampere", nil
|
|
case nvml.DEVICE_ARCH_ADA:
|
|
return "Ada", nil
|
|
case nvml.DEVICE_ARCH_HOPPER:
|
|
return "Hopper", nil
|
|
case nvml.DEVICE_ARCH_UNKNOWN:
|
|
return "Unknown", nil
|
|
}
|
|
return "", fmt.Errorf("error interpreting device architecture as string: %v", arch)
|
|
}
|
|
|
|
// GetBrandAsString returns the Device architecture as a string.
|
|
func (d *device) GetBrandAsString() (string, error) {
|
|
brand, ret := d.GetBrand()
|
|
if ret != nvml.SUCCESS {
|
|
return "", fmt.Errorf("error getting device brand: %v", ret)
|
|
}
|
|
switch brand {
|
|
case nvml.BRAND_UNKNOWN:
|
|
return "Unknown", nil
|
|
case nvml.BRAND_QUADRO:
|
|
return "Quadro", nil
|
|
case nvml.BRAND_TESLA:
|
|
return "Tesla", nil
|
|
case nvml.BRAND_NVS:
|
|
return "NVS", nil
|
|
case nvml.BRAND_GRID:
|
|
return "Grid", nil
|
|
case nvml.BRAND_GEFORCE:
|
|
return "GeForce", nil
|
|
case nvml.BRAND_TITAN:
|
|
return "Titan", nil
|
|
case nvml.BRAND_NVIDIA_VAPPS:
|
|
return "NvidiaVApps", nil
|
|
case nvml.BRAND_NVIDIA_VPC:
|
|
return "NvidiaVPC", nil
|
|
case nvml.BRAND_NVIDIA_VCS:
|
|
return "NvidiaVCS", nil
|
|
case nvml.BRAND_NVIDIA_VWS:
|
|
return "NvidiaVWS", nil
|
|
// Deprecated in favor of nvml.BRAND_NVIDIA_CLOUD_GAMING
|
|
//case nvml.BRAND_NVIDIA_VGAMING:
|
|
// return "VGaming", nil
|
|
case nvml.BRAND_NVIDIA_CLOUD_GAMING:
|
|
return "NvidiaCloudGaming", nil
|
|
case nvml.BRAND_QUADRO_RTX:
|
|
return "QuadroRTX", nil
|
|
case nvml.BRAND_NVIDIA_RTX:
|
|
return "NvidiaRTX", nil
|
|
case nvml.BRAND_NVIDIA:
|
|
return "Nvidia", nil
|
|
case nvml.BRAND_GEFORCE_RTX:
|
|
return "GeForceRTX", nil
|
|
case nvml.BRAND_TITAN_RTX:
|
|
return "TitanRTX", nil
|
|
}
|
|
return "", fmt.Errorf("error interpreting device brand as string: %v", brand)
|
|
}
|
|
|
|
// GetPCIBusID returns the string representation of the bus ID.
|
|
func (d *device) GetPCIBusID() (string, error) {
|
|
info, ret := d.GetPciInfo()
|
|
if ret != nvml.SUCCESS {
|
|
return "", fmt.Errorf("error getting PCI info: %w", ret)
|
|
}
|
|
|
|
var bytes []byte
|
|
for _, b := range info.BusId {
|
|
if byte(b) == '\x00' {
|
|
break
|
|
}
|
|
bytes = append(bytes, byte(b))
|
|
}
|
|
id := strings.ToLower(string(bytes))
|
|
|
|
if id != "0000" {
|
|
id = strings.TrimPrefix(id, "0000")
|
|
}
|
|
|
|
return id, nil
|
|
}
|
|
|
|
// GetCudaComputeCapabilityAsString returns the Device's CUDA compute capability as a version string.
|
|
func (d *device) GetCudaComputeCapabilityAsString() (string, error) {
|
|
major, minor, ret := d.GetCudaComputeCapability()
|
|
if ret != nvml.SUCCESS {
|
|
return "", fmt.Errorf("error getting CUDA compute capability: %v", ret)
|
|
}
|
|
return fmt.Sprintf("%d.%d", major, minor), nil
|
|
}
|
|
|
|
// IsMigCapable checks if a device is capable of having MIG paprtitions created on it.
|
|
func (d *device) IsMigCapable() (bool, error) {
|
|
if !d.lib.hasSymbol("nvmlDeviceGetMigMode") {
|
|
return false, nil
|
|
}
|
|
|
|
_, _, ret := nvml.Device(d).GetMigMode()
|
|
if ret == nvml.ERROR_NOT_SUPPORTED {
|
|
return false, nil
|
|
}
|
|
if ret != nvml.SUCCESS {
|
|
return false, fmt.Errorf("error getting MIG mode: %v", ret)
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
// IsMigEnabled checks if a device has MIG mode currently enabled on it.
|
|
func (d *device) IsMigEnabled() (bool, error) {
|
|
if !d.lib.hasSymbol("nvmlDeviceGetMigMode") {
|
|
return false, nil
|
|
}
|
|
|
|
mode, _, ret := nvml.Device(d).GetMigMode()
|
|
if ret == nvml.ERROR_NOT_SUPPORTED {
|
|
return false, nil
|
|
}
|
|
if ret != nvml.SUCCESS {
|
|
return false, fmt.Errorf("error getting MIG mode: %v", ret)
|
|
}
|
|
|
|
return (mode == nvml.DEVICE_MIG_ENABLE), nil
|
|
}
|
|
|
|
// VisitMigDevices walks a top-level device and invokes a callback function for each MIG device configured on it.
|
|
func (d *device) VisitMigDevices(visit func(int, MigDevice) error) error {
|
|
capable, err := d.IsMigCapable()
|
|
if err != nil {
|
|
return fmt.Errorf("error checking if GPU is MIG capable: %v", err)
|
|
}
|
|
if !capable {
|
|
return nil
|
|
}
|
|
|
|
count, ret := nvml.Device(d).GetMaxMigDeviceCount()
|
|
if ret != nvml.SUCCESS {
|
|
return fmt.Errorf("error getting max MIG device count: %v", ret)
|
|
}
|
|
|
|
for i := 0; i < count; i++ {
|
|
device, ret := nvml.Device(d).GetMigDeviceHandleByIndex(i)
|
|
if ret == nvml.ERROR_NOT_FOUND {
|
|
continue
|
|
}
|
|
if ret == nvml.ERROR_INVALID_ARGUMENT {
|
|
continue
|
|
}
|
|
if ret != nvml.SUCCESS {
|
|
return fmt.Errorf("error getting MIG device handle at index '%v': %v", i, ret)
|
|
}
|
|
mig, err := d.lib.NewMigDevice(device)
|
|
if err != nil {
|
|
return fmt.Errorf("error creating new MIG device wrapper: %v", err)
|
|
}
|
|
err = visit(i, mig)
|
|
if err != nil {
|
|
return fmt.Errorf("error visiting MIG device: %v", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// VisitMigProfiles walks a top-level device and invokes a callback function for each unique MIG Profile that can be configured on it.
|
|
func (d *device) VisitMigProfiles(visit func(MigProfile) error) error {
|
|
capable, err := d.IsMigCapable()
|
|
if err != nil {
|
|
return fmt.Errorf("error checking if GPU is MIG capable: %v", err)
|
|
}
|
|
|
|
if !capable {
|
|
return nil
|
|
}
|
|
|
|
memory, ret := d.GetMemoryInfo()
|
|
if ret != nvml.SUCCESS {
|
|
return fmt.Errorf("error getting device memory info: %v", ret)
|
|
}
|
|
|
|
for i := 0; i < nvml.GPU_INSTANCE_PROFILE_COUNT; i++ {
|
|
giProfileInfo, ret := d.GetGpuInstanceProfileInfo(i)
|
|
if ret == nvml.ERROR_NOT_SUPPORTED {
|
|
continue
|
|
}
|
|
if ret == nvml.ERROR_INVALID_ARGUMENT {
|
|
continue
|
|
}
|
|
if ret != nvml.SUCCESS {
|
|
return fmt.Errorf("error getting GPU Instance profile info: %v", ret)
|
|
}
|
|
|
|
for j := 0; j < nvml.COMPUTE_INSTANCE_PROFILE_COUNT; j++ {
|
|
for k := 0; k < nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT; k++ {
|
|
p, err := d.lib.NewMigProfile(i, j, k, giProfileInfo.MemorySizeMB, memory.Total)
|
|
if err != nil {
|
|
return fmt.Errorf("error creating MIG profile: %v", err)
|
|
}
|
|
|
|
// NOTE: The NVML API doesn't currently let us query the set of
|
|
// valid Compute Instance profiles without first instantiating
|
|
// a GPU Instance to check against. In theory, it should be
|
|
// possible to get this information without a reference to a
|
|
// GPU instance, but no API is provided for that at the moment.
|
|
// We run the checks below to weed out invalid profiles
|
|
// heuristically, given what we know about how they are
|
|
// physically constructed. In the future we should do this via
|
|
// NVML once a proper API for this exists.
|
|
pi := p.GetInfo()
|
|
if pi.C > pi.G {
|
|
continue
|
|
}
|
|
if (pi.C < pi.G) && ((pi.C * 2) > (pi.G + 1)) {
|
|
continue
|
|
}
|
|
|
|
err = visit(p)
|
|
if err != nil {
|
|
return fmt.Errorf("error visiting MIG profile: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetMigDevices gets the set of MIG devices associated with a top-level device.
|
|
func (d *device) GetMigDevices() ([]MigDevice, error) {
|
|
var migs []MigDevice
|
|
err := d.VisitMigDevices(func(j int, m MigDevice) error {
|
|
migs = append(migs, m)
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return migs, nil
|
|
}
|
|
|
|
// GetMigProfiles gets the set of unique MIG profiles associated with a top-level device.
|
|
func (d *device) GetMigProfiles() ([]MigProfile, error) {
|
|
// Return the cached list if available
|
|
if d.migProfiles != nil {
|
|
return d.migProfiles, nil
|
|
}
|
|
|
|
// Otherwise generate it...
|
|
var profiles []MigProfile
|
|
err := d.VisitMigProfiles(func(p MigProfile) error {
|
|
profiles = append(profiles, p)
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// And cache it before returning.
|
|
d.migProfiles = profiles
|
|
return profiles, nil
|
|
}
|
|
|
|
// isSkipped checks whether the device should be skipped.
|
|
func (d *device) isSkipped() (bool, error) {
|
|
name, ret := d.GetName()
|
|
if ret != nvml.SUCCESS {
|
|
return false, fmt.Errorf("error getting device name: %v", ret)
|
|
}
|
|
|
|
if _, exists := d.lib.skippedDevices[name]; exists {
|
|
return true, nil
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
// VisitDevices visits each top-level device and invokes a callback function for it.
|
|
func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
|
|
count, ret := d.nvmllib.DeviceGetCount()
|
|
if ret != nvml.SUCCESS {
|
|
return fmt.Errorf("error getting device count: %v", ret)
|
|
}
|
|
|
|
for i := 0; i < count; i++ {
|
|
device, ret := d.nvmllib.DeviceGetHandleByIndex(i)
|
|
if ret != nvml.SUCCESS {
|
|
return fmt.Errorf("error getting device handle for index '%v': %v", i, ret)
|
|
}
|
|
dev, err := d.newDevice(device)
|
|
if err != nil {
|
|
return fmt.Errorf("error creating new device wrapper: %v", err)
|
|
}
|
|
|
|
isSkipped, err := dev.isSkipped()
|
|
if err != nil {
|
|
return fmt.Errorf("error checking whether device is skipped: %v", err)
|
|
}
|
|
if isSkipped {
|
|
continue
|
|
}
|
|
|
|
err = visit(i, dev)
|
|
if err != nil {
|
|
return fmt.Errorf("error visiting device: %v", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// VisitMigDevices walks a top-level device and invokes a callback function for each MIG device configured on it.
|
|
func (d *devicelib) VisitMigDevices(visit func(int, Device, int, MigDevice) error) error {
|
|
err := d.VisitDevices(func(i int, dev Device) error {
|
|
err := dev.VisitMigDevices(func(j int, mig MigDevice) error {
|
|
err := visit(i, dev, j, mig)
|
|
if err != nil {
|
|
return fmt.Errorf("error visiting MIG device: %v", err)
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("error visiting device: %v", err)
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("error visiting devices: %v", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// VisitMigProfiles walks a top-level device and invokes a callback function for each unique MIG profile found on them.
|
|
func (d *devicelib) VisitMigProfiles(visit func(MigProfile) error) error {
|
|
visited := make(map[string]bool)
|
|
err := d.VisitDevices(func(i int, dev Device) error {
|
|
err := dev.VisitMigProfiles(func(p MigProfile) error {
|
|
if visited[p.String()] {
|
|
return nil
|
|
}
|
|
|
|
err := visit(p)
|
|
if err != nil {
|
|
return fmt.Errorf("error visiting MIG profile: %v", err)
|
|
}
|
|
|
|
visited[p.String()] = true
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("error visiting device: %v", err)
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("error visiting devices: %v", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetDevices gets the set of all top-level devices.
|
|
func (d *devicelib) GetDevices() ([]Device, error) {
|
|
var devs []Device
|
|
err := d.VisitDevices(func(i int, dev Device) error {
|
|
devs = append(devs, dev)
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return devs, nil
|
|
}
|
|
|
|
// GetMigDevices gets the set of MIG devices across all top-level devices.
|
|
func (d *devicelib) GetMigDevices() ([]MigDevice, error) {
|
|
var migs []MigDevice
|
|
err := d.VisitMigDevices(func(i int, dev Device, j int, m MigDevice) error {
|
|
migs = append(migs, m)
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return migs, nil
|
|
}
|
|
|
|
// GetMigProfiles gets the set of unique MIG profiles across all top-level devices.
|
|
func (d *devicelib) GetMigProfiles() ([]MigProfile, error) {
|
|
// Return the cached list if available
|
|
if d.migProfiles != nil {
|
|
return d.migProfiles, nil
|
|
}
|
|
|
|
// Otherwise generate it...
|
|
var profiles []MigProfile
|
|
err := d.VisitMigProfiles(func(p MigProfile) error {
|
|
profiles = append(profiles, p)
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// And cache it before returning.
|
|
d.migProfiles = profiles
|
|
return profiles, nil
|
|
}
|
|
|
|
// hasSymbol checks to see if the given symbol is present in the NVML library.
|
|
// If devicelib is configured to not verify symbols, then all symbols are assumed to exist.
|
|
func (d *devicelib) hasSymbol(symbol string) bool {
|
|
if !*d.verifySymbols {
|
|
return true
|
|
}
|
|
|
|
return d.nvmllib.Extensions().LookupSymbol(symbol) == nil
|
|
}
|