From a404873b1217f43b0e52e76be3897885b1c70b16 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Fri, 2 Sep 2022 15:05:09 +0200 Subject: [PATCH] Add additional functions to Device interface Add the following functions to the Device interface: * GetCudaComputeCapability * GetAttributes * GetName Signed-off-by: Evan Lezar --- pkg/nvml/device.go | 18 +++++++ pkg/nvml/device_mock.go | 108 ++++++++++++++++++++++++++++++++++++++++ pkg/nvml/types.go | 6 +++ 3 files changed, 132 insertions(+) diff --git a/pkg/nvml/device.go b/pkg/nvml/device.go index 9b0caf8..91bea46 100644 --- a/pkg/nvml/device.go +++ b/pkg/nvml/device.go @@ -115,3 +115,21 @@ func (d nvmlDevice) GetComputeInstanceId() (int, Return) { ci, r := nvml.Device(d).GetComputeInstanceId() return ci, Return(r) } + +// GetCudaComputeCapability returns the compute capability major and minor versions for a device +func (d nvmlDevice) GetCudaComputeCapability() (int, int, Return) { + major, minor, r := nvml.Device(d).GetCudaComputeCapability() + return major, minor, Return(r) +} + +// GetAttributes returns the device attributes for a MIG device +func (d nvmlDevice) GetAttributes() (DeviceAttributes, Return) { + a, r := nvml.Device(d).GetAttributes() + return DeviceAttributes(a), Return(r) +} + +// GetName returns the device attributes for a MIG device +func (d nvmlDevice) GetName() (string, Return) { + n, r := nvml.Device(d).GetName() + return n, Return(r) +} diff --git a/pkg/nvml/device_mock.go b/pkg/nvml/device_mock.go index a25898e..adc5e28 100644 --- a/pkg/nvml/device_mock.go +++ b/pkg/nvml/device_mock.go @@ -17,9 +17,15 @@ var _ Device = &DeviceMock{} // // // make and configure a mocked Device // mockedDevice := &DeviceMock{ +// GetAttributesFunc: func() (DeviceAttributes, Return) { +// panic("mock out the GetAttributes method") +// }, // GetComputeInstanceIdFunc: func() (int, Return) { // panic("mock out the GetComputeInstanceId method") // }, +// GetCudaComputeCapabilityFunc: func() (int, int, Return) { +// panic("mock out the GetCudaComputeCapability method") +// }, // GetDeviceHandleFromMigDeviceHandleFunc: func() (Device, Return) { // panic("mock out the GetDeviceHandleFromMigDeviceHandle method") // }, @@ -50,6 +56,9 @@ var _ Device = &DeviceMock{} // GetMinorNumberFunc: func() (int, Return) { // panic("mock out the GetMinorNumber method") // }, +// GetNameFunc: func() (string, Return) { +// panic("mock out the GetName method") +// }, // GetPciInfoFunc: func() (PciInfo, Return) { // panic("mock out the GetPciInfo method") // }, @@ -69,9 +78,15 @@ var _ Device = &DeviceMock{} // // } type DeviceMock struct { + // GetAttributesFunc mocks the GetAttributes method. + GetAttributesFunc func() (DeviceAttributes, Return) + // GetComputeInstanceIdFunc mocks the GetComputeInstanceId method. GetComputeInstanceIdFunc func() (int, Return) + // GetCudaComputeCapabilityFunc mocks the GetCudaComputeCapability method. + GetCudaComputeCapabilityFunc func() (int, int, Return) + // GetDeviceHandleFromMigDeviceHandleFunc mocks the GetDeviceHandleFromMigDeviceHandle method. GetDeviceHandleFromMigDeviceHandleFunc func() (Device, Return) @@ -102,6 +117,9 @@ type DeviceMock struct { // GetMinorNumberFunc mocks the GetMinorNumber method. GetMinorNumberFunc func() (int, Return) + // GetNameFunc mocks the GetName method. + GetNameFunc func() (string, Return) + // GetPciInfoFunc mocks the GetPciInfo method. GetPciInfoFunc func() (PciInfo, Return) @@ -116,9 +134,15 @@ type DeviceMock struct { // calls tracks calls to the methods. calls struct { + // GetAttributes holds details about calls to the GetAttributes method. + GetAttributes []struct { + } // GetComputeInstanceId holds details about calls to the GetComputeInstanceId method. GetComputeInstanceId []struct { } + // GetCudaComputeCapability holds details about calls to the GetCudaComputeCapability method. + GetCudaComputeCapability []struct { + } // GetDeviceHandleFromMigDeviceHandle holds details about calls to the GetDeviceHandleFromMigDeviceHandle method. GetDeviceHandleFromMigDeviceHandle []struct { } @@ -155,6 +179,9 @@ type DeviceMock struct { // GetMinorNumber holds details about calls to the GetMinorNumber method. GetMinorNumber []struct { } + // GetName holds details about calls to the GetName method. + GetName []struct { + } // GetPciInfo holds details about calls to the GetPciInfo method. GetPciInfo []struct { } @@ -170,7 +197,9 @@ type DeviceMock struct { Mode int } } + lockGetAttributes sync.RWMutex lockGetComputeInstanceId sync.RWMutex + lockGetCudaComputeCapability sync.RWMutex lockGetDeviceHandleFromMigDeviceHandle sync.RWMutex lockGetGpuInstanceId sync.RWMutex lockGetGpuInstanceProfileInfo sync.RWMutex @@ -181,12 +210,39 @@ type DeviceMock struct { lockGetMigDeviceHandleByIndex sync.RWMutex lockGetMigMode sync.RWMutex lockGetMinorNumber sync.RWMutex + lockGetName sync.RWMutex lockGetPciInfo sync.RWMutex lockGetUUID sync.RWMutex lockIsMigDeviceHandle sync.RWMutex lockSetMigMode sync.RWMutex } +// GetAttributes calls GetAttributesFunc. +func (mock *DeviceMock) GetAttributes() (DeviceAttributes, Return) { + if mock.GetAttributesFunc == nil { + panic("DeviceMock.GetAttributesFunc: method is nil but Device.GetAttributes was just called") + } + callInfo := struct { + }{} + mock.lockGetAttributes.Lock() + mock.calls.GetAttributes = append(mock.calls.GetAttributes, callInfo) + mock.lockGetAttributes.Unlock() + return mock.GetAttributesFunc() +} + +// GetAttributesCalls gets all the calls that were made to GetAttributes. +// Check the length with: +// len(mockedDevice.GetAttributesCalls()) +func (mock *DeviceMock) GetAttributesCalls() []struct { +} { + var calls []struct { + } + mock.lockGetAttributes.RLock() + calls = mock.calls.GetAttributes + mock.lockGetAttributes.RUnlock() + return calls +} + // GetComputeInstanceId calls GetComputeInstanceIdFunc. func (mock *DeviceMock) GetComputeInstanceId() (int, Return) { if mock.GetComputeInstanceIdFunc == nil { @@ -213,6 +269,32 @@ func (mock *DeviceMock) GetComputeInstanceIdCalls() []struct { return calls } +// GetCudaComputeCapability calls GetCudaComputeCapabilityFunc. +func (mock *DeviceMock) GetCudaComputeCapability() (int, int, Return) { + if mock.GetCudaComputeCapabilityFunc == nil { + panic("DeviceMock.GetCudaComputeCapabilityFunc: method is nil but Device.GetCudaComputeCapability was just called") + } + callInfo := struct { + }{} + mock.lockGetCudaComputeCapability.Lock() + mock.calls.GetCudaComputeCapability = append(mock.calls.GetCudaComputeCapability, callInfo) + mock.lockGetCudaComputeCapability.Unlock() + return mock.GetCudaComputeCapabilityFunc() +} + +// GetCudaComputeCapabilityCalls gets all the calls that were made to GetCudaComputeCapability. +// Check the length with: +// len(mockedDevice.GetCudaComputeCapabilityCalls()) +func (mock *DeviceMock) GetCudaComputeCapabilityCalls() []struct { +} { + var calls []struct { + } + mock.lockGetCudaComputeCapability.RLock() + calls = mock.calls.GetCudaComputeCapability + mock.lockGetCudaComputeCapability.RUnlock() + return calls +} + // GetDeviceHandleFromMigDeviceHandle calls GetDeviceHandleFromMigDeviceHandleFunc. func (mock *DeviceMock) GetDeviceHandleFromMigDeviceHandle() (Device, Return) { if mock.GetDeviceHandleFromMigDeviceHandleFunc == nil { @@ -488,6 +570,32 @@ func (mock *DeviceMock) GetMinorNumberCalls() []struct { return calls } +// GetName calls GetNameFunc. +func (mock *DeviceMock) GetName() (string, Return) { + if mock.GetNameFunc == nil { + panic("DeviceMock.GetNameFunc: method is nil but Device.GetName was just called") + } + callInfo := struct { + }{} + mock.lockGetName.Lock() + mock.calls.GetName = append(mock.calls.GetName, callInfo) + mock.lockGetName.Unlock() + return mock.GetNameFunc() +} + +// GetNameCalls gets all the calls that were made to GetName. +// Check the length with: +// len(mockedDevice.GetNameCalls()) +func (mock *DeviceMock) GetNameCalls() []struct { +} { + var calls []struct { + } + mock.lockGetName.RLock() + calls = mock.calls.GetName + mock.lockGetName.RUnlock() + return calls +} + // GetPciInfo calls GetPciInfoFunc. func (mock *DeviceMock) GetPciInfo() (PciInfo, Return) { if mock.GetPciInfoFunc == nil { diff --git a/pkg/nvml/types.go b/pkg/nvml/types.go index 47a7eb3..d6ed9c4 100644 --- a/pkg/nvml/types.go +++ b/pkg/nvml/types.go @@ -53,6 +53,9 @@ type Device interface { GetMigDeviceHandleByIndex(Index int) (Device, Return) GetGpuInstanceId() (int, Return) GetComputeInstanceId() (int, Return) + GetCudaComputeCapability() (int, int, Return) + GetAttributes() (DeviceAttributes, Return) + GetName() (string, Return) } // GpuInstance defines the functions implemented by a GpuInstance @@ -111,3 +114,6 @@ type ComputeInstanceProfileInfo nvml.ComputeInstanceProfileInfo // ComputeInstancePlacement holds placement info about a Compute Instance type ComputeInstancePlacement nvml.ComputeInstancePlacement + +// DeviceAttributes stores information about MIG devices +type DeviceAttributes nvml.DeviceAttributes