diff --git a/pkg/nvml/device.go b/pkg/nvml/device.go index 26f3d87..95877b7 100644 --- a/pkg/nvml/device.go +++ b/pkg/nvml/device.go @@ -27,6 +27,12 @@ func (d nvmlDevice) nvmlDeviceHandle() *nvml.Device { return (*nvml.Device)(&d) } +// CreateGpuInstance creates a GPU instance matching the provided GpuInstanceProfileInfo +func (d nvmlDevice) CreateGpuInstance(Info *GpuInstanceProfileInfo) (GpuInstance, Return) { + gpuInstance, r := nvml.Device(d).CreateGpuInstance((*nvml.GpuInstanceProfileInfo)(Info)) + return nvmlGpuInstance(gpuInstance), Return(r) +} + // GetIndex returns the index of a Device func (d nvmlDevice) GetIndex() (int, Return) { i, r := nvml.Device(d).GetIndex() diff --git a/pkg/nvml/device_mock.go b/pkg/nvml/device_mock.go index 203676c..6e275db 100644 --- a/pkg/nvml/device_mock.go +++ b/pkg/nvml/device_mock.go @@ -18,6 +18,9 @@ var _ Device = &DeviceMock{} // // // make and configure a mocked Device // mockedDevice := &DeviceMock{ +// CreateGpuInstanceFunc: func(Info *GpuInstanceProfileInfo) (GpuInstance, Return) { +// panic("mock out the CreateGpuInstance method") +// }, // CreateGpuInstanceWithPlacementFunc: func(gpuInstanceProfileInfo *GpuInstanceProfileInfo, gpuInstancePlacement *GpuInstancePlacement) (GpuInstance, Return) { // panic("mock out the CreateGpuInstanceWithPlacement method") // }, @@ -112,6 +115,9 @@ var _ Device = &DeviceMock{} // // } type DeviceMock struct { + // CreateGpuInstanceFunc mocks the CreateGpuInstance method. + CreateGpuInstanceFunc func(Info *GpuInstanceProfileInfo) (GpuInstance, Return) + // CreateGpuInstanceWithPlacementFunc mocks the CreateGpuInstanceWithPlacement method. CreateGpuInstanceWithPlacementFunc func(gpuInstanceProfileInfo *GpuInstanceProfileInfo, gpuInstancePlacement *GpuInstancePlacement) (GpuInstance, Return) @@ -201,6 +207,11 @@ type DeviceMock struct { // calls tracks calls to the methods. calls struct { + // CreateGpuInstance holds details about calls to the CreateGpuInstance method. + CreateGpuInstance []struct { + // Info is the Info argument value. + Info *GpuInstanceProfileInfo + } // CreateGpuInstanceWithPlacement holds details about calls to the CreateGpuInstanceWithPlacement method. CreateGpuInstanceWithPlacement []struct { // GpuInstanceProfileInfo is the gpuInstanceProfileInfo argument value. @@ -315,6 +326,7 @@ type DeviceMock struct { nvmlDeviceHandle []struct { } } + lockCreateGpuInstance sync.RWMutex lockCreateGpuInstanceWithPlacement sync.RWMutex lockGetArchitecture sync.RWMutex lockGetAttributes sync.RWMutex @@ -346,6 +358,38 @@ type DeviceMock struct { locknvmlDeviceHandle sync.RWMutex } +// CreateGpuInstance calls CreateGpuInstanceFunc. +func (mock *DeviceMock) CreateGpuInstance(Info *GpuInstanceProfileInfo) (GpuInstance, Return) { + if mock.CreateGpuInstanceFunc == nil { + panic("DeviceMock.CreateGpuInstanceFunc: method is nil but Device.CreateGpuInstance was just called") + } + callInfo := struct { + Info *GpuInstanceProfileInfo + }{ + Info: Info, + } + mock.lockCreateGpuInstance.Lock() + mock.calls.CreateGpuInstance = append(mock.calls.CreateGpuInstance, callInfo) + mock.lockCreateGpuInstance.Unlock() + return mock.CreateGpuInstanceFunc(Info) +} + +// CreateGpuInstanceCalls gets all the calls that were made to CreateGpuInstance. +// Check the length with: +// +// len(mockedDevice.CreateGpuInstanceCalls()) +func (mock *DeviceMock) CreateGpuInstanceCalls() []struct { + Info *GpuInstanceProfileInfo +} { + var calls []struct { + Info *GpuInstanceProfileInfo + } + mock.lockCreateGpuInstance.RLock() + calls = mock.calls.CreateGpuInstance + mock.lockCreateGpuInstance.RUnlock() + return calls +} + // CreateGpuInstanceWithPlacement calls CreateGpuInstanceWithPlacementFunc. func (mock *DeviceMock) CreateGpuInstanceWithPlacement(gpuInstanceProfileInfo *GpuInstanceProfileInfo, gpuInstancePlacement *GpuInstancePlacement) (GpuInstance, Return) { if mock.CreateGpuInstanceWithPlacementFunc == nil { diff --git a/pkg/nvml/types.go b/pkg/nvml/types.go index 02dbab3..5fca53b 100644 --- a/pkg/nvml/types.go +++ b/pkg/nvml/types.go @@ -40,6 +40,7 @@ type Interface interface { // //go:generate moq -out device_mock.go . Device type Device interface { + CreateGpuInstance(Info *GpuInstanceProfileInfo) (GpuInstance, Return) CreateGpuInstanceWithPlacement(*GpuInstanceProfileInfo, *GpuInstancePlacement) (GpuInstance, Return) GetArchitecture() (DeviceArchitecture, Return) GetAttributes() (DeviceAttributes, Return)