diff --git a/pkg/nvlib/device/device.go b/pkg/nvlib/device/device.go index e2603cf..f828cab 100644 --- a/pkg/nvlib/device/device.go +++ b/pkg/nvlib/device/device.go @@ -27,6 +27,7 @@ import ( type Device interface { nvml.Device GetArchitectureAsString() (string, error) + GetBrandAsString() (string, error) GetCudaComputeCapabilityAsString() (string, error) GetMigDevices() ([]MigDevice, error) GetMigProfiles() ([]MigProfile, error) @@ -92,6 +93,54 @@ func (d *device) GetArchitectureAsString() (string, error) { return "", fmt.Errorf("error interpreting device architecture as string: %v", arch) } +// GetBrandAsString returns the Device architecture as a string +func (d *device) GetBrandAsString() (string, error) { + brand, ret := d.GetBrand() + if ret != nvml.SUCCESS { + return "", fmt.Errorf("error getting device brand: %v", ret) + } + switch brand { + case nvml.BRAND_UNKNOWN: + return "Unknown", nil + case nvml.BRAND_QUADRO: + return "Quadro", nil + case nvml.BRAND_TESLA: + return "Tesla", nil + case nvml.BRAND_NVS: + return "NVS", nil + case nvml.BRAND_GRID: + return "Grid", nil + case nvml.BRAND_GEFORCE: + return "GeForce", nil + case nvml.BRAND_TITAN: + return "Titan", nil + case nvml.BRAND_NVIDIA_VAPPS: + return "NvidiaVApps", nil + case nvml.BRAND_NVIDIA_VPC: + return "NvidiaVPC", nil + case nvml.BRAND_NVIDIA_VCS: + return "NvidiaVCS", nil + case nvml.BRAND_NVIDIA_VWS: + return "NvidiaVWS", nil + // Deprecated in favor of nvml.BRAND_NVIDIA_CLOUD_GAMING + //case nvml.BRAND_NVIDIA_VGAMING: + // return "VGaming", nil + case nvml.BRAND_NVIDIA_CLOUD_GAMING: + return "NvidiaCloudGaming", nil + case nvml.BRAND_QUADRO_RTX: + return "QuadroRTX", nil + case nvml.BRAND_NVIDIA_RTX: + return "NvidiaRTX", nil + case nvml.BRAND_NVIDIA: + return "Nvidia", nil + case nvml.BRAND_GEFORCE_RTX: + return "GeForceRTX", nil + case nvml.BRAND_TITAN_RTX: + return "TitanRTX", nil + } + return "", fmt.Errorf("error interpreting device brand as string: %v", brand) +} + // GetCudaComputeCapabilityAsString returns the Device's CUDA compute capability as a version string func (d *device) GetCudaComputeCapabilityAsString() (string, error) { major, minor, ret := d.GetCudaComputeCapability() diff --git a/pkg/nvml/consts.go b/pkg/nvml/consts.go index 6cccee7..c9b85de 100644 --- a/pkg/nvml/consts.go +++ b/pkg/nvml/consts.go @@ -62,6 +62,29 @@ const ( DEVICE_ARCH_UNKNOWN = nvml.DEVICE_ARCH_UNKNOWN ) +// Device brand constants +const ( + BRAND_UNKNOWN = BrandType(nvml.BRAND_UNKNOWN) + BRAND_QUADRO = BrandType(nvml.BRAND_QUADRO) + BRAND_TESLA = BrandType(nvml.BRAND_TESLA) + BRAND_NVS = BrandType(nvml.BRAND_NVS) + BRAND_GRID = BrandType(nvml.BRAND_GRID) + BRAND_GEFORCE = BrandType(nvml.BRAND_GEFORCE) + BRAND_TITAN = BrandType(nvml.BRAND_TITAN) + BRAND_NVIDIA_VAPPS = BrandType(nvml.BRAND_NVIDIA_VAPPS) + BRAND_NVIDIA_VPC = BrandType(nvml.BRAND_NVIDIA_VPC) + BRAND_NVIDIA_VCS = BrandType(nvml.BRAND_NVIDIA_VCS) + BRAND_NVIDIA_VWS = BrandType(nvml.BRAND_NVIDIA_VWS) + BRAND_NVIDIA_CLOUD_GAMING = BrandType(nvml.BRAND_NVIDIA_CLOUD_GAMING) + BRAND_NVIDIA_VGAMING = BrandType(nvml.BRAND_NVIDIA_VGAMING) + BRAND_QUADRO_RTX = BrandType(nvml.BRAND_QUADRO_RTX) + BRAND_NVIDIA_RTX = BrandType(nvml.BRAND_NVIDIA_RTX) + BRAND_NVIDIA = BrandType(nvml.BRAND_NVIDIA) + BRAND_GEFORCE_RTX = BrandType(nvml.BRAND_GEFORCE_RTX) + BRAND_TITAN_RTX = BrandType(nvml.BRAND_TITAN_RTX) + BRAND_COUNT = BrandType(nvml.BRAND_COUNT) +) + // MIG Mode constants const ( DEVICE_MIG_ENABLE = nvml.DEVICE_MIG_ENABLE diff --git a/pkg/nvml/device.go b/pkg/nvml/device.go index ddfe6ed..3c318a7 100644 --- a/pkg/nvml/device.go +++ b/pkg/nvml/device.go @@ -156,6 +156,12 @@ func (d nvmlDevice) GetName() (string, Return) { return n, Return(r) } +// GetBrand returns the brand of a Device +func (d nvmlDevice) GetBrand() (BrandType, Return) { + b, r := nvml.Device(d).GetBrand() + return BrandType(b), Return(r) +} + // GetArchitecture returns the architecture of a Device func (d nvmlDevice) GetArchitecture() (DeviceArchitecture, Return) { a, r := nvml.Device(d).GetArchitecture() diff --git a/pkg/nvml/device_mock.go b/pkg/nvml/device_mock.go index 636e01c..34e563c 100644 --- a/pkg/nvml/device_mock.go +++ b/pkg/nvml/device_mock.go @@ -26,6 +26,9 @@ var _ Device = &DeviceMock{} // GetAttributesFunc: func() (DeviceAttributes, Return) { // panic("mock out the GetAttributes method") // }, +// GetBrandFunc: func() (BrandType, Return) { +// panic("mock out the GetBrand method") +// }, // GetComputeInstanceIdFunc: func() (int, Return) { // panic("mock out the GetComputeInstanceId method") // }, @@ -105,6 +108,9 @@ type DeviceMock struct { // GetAttributesFunc mocks the GetAttributes method. GetAttributesFunc func() (DeviceAttributes, Return) + // GetBrandFunc mocks the GetBrand method. + GetBrandFunc func() (BrandType, Return) + // GetComputeInstanceIdFunc mocks the GetComputeInstanceId method. GetComputeInstanceIdFunc func() (int, Return) @@ -183,6 +189,9 @@ type DeviceMock struct { // GetAttributes holds details about calls to the GetAttributes method. GetAttributes []struct { } + // GetBrand holds details about calls to the GetBrand method. + GetBrand []struct { + } // GetComputeInstanceId holds details about calls to the GetComputeInstanceId method. GetComputeInstanceId []struct { } @@ -266,6 +275,7 @@ type DeviceMock struct { lockCreateGpuInstanceWithPlacement sync.RWMutex lockGetArchitecture sync.RWMutex lockGetAttributes sync.RWMutex + lockGetBrand sync.RWMutex lockGetComputeInstanceId sync.RWMutex lockGetCudaComputeCapability sync.RWMutex lockGetDeviceHandleFromMigDeviceHandle sync.RWMutex @@ -379,6 +389,33 @@ func (mock *DeviceMock) GetAttributesCalls() []struct { return calls } +// GetBrand calls GetBrandFunc. +func (mock *DeviceMock) GetBrand() (BrandType, Return) { + if mock.GetBrandFunc == nil { + panic("DeviceMock.GetBrandFunc: method is nil but Device.GetBrand was just called") + } + callInfo := struct { + }{} + mock.lockGetBrand.Lock() + mock.calls.GetBrand = append(mock.calls.GetBrand, callInfo) + mock.lockGetBrand.Unlock() + return mock.GetBrandFunc() +} + +// GetBrandCalls gets all the calls that were made to GetBrand. +// Check the length with: +// +// len(mockedDevice.GetBrandCalls()) +func (mock *DeviceMock) GetBrandCalls() []struct { +} { + var calls []struct { + } + mock.lockGetBrand.RLock() + calls = mock.calls.GetBrand + mock.lockGetBrand.RUnlock() + return calls +} + // GetComputeInstanceId calls GetComputeInstanceIdFunc. func (mock *DeviceMock) GetComputeInstanceId() (int, Return) { if mock.GetComputeInstanceIdFunc == nil { diff --git a/pkg/nvml/types.go b/pkg/nvml/types.go index c360a47..39d005f 100644 --- a/pkg/nvml/types.go +++ b/pkg/nvml/types.go @@ -42,6 +42,7 @@ type Device interface { CreateGpuInstanceWithPlacement(*GpuInstanceProfileInfo, *GpuInstancePlacement) (GpuInstance, Return) GetArchitecture() (DeviceArchitecture, Return) GetAttributes() (DeviceAttributes, Return) + GetBrand() (BrandType, Return) GetComputeInstanceId() (int, Return) GetCudaComputeCapability() (int, int, Return) GetDeviceHandleFromMigDeviceHandle() (Device, Return) @@ -140,3 +141,6 @@ type DeviceAttributes nvml.DeviceAttributes // DeviceArchitecture represents the hardware architecture of a GPU device type DeviceArchitecture nvml.DeviceArchitecture + +// BrandType represents the brand of a GPU device +type BrandType nvml.BrandType