From 06cbc571ef2e45c499971561955c7ede0c24129c Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Tue, 9 Jan 2024 14:04:57 +0100 Subject: [PATCH] Add nvmlDeviceHandle function to Device interface This change allows the underlying device handle to be returned without relying on type-casting. Signed-off-by: Evan Lezar --- pkg/nvml/device.go | 11 ++++++++--- pkg/nvml/device_mock.go | 38 ++++++++++++++++++++++++++++++++++++++ pkg/nvml/types.go | 2 ++ 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/pkg/nvml/device.go b/pkg/nvml/device.go index 3901bb7..26f3d87 100644 --- a/pkg/nvml/device.go +++ b/pkg/nvml/device.go @@ -22,6 +22,11 @@ type nvmlDevice nvml.Device var _ Device = (*nvmlDevice)(nil) +// nvmlDeviceHandle returns a pointer to the underlying device. +func (d nvmlDevice) nvmlDeviceHandle() *nvml.Device { + return (*nvml.Device)(&d) +} + // GetIndex returns the index of a Device func (d nvmlDevice) GetIndex() (int, Return) { i, r := nvml.Device(d).GetIndex() @@ -181,12 +186,12 @@ func (d nvmlDevice) GetSupportedEventTypes() (uint64, Return) { // GetTopologyCommonAncestor retrieves the common ancestor for two devices. func (d nvmlDevice) GetTopologyCommonAncestor(o Device) (GpuTopologyLevel, Return) { - other, ok := o.(nvmlDevice) - if !ok { + other := o.nvmlDeviceHandle() + if other == nil { return 0, ERROR_INVALID_ARGUMENT } - l, r := nvml.Device(d).GetTopologyCommonAncestor(nvml.Device(other)) + l, r := nvml.Device(d).GetTopologyCommonAncestor(*other) return GpuTopologyLevel(l), Return(r) } diff --git a/pkg/nvml/device_mock.go b/pkg/nvml/device_mock.go index 17dfe25..203676c 100644 --- a/pkg/nvml/device_mock.go +++ b/pkg/nvml/device_mock.go @@ -4,6 +4,7 @@ package nvml import ( + "github.com/NVIDIA/go-nvml/pkg/nvml" "sync" ) @@ -101,6 +102,9 @@ var _ Device = &DeviceMock{} // SetMigModeFunc: func(Mode int) (Return, Return) { // panic("mock out the SetMigMode method") // }, +// nvmlDeviceHandleFunc: func() *nvml.Device { +// panic("mock out the nvmlDeviceHandle method") +// }, // } // // // use mockedDevice in code that requires Device @@ -192,6 +196,9 @@ type DeviceMock struct { // SetMigModeFunc mocks the SetMigMode method. SetMigModeFunc func(Mode int) (Return, Return) + // nvmlDeviceHandleFunc mocks the nvmlDeviceHandle method. + nvmlDeviceHandleFunc func() *nvml.Device + // calls tracks calls to the methods. calls struct { // CreateGpuInstanceWithPlacement holds details about calls to the CreateGpuInstanceWithPlacement method. @@ -304,6 +311,9 @@ type DeviceMock struct { // Mode is the Mode argument value. Mode int } + // nvmlDeviceHandle holds details about calls to the nvmlDeviceHandle method. + nvmlDeviceHandle []struct { + } } lockCreateGpuInstanceWithPlacement sync.RWMutex lockGetArchitecture sync.RWMutex @@ -333,6 +343,7 @@ type DeviceMock struct { lockIsMigDeviceHandle sync.RWMutex lockRegisterEvents sync.RWMutex lockSetMigMode sync.RWMutex + locknvmlDeviceHandle sync.RWMutex } // CreateGpuInstanceWithPlacement calls CreateGpuInstanceWithPlacementFunc. @@ -1153,3 +1164,30 @@ func (mock *DeviceMock) SetMigModeCalls() []struct { mock.lockSetMigMode.RUnlock() return calls } + +// nvmlDeviceHandle calls nvmlDeviceHandleFunc. +func (mock *DeviceMock) nvmlDeviceHandle() *nvml.Device { + if mock.nvmlDeviceHandleFunc == nil { + panic("DeviceMock.nvmlDeviceHandleFunc: method is nil but Device.nvmlDeviceHandle was just called") + } + callInfo := struct { + }{} + mock.locknvmlDeviceHandle.Lock() + mock.calls.nvmlDeviceHandle = append(mock.calls.nvmlDeviceHandle, callInfo) + mock.locknvmlDeviceHandle.Unlock() + return mock.nvmlDeviceHandleFunc() +} + +// nvmlDeviceHandleCalls gets all the calls that were made to nvmlDeviceHandle. +// Check the length with: +// +// len(mockedDevice.nvmlDeviceHandleCalls()) +func (mock *DeviceMock) nvmlDeviceHandleCalls() []struct { +} { + var calls []struct { + } + mock.locknvmlDeviceHandle.RLock() + calls = mock.calls.nvmlDeviceHandle + mock.locknvmlDeviceHandle.RUnlock() + return calls +} diff --git a/pkg/nvml/types.go b/pkg/nvml/types.go index ecbe4e6..02dbab3 100644 --- a/pkg/nvml/types.go +++ b/pkg/nvml/types.go @@ -68,6 +68,8 @@ type Device interface { IsMigDeviceHandle() (bool, Return) RegisterEvents(uint64, EventSet) Return SetMigMode(Mode int) (Return, Return) + // nvmlDeviceHandle returns a pointer to the underlying NVML device. + nvmlDeviceHandle() *nvml.Device } // GpuInstance defines the functions implemented by a GpuInstance