Pass device.GetBrand() through from NVMl and wrap it to print a string

Signed-off-by: Kevin Klues <kklues@nvidia.com>
This commit is contained in:
Kevin Klues 2023-03-26 21:15:51 +00:00
parent 649f196fb7
commit 087de4f458
5 changed files with 119 additions and 0 deletions

View File

@ -27,6 +27,7 @@ import (
type Device interface {
nvml.Device
GetArchitectureAsString() (string, error)
GetBrandAsString() (string, error)
GetCudaComputeCapabilityAsString() (string, error)
GetMigDevices() ([]MigDevice, error)
GetMigProfiles() ([]MigProfile, error)
@ -92,6 +93,54 @@ func (d *device) GetArchitectureAsString() (string, error) {
return "", fmt.Errorf("error interpreting device architecture as string: %v", arch)
}
// GetBrandAsString returns the Device architecture as a string
func (d *device) GetBrandAsString() (string, error) {
brand, ret := d.GetBrand()
if ret != nvml.SUCCESS {
return "", fmt.Errorf("error getting device brand: %v", ret)
}
switch brand {
case nvml.BRAND_UNKNOWN:
return "Unknown", nil
case nvml.BRAND_QUADRO:
return "Quadro", nil
case nvml.BRAND_TESLA:
return "Tesla", nil
case nvml.BRAND_NVS:
return "NVS", nil
case nvml.BRAND_GRID:
return "Grid", nil
case nvml.BRAND_GEFORCE:
return "GeForce", nil
case nvml.BRAND_TITAN:
return "Titan", nil
case nvml.BRAND_NVIDIA_VAPPS:
return "NvidiaVApps", nil
case nvml.BRAND_NVIDIA_VPC:
return "NvidiaVPC", nil
case nvml.BRAND_NVIDIA_VCS:
return "NvidiaVCS", nil
case nvml.BRAND_NVIDIA_VWS:
return "NvidiaVWS", nil
// Deprecated in favor of nvml.BRAND_NVIDIA_CLOUD_GAMING
//case nvml.BRAND_NVIDIA_VGAMING:
// return "VGaming", nil
case nvml.BRAND_NVIDIA_CLOUD_GAMING:
return "NvidiaCloudGaming", nil
case nvml.BRAND_QUADRO_RTX:
return "QuadroRTX", nil
case nvml.BRAND_NVIDIA_RTX:
return "NvidiaRTX", nil
case nvml.BRAND_NVIDIA:
return "Nvidia", nil
case nvml.BRAND_GEFORCE_RTX:
return "GeForceRTX", nil
case nvml.BRAND_TITAN_RTX:
return "TitanRTX", nil
}
return "", fmt.Errorf("error interpreting device brand as string: %v", brand)
}
// GetCudaComputeCapabilityAsString returns the Device's CUDA compute capability as a version string
func (d *device) GetCudaComputeCapabilityAsString() (string, error) {
major, minor, ret := d.GetCudaComputeCapability()

View File

@ -62,6 +62,29 @@ const (
DEVICE_ARCH_UNKNOWN = nvml.DEVICE_ARCH_UNKNOWN
)
// Device brand constants
const (
BRAND_UNKNOWN = BrandType(nvml.BRAND_UNKNOWN)
BRAND_QUADRO = BrandType(nvml.BRAND_QUADRO)
BRAND_TESLA = BrandType(nvml.BRAND_TESLA)
BRAND_NVS = BrandType(nvml.BRAND_NVS)
BRAND_GRID = BrandType(nvml.BRAND_GRID)
BRAND_GEFORCE = BrandType(nvml.BRAND_GEFORCE)
BRAND_TITAN = BrandType(nvml.BRAND_TITAN)
BRAND_NVIDIA_VAPPS = BrandType(nvml.BRAND_NVIDIA_VAPPS)
BRAND_NVIDIA_VPC = BrandType(nvml.BRAND_NVIDIA_VPC)
BRAND_NVIDIA_VCS = BrandType(nvml.BRAND_NVIDIA_VCS)
BRAND_NVIDIA_VWS = BrandType(nvml.BRAND_NVIDIA_VWS)
BRAND_NVIDIA_CLOUD_GAMING = BrandType(nvml.BRAND_NVIDIA_CLOUD_GAMING)
BRAND_NVIDIA_VGAMING = BrandType(nvml.BRAND_NVIDIA_VGAMING)
BRAND_QUADRO_RTX = BrandType(nvml.BRAND_QUADRO_RTX)
BRAND_NVIDIA_RTX = BrandType(nvml.BRAND_NVIDIA_RTX)
BRAND_NVIDIA = BrandType(nvml.BRAND_NVIDIA)
BRAND_GEFORCE_RTX = BrandType(nvml.BRAND_GEFORCE_RTX)
BRAND_TITAN_RTX = BrandType(nvml.BRAND_TITAN_RTX)
BRAND_COUNT = BrandType(nvml.BRAND_COUNT)
)
// MIG Mode constants
const (
DEVICE_MIG_ENABLE = nvml.DEVICE_MIG_ENABLE

View File

@ -156,6 +156,12 @@ func (d nvmlDevice) GetName() (string, Return) {
return n, Return(r)
}
// GetBrand returns the brand of a Device
func (d nvmlDevice) GetBrand() (BrandType, Return) {
b, r := nvml.Device(d).GetBrand()
return BrandType(b), Return(r)
}
// GetArchitecture returns the architecture of a Device
func (d nvmlDevice) GetArchitecture() (DeviceArchitecture, Return) {
a, r := nvml.Device(d).GetArchitecture()

View File

@ -26,6 +26,9 @@ var _ Device = &DeviceMock{}
// GetAttributesFunc: func() (DeviceAttributes, Return) {
// panic("mock out the GetAttributes method")
// },
// GetBrandFunc: func() (BrandType, Return) {
// panic("mock out the GetBrand method")
// },
// GetComputeInstanceIdFunc: func() (int, Return) {
// panic("mock out the GetComputeInstanceId method")
// },
@ -105,6 +108,9 @@ type DeviceMock struct {
// GetAttributesFunc mocks the GetAttributes method.
GetAttributesFunc func() (DeviceAttributes, Return)
// GetBrandFunc mocks the GetBrand method.
GetBrandFunc func() (BrandType, Return)
// GetComputeInstanceIdFunc mocks the GetComputeInstanceId method.
GetComputeInstanceIdFunc func() (int, Return)
@ -183,6 +189,9 @@ type DeviceMock struct {
// GetAttributes holds details about calls to the GetAttributes method.
GetAttributes []struct {
}
// GetBrand holds details about calls to the GetBrand method.
GetBrand []struct {
}
// GetComputeInstanceId holds details about calls to the GetComputeInstanceId method.
GetComputeInstanceId []struct {
}
@ -266,6 +275,7 @@ type DeviceMock struct {
lockCreateGpuInstanceWithPlacement sync.RWMutex
lockGetArchitecture sync.RWMutex
lockGetAttributes sync.RWMutex
lockGetBrand sync.RWMutex
lockGetComputeInstanceId sync.RWMutex
lockGetCudaComputeCapability sync.RWMutex
lockGetDeviceHandleFromMigDeviceHandle sync.RWMutex
@ -379,6 +389,33 @@ func (mock *DeviceMock) GetAttributesCalls() []struct {
return calls
}
// GetBrand calls GetBrandFunc.
func (mock *DeviceMock) GetBrand() (BrandType, Return) {
if mock.GetBrandFunc == nil {
panic("DeviceMock.GetBrandFunc: method is nil but Device.GetBrand was just called")
}
callInfo := struct {
}{}
mock.lockGetBrand.Lock()
mock.calls.GetBrand = append(mock.calls.GetBrand, callInfo)
mock.lockGetBrand.Unlock()
return mock.GetBrandFunc()
}
// GetBrandCalls gets all the calls that were made to GetBrand.
// Check the length with:
//
// len(mockedDevice.GetBrandCalls())
func (mock *DeviceMock) GetBrandCalls() []struct {
} {
var calls []struct {
}
mock.lockGetBrand.RLock()
calls = mock.calls.GetBrand
mock.lockGetBrand.RUnlock()
return calls
}
// GetComputeInstanceId calls GetComputeInstanceIdFunc.
func (mock *DeviceMock) GetComputeInstanceId() (int, Return) {
if mock.GetComputeInstanceIdFunc == nil {

View File

@ -42,6 +42,7 @@ type Device interface {
CreateGpuInstanceWithPlacement(*GpuInstanceProfileInfo, *GpuInstancePlacement) (GpuInstance, Return)
GetArchitecture() (DeviceArchitecture, Return)
GetAttributes() (DeviceAttributes, Return)
GetBrand() (BrandType, Return)
GetComputeInstanceId() (int, Return)
GetCudaComputeCapability() (int, int, Return)
GetDeviceHandleFromMigDeviceHandle() (Device, Return)
@ -140,3 +141,6 @@ type DeviceAttributes nvml.DeviceAttributes
// DeviceArchitecture represents the hardware architecture of a GPU device
type DeviceArchitecture nvml.DeviceArchitecture
// BrandType represents the brand of a GPU device
type BrandType nvml.BrandType