From 8e749776c541842c61fe73d8223419f72829eb6d Mon Sep 17 00:00:00 2001 From: Kevin Klues Date: Thu, 15 Sep 2022 17:02:28 +0000 Subject: [PATCH] Add nvml wrappers for getting GIs and CIs by ID Signed-off-by: Kevin Klues --- pkg/nvml/device.go | 6 ++++++ pkg/nvml/device_mock.go | 43 +++++++++++++++++++++++++++++++++++++++++ pkg/nvml/gi.go | 6 ++++++ pkg/nvml/gi_mock.go | 43 +++++++++++++++++++++++++++++++++++++++++ pkg/nvml/types.go | 2 ++ 5 files changed, 100 insertions(+) diff --git a/pkg/nvml/device.go b/pkg/nvml/device.go index 91bea46..37ba0b1 100644 --- a/pkg/nvml/device.go +++ b/pkg/nvml/device.go @@ -76,6 +76,12 @@ func (d nvmlDevice) GetMigMode() (int, int, Return) { return s1, s2, Return(r) } +// GetGpuInstanceById returns the GPU Instance associated with a particular ID +func (d nvmlDevice) GetGpuInstanceById(id int) (GpuInstance, Return) { + gi, r := nvml.Device(d).GetGpuInstanceById(id) + return nvmlGpuInstance(gi), Return(r) +} + // GetGpuInstanceProfileInfo returns the profile info of a GPU Instance func (d nvmlDevice) GetGpuInstanceProfileInfo(profile int) (GpuInstanceProfileInfo, Return) { p, r := nvml.Device(d).GetGpuInstanceProfileInfo(profile) diff --git a/pkg/nvml/device_mock.go b/pkg/nvml/device_mock.go index adc5e28..be397ae 100644 --- a/pkg/nvml/device_mock.go +++ b/pkg/nvml/device_mock.go @@ -29,6 +29,9 @@ var _ Device = &DeviceMock{} // GetDeviceHandleFromMigDeviceHandleFunc: func() (Device, Return) { // panic("mock out the GetDeviceHandleFromMigDeviceHandle method") // }, +// GetGpuInstanceByIdFunc: func(ID int) (GpuInstance, Return) { +// panic("mock out the GetGpuInstanceById method") +// }, // GetGpuInstanceIdFunc: func() (int, Return) { // panic("mock out the GetGpuInstanceId method") // }, @@ -90,6 +93,9 @@ type DeviceMock struct { // GetDeviceHandleFromMigDeviceHandleFunc mocks the GetDeviceHandleFromMigDeviceHandle method. GetDeviceHandleFromMigDeviceHandleFunc func() (Device, Return) + // GetGpuInstanceByIdFunc mocks the GetGpuInstanceById method. + GetGpuInstanceByIdFunc func(ID int) (GpuInstance, Return) + // GetGpuInstanceIdFunc mocks the GetGpuInstanceId method. GetGpuInstanceIdFunc func() (int, Return) @@ -146,6 +152,11 @@ type DeviceMock struct { // GetDeviceHandleFromMigDeviceHandle holds details about calls to the GetDeviceHandleFromMigDeviceHandle method. GetDeviceHandleFromMigDeviceHandle []struct { } + // GetGpuInstanceById holds details about calls to the GetGpuInstanceById method. + GetGpuInstanceById []struct { + // ID is the ID argument value. + ID int + } // GetGpuInstanceId holds details about calls to the GetGpuInstanceId method. GetGpuInstanceId []struct { } @@ -201,6 +212,7 @@ type DeviceMock struct { lockGetComputeInstanceId sync.RWMutex lockGetCudaComputeCapability sync.RWMutex lockGetDeviceHandleFromMigDeviceHandle sync.RWMutex + lockGetGpuInstanceById sync.RWMutex lockGetGpuInstanceId sync.RWMutex lockGetGpuInstanceProfileInfo sync.RWMutex lockGetGpuInstances sync.RWMutex @@ -321,6 +333,37 @@ func (mock *DeviceMock) GetDeviceHandleFromMigDeviceHandleCalls() []struct { return calls } +// GetGpuInstanceById calls GetGpuInstanceByIdFunc. +func (mock *DeviceMock) GetGpuInstanceById(ID int) (GpuInstance, Return) { + if mock.GetGpuInstanceByIdFunc == nil { + panic("DeviceMock.GetGpuInstanceByIdFunc: method is nil but Device.GetGpuInstanceById was just called") + } + callInfo := struct { + ID int + }{ + ID: ID, + } + mock.lockGetGpuInstanceById.Lock() + mock.calls.GetGpuInstanceById = append(mock.calls.GetGpuInstanceById, callInfo) + mock.lockGetGpuInstanceById.Unlock() + return mock.GetGpuInstanceByIdFunc(ID) +} + +// GetGpuInstanceByIdCalls gets all the calls that were made to GetGpuInstanceById. +// Check the length with: +// len(mockedDevice.GetGpuInstanceByIdCalls()) +func (mock *DeviceMock) GetGpuInstanceByIdCalls() []struct { + ID int +} { + var calls []struct { + ID int + } + mock.lockGetGpuInstanceById.RLock() + calls = mock.calls.GetGpuInstanceById + mock.lockGetGpuInstanceById.RUnlock() + return calls +} + // GetGpuInstanceId calls GetGpuInstanceIdFunc. func (mock *DeviceMock) GetGpuInstanceId() (int, Return) { if mock.GetGpuInstanceIdFunc == nil { diff --git a/pkg/nvml/gi.go b/pkg/nvml/gi.go index cd775bc..bc4d373 100644 --- a/pkg/nvml/gi.go +++ b/pkg/nvml/gi.go @@ -36,6 +36,12 @@ func (gi nvmlGpuInstance) GetInfo() (GpuInstanceInfo, Return) { return info, Return(r) } +// GetComputeInstanceById returns the Compute Instance associated with a particular ID. +func (gi nvmlGpuInstance) GetComputeInstanceById(id int) (ComputeInstance, Return) { + ci, r := nvml.GpuInstance(gi).GetComputeInstanceById(id) + return nvmlComputeInstance(ci), Return(r) +} + // GetComputeInstanceProfileInfo returns info about a given Compute Instance profile func (gi nvmlGpuInstance) GetComputeInstanceProfileInfo(profile int, engProfile int) (ComputeInstanceProfileInfo, Return) { p, r := nvml.GpuInstance(gi).GetComputeInstanceProfileInfo(profile, engProfile) diff --git a/pkg/nvml/gi_mock.go b/pkg/nvml/gi_mock.go index 7393a04..d2a3487 100644 --- a/pkg/nvml/gi_mock.go +++ b/pkg/nvml/gi_mock.go @@ -23,6 +23,9 @@ var _ GpuInstance = &GpuInstanceMock{} // DestroyFunc: func() Return { // panic("mock out the Destroy method") // }, +// GetComputeInstanceByIdFunc: func(ID int) (ComputeInstance, Return) { +// panic("mock out the GetComputeInstanceById method") +// }, // GetComputeInstanceProfileInfoFunc: func(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return) { // panic("mock out the GetComputeInstanceProfileInfo method") // }, @@ -45,6 +48,9 @@ type GpuInstanceMock struct { // DestroyFunc mocks the Destroy method. DestroyFunc func() Return + // GetComputeInstanceByIdFunc mocks the GetComputeInstanceById method. + GetComputeInstanceByIdFunc func(ID int) (ComputeInstance, Return) + // GetComputeInstanceProfileInfoFunc mocks the GetComputeInstanceProfileInfo method. GetComputeInstanceProfileInfoFunc func(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return) @@ -64,6 +70,11 @@ type GpuInstanceMock struct { // Destroy holds details about calls to the Destroy method. Destroy []struct { } + // GetComputeInstanceById holds details about calls to the GetComputeInstanceById method. + GetComputeInstanceById []struct { + // ID is the ID argument value. + ID int + } // GetComputeInstanceProfileInfo holds details about calls to the GetComputeInstanceProfileInfo method. GetComputeInstanceProfileInfo []struct { // Profile is the Profile argument value. @@ -82,6 +93,7 @@ type GpuInstanceMock struct { } lockCreateComputeInstance sync.RWMutex lockDestroy sync.RWMutex + lockGetComputeInstanceById sync.RWMutex lockGetComputeInstanceProfileInfo sync.RWMutex lockGetComputeInstances sync.RWMutex lockGetInfo sync.RWMutex @@ -144,6 +156,37 @@ func (mock *GpuInstanceMock) DestroyCalls() []struct { return calls } +// GetComputeInstanceById calls GetComputeInstanceByIdFunc. +func (mock *GpuInstanceMock) GetComputeInstanceById(ID int) (ComputeInstance, Return) { + if mock.GetComputeInstanceByIdFunc == nil { + panic("GpuInstanceMock.GetComputeInstanceByIdFunc: method is nil but GpuInstance.GetComputeInstanceById was just called") + } + callInfo := struct { + ID int + }{ + ID: ID, + } + mock.lockGetComputeInstanceById.Lock() + mock.calls.GetComputeInstanceById = append(mock.calls.GetComputeInstanceById, callInfo) + mock.lockGetComputeInstanceById.Unlock() + return mock.GetComputeInstanceByIdFunc(ID) +} + +// GetComputeInstanceByIdCalls gets all the calls that were made to GetComputeInstanceById. +// Check the length with: +// len(mockedGpuInstance.GetComputeInstanceByIdCalls()) +func (mock *GpuInstanceMock) GetComputeInstanceByIdCalls() []struct { + ID int +} { + var calls []struct { + ID int + } + mock.lockGetComputeInstanceById.RLock() + calls = mock.calls.GetComputeInstanceById + mock.lockGetComputeInstanceById.RUnlock() + return calls +} + // GetComputeInstanceProfileInfo calls GetComputeInstanceProfileInfoFunc. func (mock *GpuInstanceMock) GetComputeInstanceProfileInfo(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return) { if mock.GetComputeInstanceProfileInfoFunc == nil { diff --git a/pkg/nvml/types.go b/pkg/nvml/types.go index e4ae40f..38ae3f2 100644 --- a/pkg/nvml/types.go +++ b/pkg/nvml/types.go @@ -42,6 +42,7 @@ type Device interface { GetComputeInstanceId() (int, Return) GetCudaComputeCapability() (int, int, Return) GetDeviceHandleFromMigDeviceHandle() (Device, Return) + GetGpuInstanceById(ID int) (GpuInstance, Return) GetGpuInstanceId() (int, Return) GetGpuInstanceProfileInfo(Profile int) (GpuInstanceProfileInfo, Return) GetGpuInstances(Info *GpuInstanceProfileInfo) ([]GpuInstance, Return) @@ -64,6 +65,7 @@ type Device interface { type GpuInstance interface { CreateComputeInstance(Info *ComputeInstanceProfileInfo) (ComputeInstance, Return) Destroy() Return + GetComputeInstanceById(ID int) (ComputeInstance, Return) GetComputeInstanceProfileInfo(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return) GetComputeInstances(Info *ComputeInstanceProfileInfo) ([]ComputeInstance, Return) GetInfo() (GpuInstanceInfo, Return)