Add additional functions to Device interface

Add the following functions to the Device interface:
* GetCudaComputeCapability
* GetAttributes
* GetName

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2022-09-02 15:05:09 +02:00
parent da71bc2bff
commit a404873b12
3 changed files with 132 additions and 0 deletions

View File

@ -115,3 +115,21 @@ func (d nvmlDevice) GetComputeInstanceId() (int, Return) {
ci, r := nvml.Device(d).GetComputeInstanceId()
return ci, Return(r)
}
// GetCudaComputeCapability returns the compute capability major and minor versions for a device
func (d nvmlDevice) GetCudaComputeCapability() (int, int, Return) {
major, minor, r := nvml.Device(d).GetCudaComputeCapability()
return major, minor, Return(r)
}
// GetAttributes returns the device attributes for a MIG device
func (d nvmlDevice) GetAttributes() (DeviceAttributes, Return) {
a, r := nvml.Device(d).GetAttributes()
return DeviceAttributes(a), Return(r)
}
// GetName returns the device attributes for a MIG device
func (d nvmlDevice) GetName() (string, Return) {
n, r := nvml.Device(d).GetName()
return n, Return(r)
}

View File

@ -17,9 +17,15 @@ var _ Device = &DeviceMock{}
//
// // make and configure a mocked Device
// mockedDevice := &DeviceMock{
// GetAttributesFunc: func() (DeviceAttributes, Return) {
// panic("mock out the GetAttributes method")
// },
// GetComputeInstanceIdFunc: func() (int, Return) {
// panic("mock out the GetComputeInstanceId method")
// },
// GetCudaComputeCapabilityFunc: func() (int, int, Return) {
// panic("mock out the GetCudaComputeCapability method")
// },
// GetDeviceHandleFromMigDeviceHandleFunc: func() (Device, Return) {
// panic("mock out the GetDeviceHandleFromMigDeviceHandle method")
// },
@ -50,6 +56,9 @@ var _ Device = &DeviceMock{}
// GetMinorNumberFunc: func() (int, Return) {
// panic("mock out the GetMinorNumber method")
// },
// GetNameFunc: func() (string, Return) {
// panic("mock out the GetName method")
// },
// GetPciInfoFunc: func() (PciInfo, Return) {
// panic("mock out the GetPciInfo method")
// },
@ -69,9 +78,15 @@ var _ Device = &DeviceMock{}
//
// }
type DeviceMock struct {
// GetAttributesFunc mocks the GetAttributes method.
GetAttributesFunc func() (DeviceAttributes, Return)
// GetComputeInstanceIdFunc mocks the GetComputeInstanceId method.
GetComputeInstanceIdFunc func() (int, Return)
// GetCudaComputeCapabilityFunc mocks the GetCudaComputeCapability method.
GetCudaComputeCapabilityFunc func() (int, int, Return)
// GetDeviceHandleFromMigDeviceHandleFunc mocks the GetDeviceHandleFromMigDeviceHandle method.
GetDeviceHandleFromMigDeviceHandleFunc func() (Device, Return)
@ -102,6 +117,9 @@ type DeviceMock struct {
// GetMinorNumberFunc mocks the GetMinorNumber method.
GetMinorNumberFunc func() (int, Return)
// GetNameFunc mocks the GetName method.
GetNameFunc func() (string, Return)
// GetPciInfoFunc mocks the GetPciInfo method.
GetPciInfoFunc func() (PciInfo, Return)
@ -116,9 +134,15 @@ type DeviceMock struct {
// calls tracks calls to the methods.
calls struct {
// GetAttributes holds details about calls to the GetAttributes method.
GetAttributes []struct {
}
// GetComputeInstanceId holds details about calls to the GetComputeInstanceId method.
GetComputeInstanceId []struct {
}
// GetCudaComputeCapability holds details about calls to the GetCudaComputeCapability method.
GetCudaComputeCapability []struct {
}
// GetDeviceHandleFromMigDeviceHandle holds details about calls to the GetDeviceHandleFromMigDeviceHandle method.
GetDeviceHandleFromMigDeviceHandle []struct {
}
@ -155,6 +179,9 @@ type DeviceMock struct {
// GetMinorNumber holds details about calls to the GetMinorNumber method.
GetMinorNumber []struct {
}
// GetName holds details about calls to the GetName method.
GetName []struct {
}
// GetPciInfo holds details about calls to the GetPciInfo method.
GetPciInfo []struct {
}
@ -170,7 +197,9 @@ type DeviceMock struct {
Mode int
}
}
lockGetAttributes sync.RWMutex
lockGetComputeInstanceId sync.RWMutex
lockGetCudaComputeCapability sync.RWMutex
lockGetDeviceHandleFromMigDeviceHandle sync.RWMutex
lockGetGpuInstanceId sync.RWMutex
lockGetGpuInstanceProfileInfo sync.RWMutex
@ -181,12 +210,39 @@ type DeviceMock struct {
lockGetMigDeviceHandleByIndex sync.RWMutex
lockGetMigMode sync.RWMutex
lockGetMinorNumber sync.RWMutex
lockGetName sync.RWMutex
lockGetPciInfo sync.RWMutex
lockGetUUID sync.RWMutex
lockIsMigDeviceHandle sync.RWMutex
lockSetMigMode sync.RWMutex
}
// GetAttributes calls GetAttributesFunc.
func (mock *DeviceMock) GetAttributes() (DeviceAttributes, Return) {
if mock.GetAttributesFunc == nil {
panic("DeviceMock.GetAttributesFunc: method is nil but Device.GetAttributes was just called")
}
callInfo := struct {
}{}
mock.lockGetAttributes.Lock()
mock.calls.GetAttributes = append(mock.calls.GetAttributes, callInfo)
mock.lockGetAttributes.Unlock()
return mock.GetAttributesFunc()
}
// GetAttributesCalls gets all the calls that were made to GetAttributes.
// Check the length with:
// len(mockedDevice.GetAttributesCalls())
func (mock *DeviceMock) GetAttributesCalls() []struct {
} {
var calls []struct {
}
mock.lockGetAttributes.RLock()
calls = mock.calls.GetAttributes
mock.lockGetAttributes.RUnlock()
return calls
}
// GetComputeInstanceId calls GetComputeInstanceIdFunc.
func (mock *DeviceMock) GetComputeInstanceId() (int, Return) {
if mock.GetComputeInstanceIdFunc == nil {
@ -213,6 +269,32 @@ func (mock *DeviceMock) GetComputeInstanceIdCalls() []struct {
return calls
}
// GetCudaComputeCapability calls GetCudaComputeCapabilityFunc.
func (mock *DeviceMock) GetCudaComputeCapability() (int, int, Return) {
if mock.GetCudaComputeCapabilityFunc == nil {
panic("DeviceMock.GetCudaComputeCapabilityFunc: method is nil but Device.GetCudaComputeCapability was just called")
}
callInfo := struct {
}{}
mock.lockGetCudaComputeCapability.Lock()
mock.calls.GetCudaComputeCapability = append(mock.calls.GetCudaComputeCapability, callInfo)
mock.lockGetCudaComputeCapability.Unlock()
return mock.GetCudaComputeCapabilityFunc()
}
// GetCudaComputeCapabilityCalls gets all the calls that were made to GetCudaComputeCapability.
// Check the length with:
// len(mockedDevice.GetCudaComputeCapabilityCalls())
func (mock *DeviceMock) GetCudaComputeCapabilityCalls() []struct {
} {
var calls []struct {
}
mock.lockGetCudaComputeCapability.RLock()
calls = mock.calls.GetCudaComputeCapability
mock.lockGetCudaComputeCapability.RUnlock()
return calls
}
// GetDeviceHandleFromMigDeviceHandle calls GetDeviceHandleFromMigDeviceHandleFunc.
func (mock *DeviceMock) GetDeviceHandleFromMigDeviceHandle() (Device, Return) {
if mock.GetDeviceHandleFromMigDeviceHandleFunc == nil {
@ -488,6 +570,32 @@ func (mock *DeviceMock) GetMinorNumberCalls() []struct {
return calls
}
// GetName calls GetNameFunc.
func (mock *DeviceMock) GetName() (string, Return) {
if mock.GetNameFunc == nil {
panic("DeviceMock.GetNameFunc: method is nil but Device.GetName was just called")
}
callInfo := struct {
}{}
mock.lockGetName.Lock()
mock.calls.GetName = append(mock.calls.GetName, callInfo)
mock.lockGetName.Unlock()
return mock.GetNameFunc()
}
// GetNameCalls gets all the calls that were made to GetName.
// Check the length with:
// len(mockedDevice.GetNameCalls())
func (mock *DeviceMock) GetNameCalls() []struct {
} {
var calls []struct {
}
mock.lockGetName.RLock()
calls = mock.calls.GetName
mock.lockGetName.RUnlock()
return calls
}
// GetPciInfo calls GetPciInfoFunc.
func (mock *DeviceMock) GetPciInfo() (PciInfo, Return) {
if mock.GetPciInfoFunc == nil {

View File

@ -53,6 +53,9 @@ type Device interface {
GetMigDeviceHandleByIndex(Index int) (Device, Return)
GetGpuInstanceId() (int, Return)
GetComputeInstanceId() (int, Return)
GetCudaComputeCapability() (int, int, Return)
GetAttributes() (DeviceAttributes, Return)
GetName() (string, Return)
}
// GpuInstance defines the functions implemented by a GpuInstance
@ -111,3 +114,6 @@ type ComputeInstanceProfileInfo nvml.ComputeInstanceProfileInfo
// ComputeInstancePlacement holds placement info about a Compute Instance
type ComputeInstancePlacement nvml.ComputeInstancePlacement
// DeviceAttributes stores information about MIG devices
type DeviceAttributes nvml.DeviceAttributes