Add nvmlDeviceHandle function to Device interface

This change allows the underlying device handle to be returned
without relying on type-casting.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2024-01-09 14:04:57 +01:00
parent f3264c8a6a
commit 06cbc571ef
3 changed files with 48 additions and 3 deletions

View File

@ -22,6 +22,11 @@ type nvmlDevice nvml.Device
var _ Device = (*nvmlDevice)(nil) var _ Device = (*nvmlDevice)(nil)
// nvmlDeviceHandle returns a pointer to the underlying device.
func (d nvmlDevice) nvmlDeviceHandle() *nvml.Device {
return (*nvml.Device)(&d)
}
// GetIndex returns the index of a Device // GetIndex returns the index of a Device
func (d nvmlDevice) GetIndex() (int, Return) { func (d nvmlDevice) GetIndex() (int, Return) {
i, r := nvml.Device(d).GetIndex() i, r := nvml.Device(d).GetIndex()
@ -181,12 +186,12 @@ func (d nvmlDevice) GetSupportedEventTypes() (uint64, Return) {
// GetTopologyCommonAncestor retrieves the common ancestor for two devices. // GetTopologyCommonAncestor retrieves the common ancestor for two devices.
func (d nvmlDevice) GetTopologyCommonAncestor(o Device) (GpuTopologyLevel, Return) { func (d nvmlDevice) GetTopologyCommonAncestor(o Device) (GpuTopologyLevel, Return) {
other, ok := o.(nvmlDevice) other := o.nvmlDeviceHandle()
if !ok { if other == nil {
return 0, ERROR_INVALID_ARGUMENT return 0, ERROR_INVALID_ARGUMENT
} }
l, r := nvml.Device(d).GetTopologyCommonAncestor(nvml.Device(other)) l, r := nvml.Device(d).GetTopologyCommonAncestor(*other)
return GpuTopologyLevel(l), Return(r) return GpuTopologyLevel(l), Return(r)
} }

View File

@ -4,6 +4,7 @@
package nvml package nvml
import ( import (
"github.com/NVIDIA/go-nvml/pkg/nvml"
"sync" "sync"
) )
@ -101,6 +102,9 @@ var _ Device = &DeviceMock{}
// SetMigModeFunc: func(Mode int) (Return, Return) { // SetMigModeFunc: func(Mode int) (Return, Return) {
// panic("mock out the SetMigMode method") // panic("mock out the SetMigMode method")
// }, // },
// nvmlDeviceHandleFunc: func() *nvml.Device {
// panic("mock out the nvmlDeviceHandle method")
// },
// } // }
// //
// // use mockedDevice in code that requires Device // // use mockedDevice in code that requires Device
@ -192,6 +196,9 @@ type DeviceMock struct {
// SetMigModeFunc mocks the SetMigMode method. // SetMigModeFunc mocks the SetMigMode method.
SetMigModeFunc func(Mode int) (Return, Return) SetMigModeFunc func(Mode int) (Return, Return)
// nvmlDeviceHandleFunc mocks the nvmlDeviceHandle method.
nvmlDeviceHandleFunc func() *nvml.Device
// calls tracks calls to the methods. // calls tracks calls to the methods.
calls struct { calls struct {
// CreateGpuInstanceWithPlacement holds details about calls to the CreateGpuInstanceWithPlacement method. // CreateGpuInstanceWithPlacement holds details about calls to the CreateGpuInstanceWithPlacement method.
@ -304,6 +311,9 @@ type DeviceMock struct {
// Mode is the Mode argument value. // Mode is the Mode argument value.
Mode int Mode int
} }
// nvmlDeviceHandle holds details about calls to the nvmlDeviceHandle method.
nvmlDeviceHandle []struct {
}
} }
lockCreateGpuInstanceWithPlacement sync.RWMutex lockCreateGpuInstanceWithPlacement sync.RWMutex
lockGetArchitecture sync.RWMutex lockGetArchitecture sync.RWMutex
@ -333,6 +343,7 @@ type DeviceMock struct {
lockIsMigDeviceHandle sync.RWMutex lockIsMigDeviceHandle sync.RWMutex
lockRegisterEvents sync.RWMutex lockRegisterEvents sync.RWMutex
lockSetMigMode sync.RWMutex lockSetMigMode sync.RWMutex
locknvmlDeviceHandle sync.RWMutex
} }
// CreateGpuInstanceWithPlacement calls CreateGpuInstanceWithPlacementFunc. // CreateGpuInstanceWithPlacement calls CreateGpuInstanceWithPlacementFunc.
@ -1153,3 +1164,30 @@ func (mock *DeviceMock) SetMigModeCalls() []struct {
mock.lockSetMigMode.RUnlock() mock.lockSetMigMode.RUnlock()
return calls return calls
} }
// nvmlDeviceHandle calls nvmlDeviceHandleFunc.
func (mock *DeviceMock) nvmlDeviceHandle() *nvml.Device {
if mock.nvmlDeviceHandleFunc == nil {
panic("DeviceMock.nvmlDeviceHandleFunc: method is nil but Device.nvmlDeviceHandle was just called")
}
callInfo := struct {
}{}
mock.locknvmlDeviceHandle.Lock()
mock.calls.nvmlDeviceHandle = append(mock.calls.nvmlDeviceHandle, callInfo)
mock.locknvmlDeviceHandle.Unlock()
return mock.nvmlDeviceHandleFunc()
}
// nvmlDeviceHandleCalls gets all the calls that were made to nvmlDeviceHandle.
// Check the length with:
//
// len(mockedDevice.nvmlDeviceHandleCalls())
func (mock *DeviceMock) nvmlDeviceHandleCalls() []struct {
} {
var calls []struct {
}
mock.locknvmlDeviceHandle.RLock()
calls = mock.calls.nvmlDeviceHandle
mock.locknvmlDeviceHandle.RUnlock()
return calls
}

View File

@ -68,6 +68,8 @@ type Device interface {
IsMigDeviceHandle() (bool, Return) IsMigDeviceHandle() (bool, Return)
RegisterEvents(uint64, EventSet) Return RegisterEvents(uint64, EventSet) Return
SetMigMode(Mode int) (Return, Return) SetMigMode(Mode int) (Return, Return)
// nvmlDeviceHandle returns a pointer to the underlying NVML device.
nvmlDeviceHandle() *nvml.Device
} }
// GpuInstance defines the functions implemented by a GpuInstance // GpuInstance defines the functions implemented by a GpuInstance